Source code for scgo.initialization.geometry_helpers

"""Geometric utility functions for cluster structure generation.

This module provides helper functions for placing atoms based on cluster geometry,
particularly using convex hull analysis to guide growth strategies.
"""

from __future__ import annotations

import hashlib
from collections.abc import Sequence

import numpy as np
from ase import Atoms
from scipy.spatial import (
    ConvexHull,
    KDTree,
    QhullError,
)

from scgo.database.cache import get_global_cache
from scgo.utils.helpers import get_composition_counts
from scgo.utils.logging import get_logger

from .atomic_radii import get_covalent_radius
from .initialization_config import (
    CLASH_TOLERANCE,
    CONNECTIVITY_FACTOR,
    CONNECTIVITY_SUGGESTION_BUFFER,
    CONVEX_HULL_PERTURBATION_SCALE,
    CONVEX_HULL_VOLUME_TOLERANCE,
    LINEAR_GEOMETRY_TOLERANCE,
    MIN_DISTANCE_FACTOR_DEFAULT,
    ROTATION_AXIS_TOLERANCE,
    SMART_FILTERING_PERTURBATION_SCALE,
)

# Core utilities

template_debug_logger = get_logger("scgo.initialization.templates")


def format_placement_error_message(
    context: str,
    composition: list[str] | None,
    n_atoms: int | None,
    placement_radius_scaling: float,
    min_distance_factor: float,
    connectivity_factor: float,
    cell_side: float | None = None,
    diagnostics: StructureDiagnostics | None = None,
    additional_info: str = "",
) -> str:
    """Format a consistent error message for placement failures.

    This helper creates standardized error messages with parameter values,
    diagnostics, and suggestions for common placement failures.

    Args:
        context: Context description (e.g., "random_spherical", "seed+growth")
        composition: Composition list (for display)
        n_atoms: Number of atoms (for display)
        placement_radius_scaling: Current placement radius scaling
        min_distance_factor: Current minimum distance factor
        connectivity_factor: Current connectivity factor
        cell_side: Optional cell side length
        diagnostics: Optional structure diagnostics
        additional_info: Additional context information

    Returns:
        Formatted error message string
    """
    parts = [f"Could not {context}"]

    if composition:
        parts.append(f"for composition {composition}")
    if n_atoms:
        parts.append(f"({n_atoms} atoms)")

    parts.append(".\n")

    # Parameter values
    param_parts = [
        f"placement_radius_scaling={placement_radius_scaling:.2f}",
        f"min_distance_factor={min_distance_factor:.2f}",
        f"connectivity_factor={connectivity_factor:.2f}",
    ]
    if cell_side is not None:
        param_parts.append(f"cell_side={cell_side:.2f} Å")

    parts.append(f"Parameters: {', '.join(param_parts)}")

    # Diagnostics if provided
    if diagnostics:
        parts.append(f"\nDiagnostics: {diagnostics.summary}")

    # Additional info
    if additional_info:
        parts.append(f"\n{additional_info}")

    # Suggestions
    parts.append("\nSuggestions:")
    parts.append(
        f"  - Increase placement_radius_scaling to {placement_radius_scaling * 1.5:.2f}"
    )
    parts.append(
        f"  - Decrease min_distance_factor to {max(min_distance_factor * 0.8, 0.3):.2f}"
    )
    parts.append(f"  - Increase connectivity_factor to {connectivity_factor * 1.2:.2f}")
    if cell_side is not None:
        parts.append(f"  - Increase cell_side to {cell_side * 1.5:.2f} Å")

    return "\n".join(parts)


# Convex hull and caching

# Cache namespace for convex hull computations
_CONVEX_HULL_CACHE_NS = "convex_hull"


def _get_positions_hash(positions: np.ndarray) -> str:
    """Generate a collision-resistant hash for positions array.

    Uses SHA256 to avoid hash collisions that could return wrong cached results.
    For small arrays (<100 points), also stores positions bytes for collision detection.

    Args:
        positions: Array of atomic positions

    Returns:
        SHA256 hash string
    """
    positions_bytes = positions.tobytes()
    return hashlib.sha256(positions_bytes).hexdigest()


def _get_cached_hull(positions: np.ndarray) -> ConvexHull:
    """Get convex hull from cache or compute and cache it.

    Uses LRU eviction policy to maintain cache size limit.
    Uses SHA256 hashing to avoid collisions and verifies cached results match positions.

    Args:
        positions: Array of atomic positions

    Returns:
        ConvexHull object for the given positions

    Raises:
        ValueError: If positions array has fewer than 4 points (insufficient for 3D hull)
    """
    if len(positions) < 4:
        raise ValueError(
            f"Convex hull requires at least 4 points in 3D, got {len(positions)}"
        )

    positions_hash = _get_positions_hash(positions)
    positions_bytes = positions.tobytes()

    def compute_hull() -> ConvexHull:
        try:
            return ConvexHull(positions)
        except (QhullError, ValueError) as e:
            # Handle degenerate cases (collinear/coplanar points)
            raise ValueError(
                f"Convex hull computation failed for {len(positions)} points: {e}"
            ) from e

    # Use cache key with validation bytes to detect collisions
    cache_key = (positions_hash, positions_bytes)
    return get_global_cache().get_or_compute(
        _CONVEX_HULL_CACHE_NS, cache_key, compute_hull
    )


def get_convex_hull_vertex_indices(atoms: Atoms) -> np.ndarray:
    """Return atom indices that are vertices of the cluster's convex hull.

    Uses scipy's ConvexHull; vertices are the extreme points on the hull.

    Args:
        atoms: The Atoms object representing the cluster.

    Returns:
        1D array of atom indices (vertex indices). Empty array if hull cannot
        be computed (e.g. fewer than 4 atoms, degenerate geometry).
    """
    if len(atoms) < 4:
        return np.array([], dtype=np.intp)
    try:
        positions = atoms.get_positions()
        hull = _get_cached_hull(positions)
        return np.asarray(hull.vertices, dtype=np.intp)
    except (ValueError, QhullError):
        return np.array([], dtype=np.intp)


def _adjust_bond_distance_for_facet_geometry(
    bond_distance: float,
    min_centroid_dist: float,
    max_centroid_dist: float,
    min_connectivity_dist: float | None,
    max_connectivity_dist: float | None,
) -> float:
    """Adjust bond distance based on facet geometry constraints.

    This function adjusts the bond distance to ensure connectivity constraints
    are satisfied when placing atoms on convex hull facets. For small facets,
    it uses strict constraints to ensure connectivity. For large facets, it
    gradually relaxes the constraint while still relying on per-candidate
    connectivity checks to maintain cluster connectivity.

    Args:
        bond_distance: Initial bond distance for placement
        min_centroid_dist: Minimum distance from centroid to facet vertices
        max_centroid_dist: Maximum distance from centroid to facet vertices
        min_connectivity_dist: Minimum connectivity distance threshold
        max_connectivity_dist: Maximum connectivity distance threshold

    Returns:
        Adjusted bond distance that satisfies geometry constraints
    """
    adjusted_bond_distance = bond_distance
    if (
        min_connectivity_dist is not None
        and max_connectivity_dist is not None
        and min_centroid_dist < max_connectivity_dist
    ):
        # Calculate maximum bond_distance that ensures connectivity
        # Use strict constraint: no relaxation of connectivity threshold
        max_bond_from_geometry = np.sqrt(
            max_connectivity_dist**2 - min_centroid_dist**2
        )
        if max_bond_from_geometry > 0:
            adjusted_bond_distance = min(bond_distance, max_bond_from_geometry)
            # Also ensure it's at least the minimum if possible
            if min_centroid_dist < min_connectivity_dist:
                diff_sq = min_connectivity_dist**2 - max_centroid_dist**2
                if diff_sq > 0:
                    min_bond_from_geometry = np.sqrt(diff_sq)
                    adjusted_bond_distance = max(
                        adjusted_bond_distance, min_bond_from_geometry
                    )
    return adjusted_bond_distance


def compute_bond_distance_params(
    max_existing_radius: float,
    avg_new_radius: float,
    connectivity_factor: float,
    min_distance_factor: float,
    placement_radius_scaling: float,
    *,
    effective_min_distance: float | None = None,
    effective_scaling: float | None = None,
) -> tuple[float, float, float]:
    """Compute bond distance and connectivity bounds for facet-based placement.

    Shared logic used by both template growth (grow_template_via_facets) and
    batch placement (_add_atoms_batch_mode).

    Args:
        max_existing_radius: Max covalent radius among existing atoms.
        avg_new_radius: Mean covalent radius of atoms to add.
        connectivity_factor: Connectivity threshold factor.
        min_distance_factor: Factor for minimum separation.
        placement_radius_scaling: Scaling for placement distance.
        effective_min_distance: If set, use instead of min_distance_factor for min_dist.
        effective_scaling: If set, use instead of placement_radius_scaling for target_dist.

    Returns:
        (bond_distance, min_connectivity_dist, max_connectivity_dist)
    """
    base = max_existing_radius + avg_new_radius
    max_conn = float(base * connectivity_factor)
    min_dist = base * (
        effective_min_distance
        if effective_min_distance is not None
        else min_distance_factor
    )
    scale = (
        effective_scaling if effective_scaling is not None else placement_radius_scaling
    )
    target_dist = base * min(scale, connectivity_factor)
    bond_distance = float(np.clip(target_dist, min_dist, max_conn))
    return (bond_distance, float(min_dist), max_conn)


def _compute_facet_properties(
    hull: ConvexHull, atoms: Atoms
) -> list[tuple[np.ndarray, np.ndarray, float, tuple[float, float]]]:
    """Compute properties for all facets of a convex hull.

    Args:
        hull: ConvexHull object from scipy
        atoms: Atoms object to compute facet properties for

    Returns:
        List of tuples (centroid, normal, area, (min_centroid_dist, max_centroid_dist))
        for each facet
    """
    positions = atoms.get_positions()
    center_of_mass = atoms.get_center_of_mass()
    facet_properties = []

    for simplex_indices in hull.simplices:
        facet_positions = positions[simplex_indices]

        # Calculate facet area (assuming triangular facets for 3D convex hull)
        v1 = facet_positions[1] - facet_positions[0]
        v2 = facet_positions[2] - facet_positions[0]
        area = 0.5 * np.linalg.norm(np.cross(v1, v2))

        # Calculate facet centroid
        centroid = np.mean(facet_positions, axis=0)

        # Calculate distances from centroid to each vertex
        centroid_to_vertices = [
            np.linalg.norm(vertex_pos - centroid) for vertex_pos in facet_positions
        ]
        min_centroid_dist = min(centroid_to_vertices)
        max_centroid_dist = max(centroid_to_vertices)

        # Calculate facet normal (outward pointing)
        normal = np.cross(v1, v2)
        normal /= np.linalg.norm(normal)

        # Ensure normal points outwards from the center of mass
        if np.dot(normal, centroid - center_of_mass) < 0:
            normal *= -1

        facet_properties.append(
            (centroid, normal, area, (min_centroid_dist, max_centroid_dist))
        )

    return facet_properties


def _filter_safe_facets_for_placement(
    atoms: Atoms,
    facet_properties: list[tuple[np.ndarray, np.ndarray, float, tuple[float, float]]],
    bond_distance: float,
    min_connectivity_dist: float | None,
    max_connectivity_dist: float | None,
    min_distance_factor: float | None,
    new_atom_radius: float | None,
    positions: np.ndarray,
    symbols_list: list[str],
    connectivity_factor: float = CONNECTIVITY_FACTOR,
) -> list[int]:
    """Filter facets that can safely accommodate atom placement.

    Analyzes facet geometry to determine which facets can safely have an atom
    placed on them without clashes or connectivity issues. This avoids
    trial-and-error approaches.

    Args:
        atoms: The Atoms object representing the current cluster structure.
        facet_properties: List of (centroid, normal, area, (min_dist, max_dist))
            for each facet.
        bond_distance: Target bond distance for placement.
        min_connectivity_dist: Minimum connectivity distance threshold.
        max_connectivity_dist: Maximum connectivity distance threshold.
        min_distance_factor: Factor for minimum distance checks.
        new_atom_radius: Covalent radius of the atom to be placed.
        positions: Array of existing atom positions.
        symbols_list: List of element symbols for existing atoms.
        connectivity_factor: Factor to multiply sum of covalent radii for
            connectivity threshold. Defaults to CONNECTIVITY_FACTOR.

    Returns:
        List of facet indices that are safe for placement.
    """
    safe_facet_indices = []

    for idx, (
        centroid,
        normal,
        _area,
        (min_centroid_dist, max_centroid_dist),
    ) in enumerate(facet_properties):
        # Skip facets where centroid is significantly beyond connectivity threshold
        # Use scaled threshold to allow large facets with relaxed constraints
        # For very large facets (min_centroid_dist > 1.5 * max_connectivity_dist),
        # skip them as they're unlikely to yield valid placements even with relaxation
        if (
            max_connectivity_dist is not None
            and min_centroid_dist > max_connectivity_dist * 1.5
        ):
            continue

        # Adjust bond_distance based on facet geometry
        adjusted_bond_distance = _adjust_bond_distance_for_facet_geometry(
            bond_distance,
            min_centroid_dist,
            max_centroid_dist,
            min_connectivity_dist,
            max_connectivity_dist,
        )

        # Calculate expected placement position (without perturbation for analysis)
        expected_pos = centroid + normal * adjusted_bond_distance

        is_safe = True

        if min_distance_factor is not None and new_atom_radius is not None:
            for i, existing_pos in enumerate(positions):
                dist = np.linalg.norm(expected_pos - existing_pos)
                r_existing = get_covalent_radius(symbols_list[i])
                min_allowed = (new_atom_radius + r_existing) * min_distance_factor
                if dist < min_allowed:
                    is_safe = False
                    break

            if (
                is_safe
                and max_connectivity_dist is not None
                and new_atom_radius is not None
            ):
                is_connected = False
                for i, existing_pos in enumerate(positions):
                    dist = np.linalg.norm(expected_pos - existing_pos)
                    r_existing = get_covalent_radius(symbols_list[i])
                    # Connectivity threshold for this pair
                    connectivity_threshold = (
                        new_atom_radius + r_existing
                    ) * connectivity_factor
                    if dist <= connectivity_threshold:
                        is_connected = True
                        break
                if not is_connected:
                    is_safe = False

        if is_safe:
            safe_facet_indices.append(idx)

    return safe_facet_indices


def _generate_batch_positions_on_convex_hull(
    atoms: Atoms,
    n_candidates: int,
    bond_distance: float,
    rng: np.random.Generator,
    min_connectivity_dist: float | None = None,
    max_connectivity_dist: float | None = None,
    use_all_facets: bool = False,
    min_distance_factor: float | None = None,
    new_atom_symbol: str | None = None,
    smart_facet_filtering: bool = False,
    connectivity_factor: float = CONNECTIVITY_FACTOR,
) -> list[np.ndarray]:
    """Generate multiple candidate atom positions on convex hull facets.

    This function computes the convex hull once and generates candidate positions
    for multiple atoms (one per facet, up to n_candidates). This is more efficient
    than computing the hull multiple times.

    Args:
        atoms: The Atoms object representing the current cluster structure.
        n_candidates: Maximum number of candidate positions to generate. Ignored if
            use_all_facets is True.
        bond_distance: Distance (in Angstroms) from the surface at which to place
                      the new atoms. This represents the bond separation distance.
        rng: Numpy random number generator for reproducible randomness.
        min_connectivity_dist: Optional minimum distance constraint. If provided along
                              with max_connectivity_dist, bond_distance will be adjusted
                              to ensure connectivity with at least one facet vertex.
        max_connectivity_dist: Optional maximum distance constraint for connectivity.
                               If provided along with min_connectivity_dist, bond_distance
                               will be adjusted based on facet geometry.
        use_all_facets: If True, generate one position per facet (all facets) in
            deterministic order (area descending); n_candidates is ignored.
        min_distance_factor: Optional factor for minimum distance checks. If provided
                             along with new_atom_symbol, candidates will be validated
                             for clashes before being returned.
        new_atom_symbol: Optional symbol of the atom to be placed. Used with
                        min_distance_factor for clash validation.
        smart_facet_filtering: If True, pre-filter facets based on geometry analysis
                               to avoid trial-and-error. Only generates candidates
                               for facets that are known to be safe.
        connectivity_factor: Factor to multiply sum of covalent radii for
                             connectivity threshold. Defaults to CONNECTIVITY_FACTOR.

    Returns:
        List of 3D numpy arrays representing candidate positions for new atoms.
        The list may have fewer than n_candidates if there are fewer facets available
        or if some candidates fail validation. When use_all_facets is True, returns
        one position per facet (after validation).

    Note:
        For clusters with <4 atoms, returns empty list (caller should handle fallback).
    """
    if len(atoms) < 4:
        # Cannot compute convex hull for <4 atoms
        return []

    # Get convex hull from cache or compute it
    positions = atoms.get_positions()
    try:
        hull = _get_cached_hull(positions)
    except ValueError:
        # Convex hull computation failed (degenerate geometry)
        return []

    # Compute facet properties using shared helper
    facet_properties = _compute_facet_properties(hull, atoms)
    facet_areas = [prop[2] for prop in facet_properties]
    facet_centroids = [prop[0] for prop in facet_properties]
    facet_normals = [prop[1] for prop in facet_properties]
    facet_vertex_distances = [prop[3] for prop in facet_properties]

    total_area = sum(facet_areas)
    if total_area == 0:  # Handle degenerate cases
        return []

    # Sort facets by area (descending), then centroid coordinates for reproducibility
    facet_indices = list(range(len(facet_areas)))
    facet_indices.sort(
        key=lambda i: (
            -facet_areas[i],  # Negative for descending order
            round(facet_centroids[i][0], 8),
            round(facet_centroids[i][1], 8),
            round(facet_centroids[i][2], 8),
        )
    )

    # Reorder facets according to deterministic sort
    sorted_facet_centroids = [facet_centroids[i] for i in facet_indices]
    sorted_facet_normals = [facet_normals[i] for i in facet_indices]
    sorted_facet_vertex_distances = [facet_vertex_distances[i] for i in facet_indices]

    # Get symbols for validation if enabled
    symbols_list = atoms.get_chemical_symbols()
    new_atom_radius = None
    if new_atom_symbol is not None:
        new_atom_radius = get_covalent_radius(new_atom_symbol)

    n_facets = len(sorted_facet_centroids)

    # Pre-filter facets if smart filtering is enabled
    if (
        smart_facet_filtering
        and min_distance_factor is not None
        and new_atom_radius is not None
    ):
        # Create sorted facet properties list (using sorted indices)
        sorted_facet_properties = [
            (
                sorted_facet_centroids[i],
                sorted_facet_normals[i],
                facet_areas[facet_indices[i]],
                sorted_facet_vertex_distances[i],
            )
            for i in range(n_facets)
        ]
        # Filter to safe facets (returns indices in sorted_facet_properties order)
        safe_sorted_indices = _filter_safe_facets_for_placement(
            atoms,
            sorted_facet_properties,
            bond_distance,
            min_connectivity_dist,
            max_connectivity_dist,
            min_distance_factor,
            new_atom_radius,
            positions,
            symbols_list,
            connectivity_factor=connectivity_factor,
        )
        n_safe_facets = len(safe_sorted_indices)
    else:
        safe_sorted_indices = list(range(n_facets))
        n_safe_facets = n_facets

    if use_all_facets:
        # Use all safe facets (or all facets if not filtering)
        selected_indices = safe_sorted_indices
        n_to_select = len(selected_indices)
    else:
        n_to_select = min(n_candidates, n_safe_facets)
        if n_to_select == 0:
            return []
        if smart_facet_filtering and len(safe_sorted_indices) > 0:
            # Select from safe facets based on area
            sorted_safe_areas = [
                facet_areas[facet_indices[i]] for i in safe_sorted_indices
            ]
            safe_total_area = sum(sorted_safe_areas)
            if safe_total_area > 0:
                probabilities = np.array(sorted_safe_areas) / safe_total_area
                selected_safe = rng.choice(
                    len(safe_sorted_indices),
                    size=n_to_select,
                    replace=False,
                    p=probabilities,
                )
                selected_indices = [safe_sorted_indices[i] for i in selected_safe]
            else:
                selected_indices = safe_sorted_indices[:n_to_select]
        else:
            sorted_facet_areas = [
                facet_areas[facet_indices[i]] for i in range(n_facets)
            ]
            probabilities = np.array(sorted_facet_areas) / total_area
            selected_indices = rng.choice(
                n_facets, size=n_to_select, replace=False, p=probabilities
            )

    # Generate candidate positions for selected facets
    candidates = []
    perturbation_scale = CONVEX_HULL_PERTURBATION_SCALE

    for idx in selected_indices:
        chosen_centroid = sorted_facet_centroids[idx]
        chosen_normal = sorted_facet_normals[idx]
        min_centroid_dist, max_centroid_dist = sorted_facet_vertex_distances[idx]

        # Skip facets where centroid is significantly beyond connectivity threshold
        # For very large facets (min_centroid_dist > 1.5 * max_connectivity_dist),
        # skip them as they're unlikely to yield valid placements
        if (
            max_connectivity_dist is not None
            and min_centroid_dist > max_connectivity_dist * 1.5
        ):
            continue

        # For large facets, place closer to vertices instead of centroid
        # This ensures atoms are within connectivity distance of existing atoms
        facet_size_ratio = (
            min_centroid_dist / max_connectivity_dist
            if max_connectivity_dist is not None and max_connectivity_dist > 0
            else 0.0
        )
        if facet_size_ratio > 0.6:
            # Large facet: interpolate between centroid and nearest vertex
            original_facet_idx = facet_indices[idx]
            if original_facet_idx < len(hull.simplices):
                facet_vertex_indices = hull.simplices[original_facet_idx]
                facet_vertex_positions = positions[facet_vertex_indices]

                # Find nearest vertex to centroid
                vertex_dists = [
                    np.linalg.norm(vpos - chosen_centroid)
                    for vpos in facet_vertex_positions
                ]
                nearest_vertex_idx = np.argmin(vertex_dists)
                nearest_vertex = facet_vertex_positions[nearest_vertex_idx]

                # Interpolate: for ratio > 0.6, move from centroid toward nearest vertex
                # At ratio=0.6: use centroid (0% interpolation)
                # At ratio=1.0: use 50% toward vertex
                # At ratio=1.5: use 100% at vertex
                if facet_size_ratio <= 1.0:
                    # 0 to 0.5 interpolation
                    interpolation_factor = (facet_size_ratio - 0.6) / 0.4 * 0.5
                else:
                    # 0.5 to 1.0 interpolation
                    interpolation_factor = 0.5 + (facet_size_ratio - 1.0) / 0.5 * 0.5

                placement_base = (
                    chosen_centroid * (1 - interpolation_factor)
                    + nearest_vertex * interpolation_factor
                )
            else:
                placement_base = chosen_centroid
        else:
            # Small/medium facets: use centroid
            placement_base = chosen_centroid

        # Adjust bond_distance based on facet geometry if constraints are provided
        adjusted_bond_distance = _adjust_bond_distance_for_facet_geometry(
            bond_distance,
            min_centroid_dist,
            max_centroid_dist,
            min_connectivity_dist,
            max_connectivity_dist,
        )

        # Use more conservative perturbation when close to connectivity limit
        effective_perturbation_scale = perturbation_scale
        if (
            max_connectivity_dist is not None
            and min_centroid_dist is not None
            and min_centroid_dist > max_connectivity_dist * 0.8
        ):
            # Reduce perturbation when close to connectivity limit
            effective_perturbation_scale = perturbation_scale * 0.5

        if smart_facet_filtering:
            perturbation = rng.standard_normal(3) * (
                effective_perturbation_scale * SMART_FILTERING_PERTURBATION_SCALE
            )
        else:
            perturbation = rng.standard_normal(3) * effective_perturbation_scale

        candidate_pos = (
            placement_base + chosen_normal * adjusted_bond_distance + perturbation
        )

        # Always validate if constraints are provided, even with smart filtering,
        # because perturbation or interpolation might shift the point into a clash.
        is_valid = True
        if min_distance_factor is not None and new_atom_radius is not None:
            for i, existing_pos in enumerate(positions):
                dist = np.linalg.norm(candidate_pos - existing_pos)
                r_existing = get_covalent_radius(symbols_list[i])
                min_allowed = (new_atom_radius + r_existing) * min_distance_factor
                if dist < min_allowed:
                    is_valid = False
                    break

            if (
                is_valid
                and max_connectivity_dist is not None
                and new_atom_radius is not None
            ):
                is_connected = False
                for i, existing_pos in enumerate(positions):
                    dist = np.linalg.norm(candidate_pos - existing_pos)
                    r_existing = get_covalent_radius(symbols_list[i])
                    # Connectivity threshold for this pair
                    connectivity_threshold = (
                        new_atom_radius + r_existing
                    ) * connectivity_factor
                    if dist <= connectivity_threshold:
                        is_connected = True
                        break
                if not is_connected:
                    is_valid = False

        if not is_valid:
            # Skip this candidate if validation failed
            # Debug logging for large facets
            if (
                max_connectivity_dist is not None
                and min_centroid_dist > max_connectivity_dist * 0.7
                and min_distance_factor is not None
                and new_atom_radius is not None
            ):
                min_dist_to_existing = min(
                    np.linalg.norm(candidate_pos - existing_pos)
                    for existing_pos in positions
                )
                template_debug_logger.debug(
                    f"candidate rejected: min_dist={min_dist_to_existing:.3f}, "
                    f"min_centroid_dist={min_centroid_dist:.3f}, "
                    f"max_conn={max_connectivity_dist:.3f}"
                )
            continue

        # Add candidate (either pre-validated or passed validation)
        candidates.append(candidate_pos)

    return candidates


def get_largest_facets(
    atoms: Atoms, n_facets: int = 3
) -> list[tuple[np.ndarray, np.ndarray, float]]:
    """Get the largest facets of a cluster's convex hull.

    Args:
        atoms: The Atoms object representing the cluster
        n_facets: Number of largest facets to return

    Returns:
        List of tuples (centroid, normal, area) for the largest facets

    """
    if len(atoms) < 4:
        center = atoms.get_center_of_mass()
        return [(center, np.array([1.0, 0.0, 0.0]), 1.0)]

    try:
        hull = _get_cached_hull(atoms.get_positions())
    except (ValueError, RuntimeError, QhullError):
        # Convex hull computation failed (degenerate geometry, collinear points, etc.)
        center = atoms.get_center_of_mass()
        geometry = _classify_seed_geometry(atoms)

        if geometry in ["linear", "planar"]:
            positions = atoms.get_positions()
            centered_positions = positions - center
            if len(positions) > 1:
                cov_matrix = np.cov(centered_positions.T)
                eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
                normal = eigenvectors[:, -1] / np.linalg.norm(eigenvectors[:, -1])
            else:
                normal = np.array([1.0, 0.0, 0.0])
        else:
            normal = np.array([1.0, 0.0, 0.0])
        return [(center, normal, 1.0)]

    # Compute facet properties using shared helper
    facet_properties = _compute_facet_properties(hull, atoms)
    facets = [(prop[0], prop[1], prop[2]) for prop in facet_properties]

    # Sort by area and return the largest ones
    facets.sort(key=lambda x: x[2], reverse=True)
    return facets[:n_facets]


def _classify_seed_geometry(atoms: Atoms) -> str:
    """Classify the geometric structure of a seed cluster.

    Args:
        atoms: The Atoms object to classify

    Returns:
        Geometry classification: "single", "linear", "planar", or "3d"
    """
    n_atoms = len(atoms)

    if n_atoms == 1:
        return "single"

    if n_atoms == 2:
        return "linear"

    if n_atoms >= 3:
        positions = atoms.get_positions()
        center = np.mean(positions, axis=0)
        centered_positions = positions - center

        cov_matrix = np.cov(centered_positions.T)
        eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)

        # Check if structure is linear (1D) vs planar (2D)
        # Linear: λ1 ≈ 0 AND λ2 ≈ 0 (only one dimension has variation)
        # Planar: λ1 ≈ 0 BUT λ2 > 0 (two dimensions have variation)
        # Use tolerance to allow for structures that are almost linear but not perfectly linear
        linear_tolerance = LINEAR_GEOMETRY_TOLERANCE
        if (
            eigenvalues[0] < linear_tolerance * eigenvalues[2]
            and eigenvalues[1] < linear_tolerance * eigenvalues[2]
        ):
            return "linear"  # Both λ1 ≈ 0 and λ2 ≈ 0 → 1D linear (or almost linear)
        # λ1 ≈ 0 but λ2 > tolerance → 2D planar (fall through to convex hull check)

    try:
        if len(positions) >= 4:
            hull = _get_cached_hull(positions)
            if hull.volume < CONVEX_HULL_VOLUME_TOLERANCE:
                return "planar"
            return "3d"
        else:
            return "planar"

    except (QhullError, ValueError, RuntimeError):
        # Convex hull computation failed (degenerate geometry, collinear points, etc.)
        return "planar"


def _generate_rotation_matrix(axis: np.ndarray, angle: float) -> np.ndarray:
    """Generate a 3D rotation matrix using Rodrigues' rotation formula.

    Args:
        axis: Rotation axis vector (will be normalized)
        angle: Rotation angle in radians

    Returns:
        3x3 rotation matrix
    """
    axis = np.asarray(axis)
    axis_norm = np.linalg.norm(axis)
    if axis_norm < ROTATION_AXIS_TOLERANCE:
        # Degenerate case: return identity matrix
        return np.eye(3)
    axis = axis / axis_norm

    # Skew-symmetric matrix for cross product
    K = np.array(
        [
            [0, -axis[2], axis[1]],
            [axis[2], 0, -axis[0]],
            [-axis[1], axis[0], 0],
        ]
    )
    # Rodrigues' rotation formula
    R = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * np.dot(K, K)
    return R


def place_multi_atom_seed_on_facet(
    seed_atoms: Atoms,
    target_facet_centroid: np.ndarray,
    target_facet_normal: np.ndarray,
    bond_distance: float,
    rng: np.random.Generator,
) -> Atoms:
    """Place a multi-atom seed so that its largest facet contacts the target facet.

    Args:
        seed_atoms: The seed to place
        target_facet_centroid: Centroid of the target facet
        target_facet_normal: Normal vector of the target facet
        bond_distance: Desired bond distance between facets
        rng: Random number generator for rotation

    Returns:
        The seed atoms with new positions

    """
    # Get the largest facet of the seed
    seed_facets = get_largest_facets(seed_atoms, n_facets=1)
    if not seed_facets:
        # Fallback: use center of mass
        seed_normal = np.array([1.0, 0.0, 0.0])
    else:
        seed_facet_centroid, seed_facet_normal, _ = seed_facets[0]
        seed_normal = seed_facet_normal

    # Create a copy to work with
    placed_seed = seed_atoms.copy()

    # Step 1: Rotate the seed so its facet normal aligns with the target normal
    # We want the seed normal to point towards the target (opposite direction)
    target_direction = -target_facet_normal

    # Calculate rotation axis and angle
    rotation_axis = np.cross(seed_normal, target_direction)
    rotation_axis_norm = np.linalg.norm(rotation_axis)

    if rotation_axis_norm > 1e-6:  # Not parallel
        rotation_axis /= rotation_axis_norm
        cos_angle = np.dot(seed_normal, target_direction)
        cos_angle = np.clip(cos_angle, -1.0, 1.0)  # Handle numerical errors
        rotation_angle = np.arccos(cos_angle)

        # Apply rotation using Rodrigues' formula
        R = _generate_rotation_matrix(rotation_axis, rotation_angle)

        # Apply rotation around the center of mass
        center = placed_seed.get_center_of_mass()
        positions = placed_seed.get_positions()
        rotated_positions = center + (positions - center) @ R.T
        placed_seed.set_positions(rotated_positions)

    # Step 2: Translate the seed so its facet contacts the target facet
    # Get the new facet position after rotation
    new_facets = get_largest_facets(placed_seed, n_facets=1)
    if new_facets:
        new_facet_centroid, _, _ = new_facets[0]
    else:
        new_facet_centroid = placed_seed.get_center_of_mass()

    # Calculate translation vector
    # The facet centroids may be inside the clusters, so we need separation
    target_position = target_facet_centroid + target_facet_normal * bond_distance
    translation = target_position - new_facet_centroid

    # Apply translation
    placed_seed.translate(translation)

    return placed_seed


def _find_connected_components(
    atoms: Atoms, connectivity_factor: float, use_mic: bool = False
) -> tuple[dict[int, list[int]], list[int]]:
    """Find connected components using Union-Find algorithm.

    Args:
        atoms: The Atoms object to check
        connectivity_factor: Factor to multiply sum of covalent radii for connectivity threshold
        use_mic: If True, use minimum image convention for distance calculations

    Returns:
        Tuple of (components dict mapping root to atom indices, parent array for Union-Find)
    """
    if len(atoms) <= 1:
        return {0: [0] if len(atoms) == 1 else []}, list(range(len(atoms)))

    positions = atoms.get_positions()
    symbols = atoms.get_chemical_symbols()
    n_atoms = len(atoms)

    parent = list(range(n_atoms))

    def find(x: int) -> int:
        """Find root of x with path compression."""
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]

    def union(x: int, y: int) -> bool:
        """Union two components. Returns True if union was performed."""
        px, py = find(x), find(y)
        if px != py:
            parent[px] = py
            return True
        return False

    if n_atoms < 50:
        for i in range(n_atoms):
            for j in range(i + 1, n_atoms):
                if use_mic:
                    distance = float(atoms.get_distance(i, j, mic=True))
                else:
                    distance = np.linalg.norm(positions[i] - positions[j])
                r_i = get_covalent_radius(symbols[i])
                r_j = get_covalent_radius(symbols[j])
                threshold = (r_i + r_j) * connectivity_factor
                if distance <= threshold:
                    union(i, j)
    else:
        tree = KDTree(positions)
        unique_radii = {get_covalent_radius(s) for s in symbols}
        max_radius = max(unique_radii)
        query_radius = 2 * max_radius * connectivity_factor

        for i in range(n_atoms):
            neighbor_indices = tree.query_ball_point(positions[i], query_radius)
            r_i = get_covalent_radius(symbols[i])

            for j in neighbor_indices:
                if j <= i:
                    continue

                if use_mic:
                    distance = float(atoms.get_distance(i, j, mic=True))
                else:
                    distance = np.linalg.norm(positions[i] - positions[j])
                r_j = get_covalent_radius(symbols[j])
                threshold = (r_i + r_j) * connectivity_factor

                if distance <= threshold:
                    union(i, j)

    components: dict[int, list[int]] = {}
    for i in range(n_atoms):
        root = find(i)
        if root not in components:
            components[root] = []
        components[root].append(i)

    return components, parent


[docs] def is_cluster_connected( atoms: Atoms, connectivity_factor: float = CONNECTIVITY_FACTOR, use_mic: bool = False, ) -> bool: """Check if all atoms in a cluster are connected within the specified distance threshold. Uses a Union-Find algorithm with KDTree spatial indexing to efficiently determine if all atoms form a single connected component where edges exist between atoms within (r_i + r_j) * connectivity_factor. This optimized version uses scipy.spatial.KDTree for efficient neighbor queries, providing O(n log n) performance instead of O(n²) for large clusters. Args: atoms: The Atoms object to check connectivity_factor: Factor to multiply sum of covalent radii for connectivity threshold. Defaults to CONNECTIVITY_FACTOR (1.4). use_mic: If True, use minimum image convention for distance calculations. Returns: True if all atoms are in one connected component, False otherwise. """ components, _ = _find_connected_components(atoms, connectivity_factor, use_mic) return len(components) <= 1
def analyze_disconnection( atoms: Atoms, connectivity_factor: float = CONNECTIVITY_FACTOR, use_mic: bool = False, ) -> tuple[float, float, str]: """Analyze disconnection in a cluster and suggest appropriate connectivity factor. Args: atoms: The Atoms object to analyze connectivity_factor: Current connectivity factor used use_mic: If True, use minimum image convention for distance calculations. Returns: Tuple of (max_disconnection_distance, suggested_connectivity_factor, analysis_message) """ if len(atoms) <= 1: return 0.0, connectivity_factor, "Single atom or empty cluster" components, _ = _find_connected_components(atoms, connectivity_factor, use_mic) if len(components) <= 1: return 0.0, connectivity_factor, "Cluster is connected" positions = atoms.get_positions() symbols = atoms.get_chemical_symbols() min_inter_component_distance = float("inf") closest_atoms = None component_list = list(components.values()) for i in range(len(component_list)): for j in range(i + 1, len(component_list)): comp1, comp2 = component_list[i], component_list[j] for atom1 in comp1: for atom2 in comp2: distance = np.linalg.norm(positions[atom1] - positions[atom2]) if distance < min_inter_component_distance: min_inter_component_distance = distance closest_atoms = (atom1, atom2, symbols[atom1], symbols[atom2]) if closest_atoms is None: return float("inf"), connectivity_factor, "Unable to analyze disconnection" atom1_idx, atom2_idx, sym1, sym2 = closest_atoms r1 = get_covalent_radius(sym1) r2 = get_covalent_radius(sym2) # Calculate what connectivity factor would be needed to connect these atoms suggested_factor = min_inter_component_distance / (r1 + r2) # Add a small buffer to ensure connectivity suggested_factor *= CONNECTIVITY_SUGGESTION_BUFFER analysis_msg = ( f"Cluster has {len(components)} disconnected components. " f"Closest atoms are {sym1}({atom1_idx}) and {sym2}({atom2_idx}) " f"at distance {min_inter_component_distance:.3f}Å. " f"Suggested connectivity_factor: {suggested_factor:.2f}" ) return min_inter_component_distance, suggested_factor, analysis_msg def _identify_safe_removal_candidates( cluster: Atoms, candidate_indices: list[int], connectivity_factor: float, use_mic: bool = False, max_to_check: int = 10, ) -> list[int]: """Identify which candidates can be safely removed without disconnecting. Pre-checks connectivity impact before actual removal by testing each candidate atom's removal and verifying the cluster remains connected. Args: cluster: The cluster to analyze candidate_indices: List of atom indices to test for safe removal connectivity_factor: Factor for connectivity threshold max_to_check: Maximum number of candidates to check (for performance) Returns: List of atom indices that can be safely removed without disconnecting the cluster. Empty list if none are safe or if cluster is too small. """ if len(cluster) <= 2: # Removing from 1-2 atom clusters would leave disconnected or empty return [] if not candidate_indices: return [] # Limit checks for performance candidates_to_check = candidate_indices[:max_to_check] safe_candidates = [] for idx in candidates_to_check: if idx >= len(cluster) or idx < 0: continue # Create test cluster without this atom test_cluster = cluster.copy() del test_cluster[idx] # Check if cluster remains connected after removal if len(test_cluster) > 1: if is_cluster_connected(test_cluster, connectivity_factor, use_mic): safe_candidates.append(idx) else: # Single atom left - always connected safe_candidates.append(idx) return safe_candidates
[docs] class StructureDiagnostics: """Container for comprehensive structure diagnostics. Attributes: is_valid: True if structure has no clashes and is connected has_clashes: True if atomic clashes were detected is_disconnected: True if cluster has multiple disconnected components clash_details: List of clash description strings n_components: Number of disconnected components (1 if connected) closest_inter_component_distance: Distance between closest atoms in different components suggested_connectivity_factor: Connectivity factor needed to connect all components summary: Human-readable summary of all issues """ def __init__( self, is_valid: bool, has_clashes: bool, is_disconnected: bool, clash_details: list[str], n_components: int, closest_inter_component_distance: float, suggested_connectivity_factor: float, summary: str, ): self.is_valid = is_valid self.has_clashes = has_clashes self.is_disconnected = is_disconnected self.clash_details = clash_details self.n_components = n_components self.closest_inter_component_distance = closest_inter_component_distance self.suggested_connectivity_factor = suggested_connectivity_factor self.summary = summary
[docs] def get_structure_diagnostics( atoms: Atoms, min_distance_factor: float, connectivity_factor: float, use_mic: bool = False, ) -> StructureDiagnostics: """Get comprehensive diagnostics for a cluster structure. This function analyzes both clashes and connectivity issues and returns detailed diagnostic information useful for debugging initialization failures. Args: atoms: The Atoms object to analyze min_distance_factor: Factor to scale covalent radii for minimum distance checks connectivity_factor: Factor to multiply sum of covalent radii for connectivity threshold Returns: StructureDiagnostics object containing detailed analysis results """ if len(atoms) == 0: return StructureDiagnostics( is_valid=True, has_clashes=False, is_disconnected=False, clash_details=[], n_components=0, closest_inter_component_distance=0.0, suggested_connectivity_factor=connectivity_factor, summary="Empty cluster", ) positions = atoms.get_positions() symbols = atoms.get_chemical_symbols() n_atoms = len(atoms) # Analyze clashes # Pre-compute covalent radii to avoid repeated lookups radii = {symbol: get_covalent_radius(symbol) for symbol in set(symbols)} clash_details = [] for i in range(n_atoms): if len(clash_details) >= 10: break for j in range(i + 1, n_atoms): distance = np.linalg.norm(positions[i] - positions[j]) r_i = radii[symbols[i]] r_j = radii[symbols[j]] min_allowed = (r_i + r_j) * min_distance_factor if distance < min_allowed - CLASH_TOLERANCE: clash_details.append( f"{symbols[i]}({i})-{symbols[j]}({j}): " f"{distance:.3f}Å < {min_allowed:.3f}Å (gap: {min_allowed - distance:.3f}Å)" ) if len(clash_details) >= 10: break has_clashes = bool(clash_details) # Analyze connectivity components, _ = _find_connected_components(atoms, connectivity_factor, use_mic) n_components = len(components) is_disconnected = n_components > 1 closest_inter_component_distance = 0.0 suggested_connectivity_factor = connectivity_factor if is_disconnected: # Find closest inter-component distance min_dist = float("inf") closest_pair = None component_list = list(components.values()) for ci in range(len(component_list)): for cj in range(ci + 1, len(component_list)): for atom1 in component_list[ci]: for atom2 in component_list[cj]: dist = np.linalg.norm(positions[atom1] - positions[atom2]) if dist < min_dist: min_dist = dist closest_pair = (atom1, atom2) closest_inter_component_distance = min_dist if closest_pair is not None: i1, i2 = closest_pair r1 = get_covalent_radius(symbols[i1]) r2 = get_covalent_radius(symbols[i2]) suggested_connectivity_factor = ( min_dist / (r1 + r2) ) * CONNECTIVITY_SUGGESTION_BUFFER # Build summary summary_parts = [] if has_clashes: summary_parts.append( f"Clashes ({len(clash_details)}): {'; '.join(clash_details[:3])}" + (f" (+{len(clash_details) - 3} more)" if len(clash_details) > 3 else "") ) if is_disconnected: summary_parts.append( f"Disconnected: {n_components} components, " f"gap={closest_inter_component_distance:.3f}Å, " f"suggested factor={suggested_connectivity_factor:.2f}" ) if summary_parts: summary = "; ".join(summary_parts) else: summary = "Structure is valid (no clashes, connected)" return StructureDiagnostics( is_valid=not has_clashes and not is_disconnected, has_clashes=has_clashes, is_disconnected=is_disconnected, clash_details=clash_details, n_components=n_components, closest_inter_component_distance=closest_inter_component_distance, suggested_connectivity_factor=suggested_connectivity_factor, summary=summary, )
[docs] def validate_cluster_structure( atoms: Atoms, min_distance_factor: float, connectivity_factor: float, check_clashes: bool = True, check_connectivity: bool = True, use_mic: bool = False, ) -> tuple[bool, str]: """Validate a cluster structure for clashes and connectivity. This function provides a centralized validation that ensures all returned cluster structures meet the specified constraints. It checks for atomic clashes and connectivity using the same logic as the placement algorithms. Args: atoms: The Atoms object to validate min_distance_factor: Factor to scale covalent radii for minimum distance checks connectivity_factor: Factor to multiply sum of covalent radii for connectivity threshold check_clashes: Whether to check for atomic clashes (default: True) check_connectivity: Whether to check connectivity (default: True) Returns: Tuple of (is_valid, error_message). If is_valid is True, error_message is empty. If is_valid is False, error_message contains diagnostic information. """ # Early exit if no checks requested if not check_clashes and not check_connectivity: return True, "" # Use get_structure_diagnostics for the actual analysis diagnostics = get_structure_diagnostics( atoms, min_distance_factor, connectivity_factor, use_mic ) # Filter based on what checks are requested has_issues = False error_parts = [] if check_clashes and diagnostics.has_clashes: has_issues = True error_parts.append( f"Atomic clashes detected with min_distance_factor={min_distance_factor}:\n" f" " + "\n ".join(diagnostics.clash_details[:5]) ) if check_connectivity and diagnostics.is_disconnected: has_issues = True error_parts.append( f"Cluster is not connected with connectivity_factor={connectivity_factor}. " f"Atoms are not within bonding distance of each other.", ) if has_issues: composition = atoms.get_chemical_formula() error_message = ( f"Validation failed for {composition} cluster ({len(atoms)} atoms):\n" + "\n".join(error_parts) ) return False, error_message return True, ""
def reorder_cluster_to_composition(cluster: Atoms, composition: Sequence[str]) -> Atoms: """Reorder cluster atoms to match the campaign composition symbol sequence. GA cut-and-splice pairing requires identical per-index atomic numbers across parents, so all structures for a given composition must share the same order. """ desired = list(composition) current = cluster.get_chemical_symbols() if current == desired: return cluster by_symbol: dict[str, list[int]] = {} for idx, sym in enumerate(current): by_symbol.setdefault(sym, []).append(idx) selection: list[int] = [] for sym in desired: matching = by_symbol.get(sym) if not matching: raise ValueError( "Generated cluster symbols do not match requested composition." ) selection.append(matching.pop(0)) return cluster[selection].copy()
[docs] def validate_cluster( atoms: Atoms, composition: list[str] | None = None, min_distance_factor: float | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, check_clashes: bool = True, check_connectivity: bool | None = None, sort_atoms: bool = True, raise_on_failure: bool = False, source: str = "", use_mic: bool = False, ) -> tuple[Atoms, bool, str]: """Unified cluster validation with comprehensive checks. This function consolidates all validation logic used across the initialization module. It can check composition, clashes, connectivity, and optionally sort atoms by element. Args: atoms: The Atoms object to validate composition: Optional expected composition to verify exact match min_distance_factor: Factor for minimum distance checks. If None, uses MIN_DISTANCE_FACTOR_DEFAULT when check_clashes is True connectivity_factor: Factor for connectivity threshold check_clashes: Whether to check for atomic clashes (default: True) check_connectivity: Whether to check connectivity. If None, auto-detects based on atom count (>2 atoms) sort_atoms: When True and ``composition`` is set, reorder atoms to match the composition list (required for GA pairing). When True without ``composition``, fall back to alphabetical element sort. raise_on_failure: Whether to raise ValueError on validation failure source: Context string for error messages (e.g., "template", "seed+growth") Returns: Tuple of (validated_atoms, is_valid, error_message). If is_valid is True, error_message is empty. validated_atoms may be reordered if sort_atoms=True. Raises: ValueError: If raise_on_failure=True and validation fails """ # Auto-detect if we should check connectivity if check_connectivity is None: check_connectivity = _should_check_connectivity(atoms) # Use default min_distance_factor if not provided and checks are requested if min_distance_factor is None and (check_clashes or check_connectivity): min_distance_factor = MIN_DISTANCE_FACTOR_DEFAULT # Verify exact composition if provided if composition is not None and not _verify_exact_composition(atoms, composition): expected_counts = get_composition_counts(composition) actual_counts = get_composition_counts(atoms.get_chemical_symbols()) error_msg = ( f"{'[' + source + '] ' if source else ''}Composition mismatch. " f"Expected {composition} (counts: {expected_counts}), " f"got {atoms.get_chemical_symbols()} (counts: {actual_counts})" ) if raise_on_failure: raise ValueError(error_msg) return atoms, False, error_msg # Canonicalize atom order for GA-compatible pairing if sort_atoms: if composition is not None: atoms = reorder_cluster_to_composition(atoms, composition) else: atoms = _sort_atoms_by_element(atoms) # Validate structure (clashes and connectivity) if check_clashes or check_connectivity: is_valid, error_message = validate_cluster_structure( atoms, min_distance_factor, connectivity_factor, check_clashes=check_clashes, check_connectivity=check_connectivity, use_mic=use_mic, ) if not is_valid: full_error = f"{'[' + source + '] ' if source else ''}{error_message}" if raise_on_failure: raise ValueError(full_error) return atoms, False, full_error return atoms, True, ""
def _sort_atoms_by_element(atoms: Atoms) -> Atoms: """Sort atoms by element symbol to ensure consistent ordering. This ensures that clusters with the same composition always have the same atom ordering, which is required for GA pairing operations. Args: atoms: The Atoms object to sort Returns: A new Atoms object with atoms sorted by element symbol """ if len(atoms) <= 1: return atoms # No sorting needed for single atoms or empty clusters # Get element symbols symbols = atoms.get_chemical_symbols() # Early exit: check if already sorted is_sorted = all(symbols[i] <= symbols[i + 1] for i in range(len(symbols) - 1)) if is_sorted: # Already sorted, return copy to maintain API contract return atoms.copy() positions = atoms.get_positions() # Create tuples of (element_symbol, original_index) for stable sorting indexed_symbols = [(symbol, i) for i, symbol in enumerate(symbols)] # Sort by element symbol (alphabetically), then by original index for stability indexed_symbols.sort(key=lambda x: (x[0], x[1])) # Extract sorted indices sorted_indices = [idx for _, idx in indexed_symbols] # Create new Atoms object with sorted order sorted_symbols = [symbols[i] for i in sorted_indices] sorted_positions = positions[sorted_indices] sorted_atoms = Atoms( symbols=sorted_symbols, positions=sorted_positions, cell=atoms.get_cell(), pbc=atoms.get_pbc(), ) # Copy calculator and info if present if atoms.calc is not None: sorted_atoms.calc = atoms.calc if hasattr(atoms, "info") and atoms.info: sorted_atoms.info = atoms.info.copy() return sorted_atoms def _should_check_connectivity(atoms: Atoms) -> bool: """Determine if connectivity check should be performed for a cluster. Connectivity checks are only meaningful for clusters with more than 2 atoms. For very small clusters (<= 2 atoms), the notion of connectivity is ambiguous. Args: atoms: The Atoms object to check Returns: True if connectivity should be checked, False otherwise """ return len(atoms) > 2 def _verify_exact_composition(atoms: Atoms, composition: list[str]) -> bool: """Verify that atoms object has exactly the composition specified. Args: atoms: The Atoms object to verify composition: Target composition list Returns: True if composition matches exactly, False otherwise """ if len(atoms) != len(composition): return False atoms_symbols = atoms.get_chemical_symbols() atoms_counts = get_composition_counts(atoms_symbols) comp_counts = get_composition_counts(composition) return atoms_counts == comp_counts def _cycle_composition_to_length( composition: list[str], target_length: int ) -> list[str]: """Cycle a composition list to match a target length, producing exact element counts. This function repeats the composition list as many times as needed to reach the target length, then truncates to exactly match the target length. This produces exact element counts that match what cycling the pattern would create. Args: composition: List of element symbols to cycle target_length: Target length for the resulting composition list Returns: List of element symbols with length equal to target_length Example: >>> _cycle_composition_to_length(["Pt", "Au"], 5) ["Pt", "Au", "Pt", "Au", "Pt"] """ if not composition: raise ValueError("Cannot cycle empty composition to target length") if target_length <= 0: return [] n_cycles = (target_length // len(composition)) + ( 1 if target_length % len(composition) > 0 else 0 ) return (composition * n_cycles)[:target_length] def _assign_exact_composition( cluster: Atoms, composition: list[str], n_atoms: int | None = None, rng: np.random.Generator | None = None, ) -> Atoms: """Assign exact composition to cluster, ensuring atom count matches. This function ensures the final composition matches the target composition exactly by cycling through the composition list to match n_atoms, producing exact element counts. Args: cluster: The cluster to assign composition to composition: Target composition list (will be cycled to match n_atoms) n_atoms: Expected number of atoms. If None, uses len(cluster) rng: Optional RNG to shuffle the assigned composition (prevents patterns) Returns: Atoms object with exact composition assigned Raises: ValueError: If cluster atom count doesn't match n_atoms (when provided) or if composition assignment fails """ if n_atoms is None: n_atoms = len(cluster) if len(cluster) != n_atoms: raise ValueError( f"Cannot assign composition: cluster has {len(cluster)} atoms " f"but target is {n_atoms} atoms" ) # Assign exact composition if not composition: raise ValueError( f"Cannot assign empty composition to cluster with {n_atoms} atoms" ) elif len(composition) == n_atoms: # Exact match - use composition directly cluster.set_chemical_symbols(composition) else: # Create extended composition list by cycling extended_composition = _cycle_composition_to_length(composition, n_atoms) if rng is not None: rng.shuffle(extended_composition) # Verify exact counts match what cycling produced expected_counts = get_composition_counts(extended_composition) # Assign the extended composition cluster.set_chemical_symbols(extended_composition) # Verify the assignment produced exact counts actual_counts = get_composition_counts(cluster.get_chemical_symbols()) if actual_counts != expected_counts: raise ValueError( f"Composition assignment failed: expected counts {expected_counts}, " f"got {actual_counts} after assignment" ) return cluster def _compute_composition_delta( base_counts: dict[str, int], target_counts: dict[str, int], ) -> tuple[list[str], dict[str, int], dict[str, int]]: """Compute atoms to add and remove to match target composition. Args: base_counts: Current composition counts as dict mapping element to count target_counts: Target composition counts as dict mapping element to count Returns: Tuple of (atoms_to_add, atoms_to_remove_dict, excess_elements_dict): - atoms_to_add: List of element symbols to add - atoms_to_remove: Dict mapping element to count to remove - excess_elements: Dict of elements that would need to be removed but aren't present in sufficient quantity (for error reporting) """ atoms_to_add = [] atoms_to_remove = {} excess_elements = {} all_elements = set(base_counts.keys()) | set(target_counts.keys()) for elem in all_elements: base_count = base_counts.get(elem, 0) target_count = target_counts.get(elem, 0) diff = target_count - base_count if diff > 0: # Need to add this element atoms_to_add.extend([elem] * diff) elif diff < 0: # Need to remove this element removal_count = -diff if base_count >= removal_count: atoms_to_remove[elem] = removal_count else: # Can't remove enough - this is an excess element excess_elements[elem] = removal_count - base_count return atoms_to_add, atoms_to_remove, excess_elements def _check_composition_feasibility( base_composition: list[str], target_composition: list[str], operation: str = "grow", ) -> tuple[bool, str]: """Check if composition change is feasible. Validates whether it's possible to transform base_composition into target_composition through growth or reduction operations. Args: base_composition: Current composition as list of element symbols target_composition: Target composition as list of element symbols operation: Type of operation - "grow" or "reduce" (affects error messages) Returns: Tuple of (is_feasible, error_message): - is_feasible: True if operation is possible, False otherwise - error_message: Empty string if feasible, detailed error message if not """ # Handle empty composition cases if not base_composition: if not target_composition or operation == "grow": return True, "" return False, ( f"Cannot {operation} from empty composition to non-empty target " f"{target_composition}. Use growth operation instead." ) if not target_composition: if operation == "reduce": return True, "" return False, ( f"Cannot {operation} from non-empty composition {base_composition} " f"to empty target. Use reduction operation instead." ) base_counts = get_composition_counts(base_composition) target_counts = get_composition_counts(target_composition) _, atoms_to_remove, excess_elements = _compute_composition_delta( base_counts, target_counts ) # Check if we can achieve target composition if excess_elements: # Some elements need to be removed but aren't present in sufficient quantity excess_details = ", ".join( f"{elem} (need to remove {count} more than available)" for elem, count in excess_elements.items() ) return False, ( f"Cannot achieve target composition {target_composition} from " f"base {base_composition}: insufficient quantity of elements to remove. " f"Excess elements: {excess_details}" ) base_total = sum(base_counts.values()) target_total = sum(target_counts.values()) if operation == "grow" and target_total < base_total: return False, ( f"Cannot grow from {base_total} atoms to {target_total} atoms. " f"Target has fewer atoms than base. Use reduction operation instead." ) if operation == "reduce" and target_total > base_total: return False, ( f"Cannot reduce from {base_total} atoms to {target_total} atoms. " f"Target has more atoms than base. Use growth operation instead." ) return True, "" def clear_convex_hull_cache() -> None: """Clear the convex hull computation cache. This is useful for testing or when memory usage becomes a concern. """ get_global_cache().clear_namespace(_CONVEX_HULL_CACHE_NS)