Source code for scgo.initialization.seed_combiners

"""Seed combination and growth strategies for cluster generation."""

from __future__ import annotations

import numpy as np
from ase import Atoms

from scgo.utils.logging import get_logger

from .geometry_helpers import (
    _check_composition_feasibility,
    _generate_rotation_matrix,
    analyze_disconnection,
    compute_bond_distance_params,
    get_covalent_radius,
    get_largest_facets,
    is_cluster_connected,
    place_multi_atom_seed_on_facet,
    validate_cluster,
)
from .initialization_config import (
    CLASH_TOLERANCE,
    CONNECTIVITY_FACTOR,
    SEED_CLASH_FACTOR,
)
from .random_spherical import grow_from_seed

logger = get_logger(__name__)


def _apply_random_rotation(atoms: Atoms, rng: np.random.Generator) -> Atoms:
    """Apply a random rotation to atoms around their center of mass.

    Args:
        atoms: The Atoms object to rotate
        rng: Random number generator for reproducible randomness

    Returns:
        New Atoms object with rotated positions
    """
    rotated_atoms = atoms.copy()
    rotated_atoms.center()
    center = rotated_atoms.get_center_of_mass()
    axis = rng.standard_normal(3)
    axis /= np.linalg.norm(axis)
    angle = rng.uniform(0, 2 * np.pi)
    R = _generate_rotation_matrix(axis, angle)
    positions = rotated_atoms.get_positions()
    rotated_positions = center + (positions - center) @ R.T
    rotated_atoms.set_positions(rotated_positions)
    return rotated_atoms


def _is_valid_placement(
    seed_to_add: Atoms,
    combined_atoms: Atoms,
    connectivity_factor: float,
    min_distance_factor: float = SEED_CLASH_FACTOR,
) -> bool:
    """Check if a seed placement is valid (no clashes and maintains connectivity)."""
    seed_positions = seed_to_add.get_positions()
    existing_positions = combined_atoms.get_positions()
    existing_symbols = combined_atoms.get_chemical_symbols()
    seed_symbols = seed_to_add.get_chemical_symbols()

    for seed_pos, seed_sym in zip(seed_positions, seed_symbols, strict=False):
        for exist_pos, exist_sym in zip(
            existing_positions, existing_symbols, strict=False
        ):
            distance = np.linalg.norm(seed_pos - exist_pos)
            r_seed = get_covalent_radius(seed_sym)
            r_exist = get_covalent_radius(exist_sym)
            min_allowed = (r_seed + r_exist) * min_distance_factor

            if distance < min_allowed - CLASH_TOLERANCE:
                return False

    temp_combined = combined_atoms.copy()
    temp_combined.extend(seed_to_add)
    return is_cluster_connected(temp_combined, connectivity_factor, use_mic=False)


[docs] def combine_seeds( seeds: list[Atoms], cell_side: float, rng: np.random.Generator, separation_scaling: float = 1.0, connectivity_factor: float = CONNECTIVITY_FACTOR, min_distance_factor: float = SEED_CLASH_FACTOR, ) -> Atoms | None: """Combines multiple seed clusters into a single new structure using facet-to-facet placement.""" if not seeds: return Atoms(cell=[cell_side, cell_side, cell_side]) # Apply random rotation to first seed for diversity (reproducible with same seed) combined_atoms = _apply_random_rotation(seeds[0], rng) for i in range(1, len(seeds)): # Apply random rotation to seed before placement for diversity base_seed_to_add = _apply_random_rotation(seeds[i], rng) seed_symbols = base_seed_to_add.get_chemical_symbols() existing_symbols = combined_atoms.get_chemical_symbols() seed_avg_radius = np.mean([get_covalent_radius(sym) for sym in seed_symbols]) existing_avg_radius = np.mean( [get_covalent_radius(sym) for sym in existing_symbols] ) # Use shared bond distance calculation for consistency bond_distance, _, max_connectivity_distance = compute_bond_distance_params( max_existing_radius=existing_avg_radius, avg_new_radius=seed_avg_radius, connectivity_factor=connectivity_factor, min_distance_factor=min_distance_factor, placement_radius_scaling=separation_scaling, ) placement_success = False seed_to_add = None existing_facets = get_largest_facets(combined_atoms, n_facets=3) for facet_centroid, facet_normal, _ in existing_facets: candidate = place_multi_atom_seed_on_facet( base_seed_to_add, facet_centroid, facet_normal, bond_distance, rng, ) if _is_valid_placement( candidate, combined_atoms, connectivity_factor, min_distance_factor ): seed_to_add = candidate placement_success = True break if not placement_success: candidate = base_seed_to_add.copy() center = combined_atoms.get_center_of_mass() direction = rng.standard_normal(3) direction /= np.linalg.norm(direction) candidate.translate( center + direction * bond_distance - candidate.get_center_of_mass() ) if _is_valid_placement( candidate, combined_atoms, connectivity_factor, min_distance_factor ): seed_to_add = candidate placement_success = True if not placement_success: logger.warning( f"Failed to place seed {i + 1} after trying all facets. Returning None." ) return None combined_atoms.extend(seed_to_add) combined_atoms.center() if not is_cluster_connected(combined_atoms, connectivity_factor, use_mic=False): disconnection_distance, suggested_factor, analysis_msg = ( analyze_disconnection( combined_atoms, connectivity_factor, use_mic=False ) ) logger.warning( f"Seed {i + 1} placement created disconnected cluster. " f"Current connectivity_factor={connectivity_factor:.2f}. " f"Analysis: {analysis_msg}. " f"Suggested connectivity_factor: {suggested_factor:.2f}" ) return None from .geometry_helpers import clear_convex_hull_cache clear_convex_hull_cache() combined_atoms.set_cell([cell_side, cell_side, cell_side]) combined_atoms.center() validated_atoms, _, _ = validate_cluster( combined_atoms, composition=None, min_distance_factor=min_distance_factor, connectivity_factor=connectivity_factor, sort_atoms=False, raise_on_failure=True, source="combine_seeds", ) return validated_atoms
[docs] def combine_and_grow( seeds: list[Atoms], target_composition: list[str], cell_side: float, rng: np.random.Generator, vdw_scaling: float = 1.0, min_distance_factor: float = SEED_CLASH_FACTOR, connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Combines seeds and grows to target composition.""" combined_seed = combine_seeds( seeds=seeds, cell_side=cell_side, rng=rng, separation_scaling=vdw_scaling, connectivity_factor=connectivity_factor, min_distance_factor=min_distance_factor, ) if combined_seed is None: logger.warning("Initial seed combination failed.") return None # Check composition feasibility before attempting growth base_composition = combined_seed.get_chemical_symbols() is_feasible, error_message = _check_composition_feasibility( base_composition, target_composition, operation="grow" ) if not is_feasible: logger.warning( f"Composition growth not feasible after seed combination: {error_message}. " f"Combined seed: {base_composition}, Target: {target_composition}" ) return None return grow_from_seed( seed_atoms=combined_seed, target_composition=target_composition, placement_radius_scaling=vdw_scaling, cell_side=cell_side, rng=rng, min_distance_factor=min_distance_factor, connectivity_factor=connectivity_factor, )