Source code for scgo.initialization.templates

"""Template structure generators for high-symmetry cluster motifs.

This module provides functions to generate regular polyhedral structures
(icosahedra, decahedra, octahedra) using ASE's cluster module. These templates
are used in the smart initialization mode to ensure exploration of important
high-symmetry basins in the potential energy surface.

Based on the Doye group's comprehensive study of Morse/LJ clusters:
https://doye.chem.ox.ac.uk/jon/structures/Morse/paper/node5.html

Polyhedra scaling:
------------------
All template structures are scaled so the **longest vertex–vertex distance**
(or longest edge) equals **(r_i + r_j) × connectivity_factor**, with
a = 2 × avg covalent radius for the composition.

- Custom templates (Tetrahedron, Cube, Cuboctahedron, Truncated Octahedron):
  Apply this rule directly: geometric parameters (edge lengths, vertex
  positions) are derived from a and connectivity_factor so the longest
  edge ≤ a × connectivity_factor.

- ASE-based templates (Icosahedron, Decahedron, Octahedron):
  ASE scales using atomic radii. We rescale after generation via
  _rescale_cluster_to_bond_length so nn distances match covalent-based
  bond length, keeping structures within the connectivity threshold.

- Validation:
  All templates are validated with validate_cluster() (connectivity
  (r_i + r_j) × connectivity_factor, minimum distances × min_distance_factor).
  No magic numbers; all thresholds derive from covalent radii.
"""

from __future__ import annotations

from collections import Counter
from collections.abc import Callable
from logging import Logger
from threading import Event, Lock
from typing import Any, cast

import numpy as np
from ase import Atom, Atoms
from ase.cluster import Decahedron, Icosahedron, Octahedron
from ase.cluster.cluster import Cluster
from ase.symbols import Symbols
from numpy.random import Generator

from scgo.utils.helpers import get_composition_counts
from scgo.utils.logging import get_logger
from scgo.utils.rng_helpers import ensure_rng_or_create

from .geometry_helpers import (
    _assign_exact_composition,
    _check_composition_feasibility,
    _cycle_composition_to_length,
    _generate_batch_positions_on_convex_hull,
    _identify_safe_removal_candidates,
    _should_check_connectivity,
    _verify_exact_composition,
    clear_convex_hull_cache,
    compute_bond_distance_params,
    get_convex_hull_vertex_indices,
    get_covalent_radius,
    validate_cluster,
    validate_cluster_structure,
)
from .initialization_config import (
    BOND_DISTANCE_MULTIPLIER_2ATOM,
    BOND_DISTANCE_MULTIPLIER_3ATOM,
    CONNECTIVITY_FACTOR,
    MAGIC_NUMBER_TOLERANCE,
    MAGIC_NUMBERS,
    MIN_DISTANCE_FACTOR_DEFAULT,
    PLACEMENT_RADIUS_SCALING_DEFAULT,
    POSITION_COMPARISON_TOLERANCE_FACTOR,
    VACUUM_DEFAULT,
)
from .random_spherical import grow_from_seed

logger: Logger = get_logger(__name__)

ICOSAHEDRON_SHELL_TO_ATOMS: dict[int, int] = {1: 1, 2: 13, 3: 55, 4: 147, 5: 309}

_TEMPLATE_REGISTRY = {}
_VALID_TEMPLATE_TYPES_CACHE: dict[int, tuple[str, ...]] = {}
_VALID_TEMPLATE_TYPES_INFLIGHT: dict[int, Event] = {}
_VALID_TEMPLATE_TYPES_LOCK = Lock()


def _get_base_element(composition: list[str]) -> str:
    """Get the base element from composition.

    Args:
        composition: List of element symbols

    Returns:
        First element symbol

    Raises:
        ValueError: If composition is empty
    """
    if not composition:
        raise ValueError("Cannot get base element from empty composition")
    return composition[0]


def _get_typical_bond_length(composition: list[str]) -> float:
    """Calculate typical bond length from composition using covalent radii.

    Computes the average covalent radius for all unique elements in the
    composition and returns twice that value as the typical bond length.
    This provides element-specific scaling for template structures.

    Args:
        composition: List of element symbols

    Returns:
        Typical bond length in Angstroms (2 × average covalent radius)

    Raises:
        ValueError: If composition is empty
    """
    if not composition:
        raise ValueError("Cannot calculate bond length from empty composition")

    unique_elements: set[str] = set(composition)
    radii: list[float] = [get_covalent_radius(elem) for elem in unique_elements]
    avg_radius: float = sum(radii) / len(radii)
    return 2.0 * avg_radius


[docs] def get_nearest_magic_number(n_atoms: int) -> int | None: """Find the nearest magic number to the given atom count. Args: n_atoms: Number of atoms in the cluster Returns: The nearest magic number, or None if no magic numbers are defined """ if not MAGIC_NUMBERS: return None nearest: int = min(MAGIC_NUMBERS, key=lambda x: abs(x - n_atoms)) return nearest
[docs] def is_near_magic_number(n_atoms: int, tolerance: int = MAGIC_NUMBER_TOLERANCE) -> bool: """Check if the atom count is near a magic number. Args: n_atoms: Number of atoms in the cluster tolerance: Maximum difference from magic number to be considered "near" Returns: True if n_atoms is within tolerance of any magic number """ nearest: int | None = get_nearest_magic_number(n_atoms) if nearest is None: return False return abs(n_atoms - nearest) <= tolerance
def _find_icosahedron_shells(n_atoms: int) -> int | None: """Find the noshells parameter for an icosahedron closest to n_atoms. Args: n_atoms: Target number of atoms Returns: noshells parameter, or None if no suitable match """ for noshells, count in ICOSAHEDRON_SHELL_TO_ATOMS.items(): if count == n_atoms: return noshells closest_shell = None min_diff = float("inf") for noshells, count in ICOSAHEDRON_SHELL_TO_ATOMS.items(): diff: int = abs(count - n_atoms) if diff < min_diff: min_diff = diff closest_shell = noshells return closest_shell def _find_decahedron_params(n_atoms: int) -> tuple[int, int, int] | None: """Find decahedron parameters (p, q, r) closest to n_atoms. Note: Uses "Pt" as a placeholder element symbol for parameter search only. The actual cluster will be created with the correct element later. Args: n_atoms: Target number of atoms Returns: Tuple of (p, q, r) parameters, or None if no suitable match """ best_params = None min_diff = float("inf") for p in range(1, 6): for q in range(1, 6): for r in [0, 1]: try: cluster: Atoms = Decahedron(symbol="Pt", p=p, q=q, r=r) count: int = len(cluster) diff: int = abs(count - n_atoms) if diff < min_diff: min_diff = diff best_params = (p, q, r) except (ValueError, RuntimeError, TypeError): continue return best_params def _find_octahedron_params(n_atoms: int) -> tuple[int, int] | None: """Find octahedron parameters (length, cutoff) closest to n_atoms. Note: Uses "Pt" as a placeholder element symbol for parameter search only. The actual cluster will be created with the correct element later. Args: n_atoms: Target number of atoms Returns: Tuple of (length, cutoff) parameters, or None if no suitable match """ best_params = None min_diff = float("inf") for length in range(1, 8): max_cutoff: int = (length - 1) // 2 for cutoff in range(max_cutoff + 1): try: cluster: Cluster = Octahedron(symbol="Pt", length=length, cutoff=cutoff) count: int = len(cluster) diff: int = abs(count - n_atoms) if diff < min_diff: min_diff = diff best_params = (length, cutoff) except (ValueError, RuntimeError, TypeError): continue return best_params def remove_atoms_from_vertices( cluster: Atoms, n_remove: int, target_composition: list[str] | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, min_distance_factor: float = MIN_DISTANCE_FACTOR_DEFAULT, rng: np.random.Generator | None = None, ) -> Atoms | None: """Remove atoms from convex-hull vertices in bulk. Uses hull.vertices as the only candidate set. Orders by distance from COM (descending), respects target_composition when given, and supports multi-round removal when n_remove exceeds the number of vertices. Single validation per round (no per-atom connectivity loop). Args: cluster: The cluster to remove atoms from. n_remove: Number of atoms to remove. target_composition: Optional; preserves exact element counts in result. connectivity_factor: Factor for connectivity threshold when validating. min_distance_factor: Factor for minimum distance checks when validating. rng: Optional RNG. Returns: New Atoms with atoms removed, or None if removal fails (e.g. <4 atoms, cannot satisfy composition from vertices, or validation fails). """ rng = ensure_rng_or_create(rng) if n_remove <= 0: return cluster.copy() if n_remove >= len(cluster): raise ValueError( f"Cannot remove {n_remove} atoms from cluster with {len(cluster)} atoms" ) if len(cluster) < 4: return None initial_len: int = len(cluster) final_total: int = initial_len - n_remove base_composition = cluster.get_chemical_symbols() final_counts: Counter[str] | None if target_composition is not None: final_composition: list[str] = _cycle_composition_to_length( target_composition, final_total ) final_counts = get_composition_counts(final_composition) # Check composition feasibility before attempting removal is_feasible, _ = _check_composition_feasibility( base_composition, final_composition, operation="reduce" ) if not is_feasible: return None else: final_counts = None current: Atoms = cluster.copy() total_removed = 0 while total_removed < n_remove: vertices: np.ndarray[tuple[Any, ...], np.dtype[Any]] = ( get_convex_hull_vertex_indices(current) ) if len(vertices) == 0: return None positions: Any | np.ndarray[tuple[Any, ...], np.dtype[Any]] = ( current.get_positions() ) symbols = current.get_chemical_symbols() center: np.ndarray[tuple[Any, ...], np.dtype[Any]] | Any = ( current.get_center_of_mass() ) distances = np.linalg.norm(positions - center, axis=1) remaining_to_remove: int = n_remove - total_removed coordination: np.ndarray = np.zeros(len(current), dtype=np.int_) if len(current) > 1: for i in range(len(current)): r_i: float = get_covalent_radius(symbols[i]) for j in range(len(current)): if i == j: continue r_j: float = get_covalent_radius(symbols[j]) d: np.floating[Any] = np.linalg.norm(positions[i] - positions[j]) if d <= (r_i + r_j) * connectivity_factor: coordination[i] += 1 vert_coord: np.ndarray = coordination[vertices] vert_dist = distances[vertices] # Add random noise for tie-breaking noise: np.ndarray[tuple[Any, ...], np.dtype[np.float64]] = rng.random( len(vertices) ) order = np.lexsort((noise, -vert_dist, vert_coord)) sorted_vertices: np.ndarray[tuple[Any, ...], np.dtype[Any]] = vertices[order] # Prepare variables for removal logic remove_counts: dict[str, int] | None = None total_to_remove_this: int if target_composition is not None: assert final_counts is not None current_counts: Counter[str] = get_composition_counts(symbols) rc: dict[str, int] = {} for el in set(symbols) | set(final_counts.keys()): cur: int = current_counts.get(el, 0) fin: int = final_counts.get(el, 0) rc[el] = max(0, cur - fin) remove_counts = {k: v for k, v in rc.items() if v > 0} total_to_remove_this = sum(remove_counts.values()) if total_to_remove_this != remaining_to_remove: raise ValueError( f"Cannot preserve target composition: need to remove " f"{total_to_remove_this} atoms this round but " f"{remaining_to_remove} remaining. " f"Target: {target_composition}, current: {symbols}" ) else: remove_counts = None total_to_remove_this = remaining_to_remove remove_indices: list[int] to_remove_this: int if target_composition is not None: remove_indices = [] assert remove_counts is not None for el, count in remove_counts.items(): el_verts: list[int] = [ int(i) for i in sorted_vertices if symbols[i] == el ] remove_indices.extend(el_verts[:count]) if not remove_indices and total_to_remove_this > 0: return None to_remove_this = min(total_to_remove_this, len(remove_indices), 1) remove_indices = remove_indices[:to_remove_this] else: max_remove_this: int = max(1, len(sorted_vertices) // 2) to_remove_this = min(remaining_to_remove, max_remove_this) remove_indices = sorted_vertices[:to_remove_this].tolist() safe_candidates: list[int] = _identify_safe_removal_candidates( current, remove_indices, connectivity_factor, max_to_check=len(remove_indices), ) if not safe_candidates: return None remove_indices = safe_candidates[:to_remove_this] keep: np.ndarray[tuple[Any, ...], np.dtype[Any]] = np.setdiff1d( np.arange(len(current)), remove_indices ) new_symbols: list[Symbols | str] = [current.symbols[i] for i in keep] new_positions: Any | np.ndarray[tuple[Any, ...], np.dtype[Any]] = positions[ keep ] new_cluster = Atoms( symbols=new_symbols, positions=new_positions, cell=current.get_cell(), pbc=current.get_pbc(), ) is_valid, err = validate_cluster_structure( new_cluster, min_distance_factor, connectivity_factor, check_clashes=True, check_connectivity=_should_check_connectivity(new_cluster), use_mic=False, ) if not is_valid: return None clear_convex_hull_cache() current = new_cluster total_removed += len(remove_indices) if not remove_indices: break if target_composition is not None: actual: Counter[str] = get_composition_counts(current.get_chemical_symbols()) if actual != final_counts: raise ValueError( f"Vertex removal did not preserve exact composition. " f"Expected {final_composition} (counts: {final_counts}), " f"got counts {actual}." ) return current def grow_template_via_facets( seed_atoms: Atoms, target_composition: list[str], placement_radius_scaling: float, cell_side: float, rng: np.random.Generator, min_distance_factor: float = MIN_DISTANCE_FACTOR_DEFAULT, connectivity_factor: float = CONNECTIVITY_FACTOR, template_name: str | None = None, ) -> Atoms | None: """Grow a template seed to target composition by placing atoms on all facets. One hull per round, one position per facet, single validation per round. Falls back to grow_from_seed when seed has <4 atoms (no 3D hull). Args: seed_atoms: The template seed to grow from. target_composition: Target composition as list of element symbols. placement_radius_scaling: Scaling for placement shell radius. cell_side: Cubic cell side length. rng: Random number generator. min_distance_factor: Factor for minimum distance checks. connectivity_factor: Factor for connectivity threshold. template_name: Optional name of template type (e.g., "cuboctahedron"). Used to enable smart facet filtering for specific templates. Returns: Grown Atoms with target composition, or None on failure. """ base: Atoms = seed_atoms.copy() base_composition = base.get_chemical_symbols() is_feasible, _ = _check_composition_feasibility( base_composition, target_composition, operation="grow" ) if not is_feasible: return None base_counts: Counter[str] = get_composition_counts(base_composition) target_counts: Counter[str] = get_composition_counts(target_composition) atoms_to_add: list[str] = list((target_counts - base_counts).elements()) if not atoms_to_add: base.set_cell([cell_side, cell_side, cell_side]) base.center() return base if len(base) < 4: return grow_from_seed( seed_atoms=base, target_composition=target_composition, placement_radius_scaling=placement_radius_scaling, cell_side=cell_side, rng=rng, min_distance_factor=min_distance_factor, connectivity_factor=connectivity_factor, ) current: Atoms = base.copy() to_add: list[str] = list(atoms_to_add) radii_to_add: dict[str, float] = {s: get_covalent_radius(s) for s in set(to_add)} max_round_retries = 3 round_retry_count = 0 while to_add: symbols = current.get_chemical_symbols() current_radii: np.ndarray[tuple[Any, ...], np.dtype[Any]] = np.array( [get_covalent_radius(s) for s in symbols] ) max_existing: float = ( float(np.max(current_radii)) if len(current_radii) > 0 else 0.0 ) avg_new = float(np.mean([radii_to_add[s] for s in to_add])) bond_distance, min_dist, max_conn = compute_bond_distance_params( max_existing, avg_new, connectivity_factor, min_distance_factor, placement_radius_scaling, ) new_atom_symbol: str | None = to_add[0] if to_add else None use_smart_filtering: bool = template_name == "cuboctahedron" candidates: list[np.ndarray[tuple[Any, ...], np.dtype[Any]]] = ( _generate_batch_positions_on_convex_hull( current, n_candidates=0, bond_distance=bond_distance, rng=rng, min_connectivity_dist=min_dist, max_connectivity_dist=max_conn, use_all_facets=True, min_distance_factor=min_distance_factor, new_atom_symbol=new_atom_symbol, smart_facet_filtering=use_smart_filtering, connectivity_factor=connectivity_factor, ) ) if not candidates: logger.debug( f"grow_template_via_facets: no candidates generated for {template_name}, " f"n_atoms={len(current)}, to_add={len(to_add)}, " f"bond_distance={bond_distance:.3f}, max_conn={max_conn:.3f}" f" (discovery failure: candidate discarded; not a per-structure fallback)" ) return None placed_count = 0 candidate_idx = 0 while placed_count < len(to_add) and candidate_idx < len(candidates): sym: str = to_add[placed_count] pos: np.ndarray[tuple[Any, ...], np.dtype[Any]] = candidates[candidate_idx] candidate_idx += 1 test_atom = Atom(sym, pos) has_clash = False for existing_atom in current: dist: np.floating[Any] = np.linalg.norm( test_atom.position - existing_atom.position ) r_new: float = get_covalent_radius(sym) r_existing: float = get_covalent_radius(existing_atom.symbol) min_allowed: float = (r_new + r_existing) * min_distance_factor if dist < min_allowed: has_clash = True break if not has_clash: current.append(test_atom) placed_count += 1 to_add = to_add[placed_count:] if placed_count == 0 and to_add and round_retry_count < max_round_retries: round_retry_count += 1 logger.debug( f"grow_template_via_facets: no atoms placed, retry {round_retry_count}/{max_round_retries}, " f"candidates={len(candidates)}, to_add={len(to_add)}" ) clear_convex_hull_cache() continue elif placed_count == 0 and to_add: logger.debug( f"grow_template_via_facets: failed to place any atoms after {max_round_retries} retries, " f"candidates={len(candidates)}, to_add={len(to_add)}" f" (discovery failure: candidate discarded; not a per-structure fallback)" ) return None round_retry_count = 0 is_valid, err = validate_cluster_structure( current, min_distance_factor, connectivity_factor, check_clashes=True, check_connectivity=_should_check_connectivity(current), use_mic=False, ) if not is_valid: logger.debug( f"grow_template_via_facets: validation failed after placing {placed_count} atoms, " f"n_atoms={len(current)}, error: {err}" ) return None clear_convex_hull_cache() current.set_cell([cell_side, cell_side, cell_side]) current.center() if not _verify_exact_composition(current, target_composition): expected: Counter[str] = get_composition_counts(target_composition) actual: Counter[str] = get_composition_counts(current.get_chemical_symbols()) raise ValueError( f"grow_template_via_facets produced wrong composition: " f"expected {target_composition} (counts {expected}), " f"got counts {actual}" ) return current def _create_balanced_base_composition( composition: list[str], base_n_atoms: int ) -> list[str]: """Create a balanced base composition by cycling through elements. For multi-element compositions, this ensures the base template has a balanced distribution of elements, making it easier to adjust to the target composition by adding, removing, or switching labels. Args: composition: Target composition list base_n_atoms: Number of atoms in the base template Returns: List of element symbols with balanced distribution Raises: ValueError: If composition is empty """ if not composition: raise ValueError( f"Cannot create balanced base composition from empty composition " f"for {base_n_atoms} atoms" ) if len(composition) == 1: return composition * base_n_atoms return [composition[i % len(composition)] for i in range(base_n_atoms)] def _assign_balanced_composition_if_multi( cluster: Atoms, composition: list[str] ) -> None: """Assign balanced composition to cluster if composition has multiple elements. For multi-element compositions, assigns a balanced distribution of elements to the cluster. Single-element compositions are left unchanged. Args: cluster: The Atoms object to assign composition to composition: Target composition list """ if len(composition) > 1: base_symbols: list[str] = _create_balanced_base_composition( composition, len(cluster) ) cluster.set_chemical_symbols(base_symbols) def _deduplicate_positions( positions: list[np.ndarray] | list[list[float]], bond_length: float ) -> list[np.ndarray]: """Remove duplicate positions from a list using bond-length-based tolerance. Args: positions: List of position arrays or lists bond_length: Typical bond length for calculating tolerance Returns: List of unique positions as numpy arrays """ position_tolerance: float = bond_length * POSITION_COMPARISON_TOLERANCE_FACTOR unique_positions: list[np.ndarray] = [] for pos in positions: pos_array: np.ndarray[tuple[Any, ...], np.dtype[Any]] = np.array(pos) if not any( np.allclose(pos_array, up, atol=position_tolerance) for up in unique_positions ): unique_positions.append(pos_array) return unique_positions def _validate_n_atoms(n_atoms: int, expected: int | None, template_name: str) -> bool: """Validate that n_atoms matches expected value for a template. Args: n_atoms: Number of atoms to validate expected: Expected number of atoms (None means no specific requirement) template_name: Name of template for error messages Returns: True if valid, False otherwise (logs debug message if invalid) """ if n_atoms <= 0: return False # Silently skip - this is expected during template discovery # Debug logging here creates excessive noise return expected is None or n_atoms == expected def _generate_custom_template( template_name: str, composition: list[str], n_atoms: int, position_generator: Callable[ [list[str], int, float, float], list[np.ndarray] | list[list[float]] ], rng: np.random.Generator | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, n_atoms_validator: Callable[[int], tuple[bool, str | None]] | None = None, post_process: Callable[[Atoms, list[str], int], Atoms] | None = None, expected_n_atoms: int | None = None, ) -> Atoms | None: """Generate a custom template structure using a position generator function. This wrapper handles the common boilerplate for custom template generators: validation, RNG setup, error handling, composition assignment, and adjustment. Args: template_name: Name of template type (e.g., "tetrahedron") composition: List of element symbols n_atoms: Target number of atoms position_generator: Function that takes (composition, n_atoms, bond_length, connectivity_factor) and returns list of positions rng: Optional random number generator connectivity_factor: Factor for connectivity threshold n_atoms_validator: Optional function that takes n_atoms and returns (is_valid, error_msg) post_process: Optional function to modify cluster after creation (takes cluster, composition, n_atoms) expected_n_atoms: Optional expected atom count for validation (uses _validate_n_atoms) Returns: Atoms object with template structure, or None if generation fails """ if expected_n_atoms is not None: if not _validate_n_atoms(n_atoms, expected_n_atoms, template_name): return None elif n_atoms <= 0: return None if n_atoms_validator is not None: is_valid, _ = n_atoms_validator(n_atoms) if not is_valid: return None rng = ensure_rng_or_create(rng) try: base_element: str = _get_base_element(composition) a: float = _get_typical_bond_length(composition) positions: ( list[np.ndarray[tuple[Any, ...], np.dtype[Any]]] | list[list[float]] ) = position_generator(composition, n_atoms, a, connectivity_factor) cluster = Atoms([base_element] * len(positions), positions=positions) if post_process is not None: cluster = post_process(cluster, composition, n_atoms) _assign_balanced_composition_if_multi(cluster, composition) adjusted = _adjust_template_to_target( cluster, n_atoms, composition, rng, template_name, connectivity_factor, MIN_DISTANCE_FACTOR_DEFAULT, ) if adjusted is None: return None return adjusted except (ValueError, RuntimeError, AttributeError): return None def _rescale_cluster_to_bond_length( atoms: Atoms, composition: list[str], connectivity_factor: float, ) -> None: """Rescale ASE-generated cluster so nn distances match covalent-based bond length. ASE Icosahedron, Decahedron, Octahedron use atomic radii. We validate with covalent radii × connectivity_factor. Rescaling ensures nn distances align with our connectivity model so structures pass validation. Modifies atoms in place. Keeps center of mass fixed. Args: atoms: ASE-generated cluster (e.g. Icosahedron, Decahedron, Octahedron). composition: Target composition (used for typical bond length). connectivity_factor: Connectivity factor; kept for API consistency. """ if len(atoms) < 2: return a: float = _get_typical_bond_length(composition) positions: Any | np.ndarray[tuple[Any, ...], np.dtype[Any]] = atoms.get_positions() n: int = len(positions) min_dists: list[np.floating[Any]] = [ min(np.linalg.norm(positions[i] - positions[j]) for j in range(n) if j != i) for i in range(n) ] current_scale = float(np.mean(min_dists)) if current_scale <= 0: return target_scale: float = a * min(1.0, connectivity_factor * 0.95) scale: float = target_scale / current_scale com = np.mean(positions, axis=0) new_positions = com + (positions - com) * scale atoms.set_positions(new_positions) def _register_ase_template( template_name: str, find_params_func: Callable[[int], Any], generate_base_func: Callable[[str, Any], Atoms], ) -> None: """Register an ASE-based template generator in the template registry. Args: template_name: Name of the template type find_params_func: Function to find parameters for the template generate_base_func: Function to generate base cluster from parameters """ _TEMPLATE_REGISTRY[template_name] = { "find_params": find_params_func, "generate_base": generate_base_func, } _register_ase_template( "icosahedron", find_params_func=_find_icosahedron_shells, generate_base_func=lambda elem, params: Icosahedron(symbol=elem, noshells=params), ) _register_ase_template( "decahedron", find_params_func=_find_decahedron_params, generate_base_func=lambda elem, params: Decahedron( symbol=elem, p=params[0], q=params[1], r=params[2] ), ) _register_ase_template( "octahedron", find_params_func=_find_octahedron_params, generate_base_func=lambda elem, params: Octahedron( symbol=elem, length=params[0], cutoff=params[1] ), ) def _generate_ase_template_with_common_pattern( composition: list[str], n_atoms: int, rng: np.random.Generator | None, template_name: str, find_params_func: Callable[[int], Any], generate_base_func: Callable[[str, Any], Atoms], connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Helper function for ASE-based template generators with common pattern. This function handles the common pattern used by icosahedron, decahedron, and octahedron generators: 1. Validate n_atoms 2. Ensure RNG exists 3. Find parameters (with early return if None) 4. Generate base cluster 5. Assign balanced composition if multi-element 6. Adjust atom count and composition Note on scaling: ASE's Icosahedron, Decahedron, and Octahedron scale using atomic radii. We rescale via _rescale_cluster_to_bond_length so nn distances match covalent-based bond length (a = 2×avg covalent radius), ensuring structures stay within (r_i + r_j) × connectivity_factor before validation. Args: composition: List of element symbols n_atoms: Target number of atoms rng: Optional random number generator template_name: Name of template type (for error messages) find_params_func: Function to find parameters for the template (returns params or None) generate_base_func: Function to generate base cluster from params (returns Atoms) connectivity_factor: Factor for connectivity threshold (based on covalent radii) Returns: Atoms object with template structure, or None if generation fails """ if n_atoms <= 0: return None rng = ensure_rng_or_create(rng) params = find_params_func(n_atoms) if params is None: return None try: base_element: str = _get_base_element(composition) cluster: Atoms = generate_base_func(base_element, params) _rescale_cluster_to_bond_length(cluster, composition, connectivity_factor) _assign_balanced_composition_if_multi(cluster, composition) adjusted = _adjust_template_to_target( cluster, n_atoms, composition, rng, template_name, connectivity_factor, MIN_DISTANCE_FACTOR_DEFAULT, ) if adjusted is None: return None return adjusted except (ValueError, RuntimeError, AttributeError): return None def _set_template_info(atoms: Atoms, template_name: str) -> None: """Set template type in atoms info for tracking. Args: atoms: The Atoms object to set info on template_name: Name of template type (e.g., "icosahedron") """ if atoms.info is None: atoms.info = {} atoms.info["template_type"] = template_name def _adjust_template_to_target( cluster: Atoms, target_n_atoms: int, composition: list[str], rng: np.random.Generator, template_name: str, connectivity_factor: float = CONNECTIVITY_FACTOR, min_distance_factor: float = MIN_DISTANCE_FACTOR_DEFAULT, cell_side: float | None = None, placement_radius_scaling: float = PLACEMENT_RADIUS_SCALING_DEFAULT, ) -> Atoms | None: """Shared helper to adjust template cluster to target atom count and composition. Handles three cases: grow (add atoms), shrink (remove atoms), or exact match. This consolidates the common adjustment logic used throughout template generation. Args: cluster: The base template cluster target_n_atoms: Target number of atoms composition: Target composition rng: Random number generator template_name: Name of template type (for error messages) connectivity_factor: Factor for connectivity threshold (default: CONNECTIVITY_FACTOR) min_distance_factor: Factor for minimum distance checks (default: MIN_DISTANCE_FACTOR_DEFAULT) cell_side: Cell side length (defaults to cluster cell or VACUUM_DEFAULT * 2) placement_radius_scaling: Scaling for placement radius (for growth). Returns: Adjusted Atoms object, or None if adjustment fails """ base_count: int = len(cluster) if cell_side is None: cell_side = ( cluster.cell.lengths()[0] if cluster.cell.any() else VACUUM_DEFAULT * 2 ) if base_count < target_n_atoms: target_composition: list[str] = _cycle_composition_to_length( composition, target_n_atoms ) try: grown = grow_template_via_facets( seed_atoms=cluster, target_composition=target_composition, placement_radius_scaling=placement_radius_scaling, cell_side=cell_side, rng=rng, min_distance_factor=min_distance_factor, connectivity_factor=connectivity_factor, template_name=template_name, ) if grown is None: logger.debug( f"Failed to add atoms to {template_name} template " f"while maintaining connectivity" f" (discovery failure: candidate discarded; not a per-structure fallback)" ) return None if len(grown) != target_n_atoms: logger.debug( f"{template_name} template has {len(grown)} atoms after growth, " f"expected {target_n_atoms}" ) return None _set_template_info(grown, template_name) return grown except ValueError as e: logger.debug( f"Failed to grow {template_name} template from {base_count} " f"to {target_n_atoms} atoms: {e}" ) return None elif base_count > target_n_atoms: n_remove: int = base_count - target_n_atoms adjusted = remove_atoms_from_vertices( cluster, n_remove, target_composition=composition, connectivity_factor=connectivity_factor, min_distance_factor=min_distance_factor, rng=rng, ) if adjusted is None: return None cluster = adjusted _set_template_info(cluster, template_name) else: cluster = _assign_exact_composition( cluster, composition, target_n_atoms, rng=rng ) _set_template_info(cluster, template_name) return cluster def _generate_ase_template_from_registry( template_name: str, composition: list[str], n_atoms: int, rng: np.random.Generator | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Helper to generate ASE-based template from registry. Args: template_name: Name of template type (e.g., "icosahedron") composition: List of element symbols n_atoms: Target number of atoms rng: Optional random number generator connectivity_factor: Factor for connectivity threshold Returns: Atoms object with template structure, or None if generation fails """ config = _TEMPLATE_REGISTRY.get(template_name) if config is None: logger.warning(f"{template_name.capitalize()} template not registered") return None find_params = cast(Callable[[int], Any], config["find_params"]) generate_base = cast(Callable[[str, Any], Atoms], config["generate_base"]) return _generate_ase_template_with_common_pattern( composition=composition, n_atoms=n_atoms, rng=rng, template_name=template_name, find_params_func=find_params, generate_base_func=generate_base, connectivity_factor=connectivity_factor, )
[docs] def generate_icosahedron( composition: list[str], n_atoms: int, rng: np.random.Generator | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Generate an icosahedral cluster. Uses ASE's Icosahedron generator and adjusts atom count by adding/removing surface atoms if needed. Args: composition: List of element symbols (cycled to match n_atoms) n_atoms: Target number of atoms rng: Optional random number generator for reproducibility connectivity_factor: Factor for connectivity threshold Returns: Atoms object with icosahedral structure, or None if generation fails """ return _generate_ase_template_from_registry( "icosahedron", composition, n_atoms, rng, connectivity_factor )
[docs] def generate_decahedron( composition: list[str], n_atoms: int, rng: np.random.Generator | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Generate a decahedral cluster. Uses ASE's Decahedron generator and adjusts atom count by adding/removing surface atoms if needed. Args: composition: List of element symbols (cycled to match n_atoms) n_atoms: Target number of atoms rng: Optional random number generator for reproducibility connectivity_factor: Factor for connectivity threshold Returns: Atoms object with decahedral structure, or None if generation fails """ return _generate_ase_template_from_registry( "decahedron", composition, n_atoms, rng, connectivity_factor )
[docs] def generate_octahedron( composition: list[str], n_atoms: int, rng: np.random.Generator | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Generate an octahedral cluster. Uses ASE's Octahedron generator and adjusts atom count by adding/removing surface atoms if needed. Args: composition: List of element symbols (cycled to match n_atoms) n_atoms: Target number of atoms rng: Optional random number generator for reproducibility connectivity_factor: Factor for connectivity threshold Returns: Atoms object with octahedral structure, or None if generation fails """ return _generate_ase_template_from_registry( "octahedron", composition, n_atoms, rng, connectivity_factor )
[docs] def generate_tetrahedron( composition: list[str], n_atoms: int, rng: np.random.Generator | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Generate a tetrahedral cluster with the specified number of atoms. Creates a regular tetrahedron with atoms at vertices. Only supports 4 atoms (the vertices of a regular tetrahedron). Args: composition: List of element symbols (cycled to match n_atoms) n_atoms: Target number of atoms (must be 4) rng: Optional random number generator for reproducibility Returns: Atoms object with tetrahedral structure, or None if generation fails (e.g., n_atoms != 4) """ def _generate_tetrahedron_positions( comp: list[str], n: int, bond_length: float, cf: float ) -> list[np.ndarray]: return [ np.array([0.0, 0.0, 0.0]), np.array([bond_length, 0.0, 0.0]), np.array([bond_length / 2, bond_length * np.sqrt(3) / 2, 0.0]), np.array( [ bond_length / 2, bond_length / (2 * np.sqrt(3)), bond_length * np.sqrt(2 / 3), ] ), ] return _generate_custom_template( template_name="tetrahedron", composition=composition, n_atoms=n_atoms, position_generator=_generate_tetrahedron_positions, rng=rng, connectivity_factor=connectivity_factor, expected_n_atoms=4, )
[docs] def generate_cube( composition: list[str], n_atoms: int, rng: np.random.Generator | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Generate a cubic cluster with the specified number of atoms. Creates cubic structures (n×n×n cubes) for perfect cube sizes only. Only supports perfect cubes (8, 27, 64, 125, etc.). Args: composition: List of element symbols (cycled to match n_atoms) n_atoms: Target number of atoms (must be a perfect cube: n³) rng: Optional random number generator for reproducibility Returns: Atoms object with cubic structure, or None if generation fails (e.g., n_atoms is not a perfect cube) """ def _validate_cube(n: int) -> tuple[bool, str | None]: """Validate that n_atoms is a perfect cube.""" cube_root = round(n ** (1 / 3)) if cube_root**3 == n: return True, None return False, ( f"generate_cube only supports perfect cubes (n³), got {n}. " f"Returning None instead of falling back to other template." ) def _generate_cube_positions( comp: list[str], n: int, bond_length: float, cf: float ) -> list[np.ndarray]: """Generate positions for n×n×n cubic lattice.""" cube_root = round(n ** (1 / 3)) return [ np.array([i * bond_length, j * bond_length, k * bond_length]) for i in range(cube_root) for j in range(cube_root) for k in range(cube_root) ] return _generate_custom_template( template_name="cube", composition=composition, n_atoms=n_atoms, position_generator=_generate_cube_positions, rng=rng, connectivity_factor=connectivity_factor, n_atoms_validator=_validate_cube, )
[docs] def generate_cuboctahedron( composition: list[str], n_atoms: int, rng: np.random.Generator | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Generate a cuboctahedral cluster with the specified number of atoms. Cuboctahedron has 12 vertices. For 13 atoms, adds a center atom. Args: composition: List of element symbols (cycled to match n_atoms) n_atoms: Target number of atoms (12 or 13 for perfect structures) rng: Optional random number generator for reproducibility Returns: Atoms object with cuboctahedral structure, or None if generation fails """ def _generate_cuboctahedron_positions( comp: list[str], n: int, bond_length: float, cf: float ) -> list[np.ndarray]: s: float = bond_length * cf / 2.0 positions = [] for sign1 in [-1, 1]: for sign2 in [-1, 1]: positions.append([sign1 * s, sign2 * s, 0.0]) positions.append([sign1 * s, 0.0, sign2 * s]) positions.append([0.0, sign1 * s, sign2 * s]) return _deduplicate_positions(positions, bond_length) def _post_process_cuboctahedron(cluster: Atoms, comp: list[str], n: int) -> Atoms: """Add center atom for 13-atom cuboctahedron.""" if n == 13: base_element: str = _get_base_element(comp) center_pos: np.ndarray[tuple[Any, ...], np.dtype[Any]] = np.array( [0.0, 0.0, 0.0] ) cluster.append(Atom(base_element, center_pos)) return cluster return _generate_custom_template( template_name="cuboctahedron", composition=composition, n_atoms=n_atoms, position_generator=_generate_cuboctahedron_positions, rng=rng, connectivity_factor=connectivity_factor, post_process=_post_process_cuboctahedron, )
[docs] def generate_truncated_octahedron( composition: list[str], n_atoms: int, rng: np.random.Generator | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Generate a truncated octahedral cluster with the specified number of atoms. Truncated octahedron has 24 vertices (6 square faces, 8 hexagonal faces). Only supports 24 atoms (the vertices of a truncated octahedron). Args: composition: List of element symbols (cycled to match n_atoms) n_atoms: Target number of atoms (must be 24) rng: Optional random number generator for reproducibility Returns: Atoms object with truncated octahedral structure, or None if generation fails (e.g., n_atoms != 24 or position generation doesn't yield exactly 24 positions) """ def _generate_truncated_octahedron_positions( comp: list[str], n: int, bond_length: float, cf: float ) -> list[np.ndarray]: s: float = bond_length * cf / 2.0 positions: list[list[float]] = [] for x_sign in [-1, 1]: for y_sign in [-1, 1]: for z_sign in [-1, 1]: minus_count: int = sum([x_sign < 0, y_sign < 0, z_sign < 0]) if minus_count % 2 == 0: for perm in [ [2 * s * x_sign, s * y_sign, 0], [2 * s * x_sign, 0, s * z_sign], [s * x_sign, 2 * s * y_sign, 0], [s * x_sign, 0, 2 * s * z_sign], [0, 2 * s * y_sign, s * z_sign], [0, s * y_sign, 2 * s * z_sign], ]: if not any(np.allclose(perm, p) for p in positions): positions.append(perm) unique_positions: list[np.ndarray[tuple[Any, ...], np.dtype[Any]]] = ( _deduplicate_positions(positions, bond_length) ) if len(unique_positions) != 24: raise ValueError( f"generate_truncated_octahedron requires exactly 24 positions, " f"got {len(unique_positions)}" ) return unique_positions[:24] return _generate_custom_template( template_name="truncated_octahedron", composition=composition, n_atoms=n_atoms, position_generator=_generate_truncated_octahedron_positions, rng=rng, connectivity_factor=connectivity_factor, expected_n_atoms=24, )
_TEMPLATE_GENERATORS: dict[str, Callable[..., Atoms | None]] = { "icosahedron": generate_icosahedron, "decahedron": generate_decahedron, "cuboctahedron": generate_cuboctahedron, "truncated_octahedron": generate_truncated_octahedron, "octahedron": generate_octahedron, "cube": generate_cube, "tetrahedron": generate_tetrahedron, }
[docs] def generate_template_structure( composition: list[str], n_atoms: int, template_type: str = "auto", rng: np.random.Generator | None = None, connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Generate a template structure of the specified type. Args: composition: List of element symbols n_atoms: Target number of atoms template_type: Type of template. Can be: - "auto": Automatically select best template type - "icosahedron": Icosahedral structure - "decahedron": Decahedral structure - "octahedron": Octahedral structure - "tetrahedron": Tetrahedral structure - "cube": Cubic structure - "cuboctahedron": Cuboctahedral structure - "truncated_octahedron": Truncated octahedral structure rng: Optional random number generator Returns: Atoms object with template structure, or None if generation fails """ if template_type == "auto": preferred_order: list[str] = [ "icosahedron", "decahedron", "cuboctahedron", "truncated_octahedron", "octahedron", "cube", "tetrahedron", ] for template_name in preferred_order: gen_func = _TEMPLATE_GENERATORS.get(template_name) if gen_func is not None: result: Atoms | None = gen_func( composition, n_atoms, rng, connectivity_factor ) if result is not None: return result return None gen_func = _TEMPLATE_GENERATORS.get(template_type) if gen_func is None: logger.warning(f"Unknown template type: {template_type}") return None return gen_func(composition, n_atoms, rng, connectivity_factor)
def _find_valid_template_types(n_atoms: int) -> list[str]: """Find all template types that can successfully generate a structure with n_atoms. Validity probing is deterministic and keyed only by ``n_atoms`` so cached and concurrent calls produce stable results independent of caller RNG state. Args: n_atoms: Target number of atoms Returns: List of template type names that can generate this size """ if n_atoms <= 0: return [] leader = False with _VALID_TEMPLATE_TYPES_LOCK: cached = _VALID_TEMPLATE_TYPES_CACHE.get(n_atoms) if cached is not None: return list(cached) event = _VALID_TEMPLATE_TYPES_INFLIGHT.get(n_atoms) if event is None: event = Event() _VALID_TEMPLATE_TYPES_INFLIGHT[n_atoms] = event leader = True if not leader: event.wait() with _VALID_TEMPLATE_TYPES_LOCK: return list(_VALID_TEMPLATE_TYPES_CACHE.get(n_atoms, ())) deterministic_rng = np.random.default_rng(n_atoms) valid_types: list[str] = [] test_composition: list[str] = ["Pt"] * n_atoms sorted_template_types: list[str] = sorted(_TEMPLATE_GENERATORS.keys()) computed: tuple[str, ...] | None = None try: for template_type in sorted_template_types: gen_func: Callable[..., Atoms | None] = _TEMPLATE_GENERATORS[template_type] try: result: Atoms | None = gen_func( test_composition, n_atoms, deterministic_rng ) if result is not None and len(result) == n_atoms: valid_types.append(template_type) except (ValueError, RuntimeError, TypeError): continue computed = tuple(sorted(valid_types)) finally: with _VALID_TEMPLATE_TYPES_LOCK: inflight_event = _VALID_TEMPLATE_TYPES_INFLIGHT.pop(n_atoms, None) if computed is not None: _VALID_TEMPLATE_TYPES_CACHE[n_atoms] = computed if inflight_event is not None: inflight_event.set() return list(computed) if computed is not None else [] def _generate_template_with_atom_adjustment( base_template_type: str, base_n_atoms: int, target_n_atoms: int, composition: list[str], rng: np.random.Generator, cell_side: float | None = None, placement_radius_scaling: float = PLACEMENT_RADIUS_SCALING_DEFAULT, min_distance_factor: float = MIN_DISTANCE_FACTOR_DEFAULT, connectivity_factor: float = CONNECTIVITY_FACTOR, ) -> Atoms | None: """Generate a template structure and adjust atom count to match target. Uses seed growth functions to add/remove atoms from the surface. Args: base_template_type: Template type to start from base_n_atoms: Number of atoms in the base template target_n_atoms: Target number of atoms composition: Target composition rng: Random number generator cell_side: Cell side length (defaults to VACUUM_DEFAULT * 2) placement_radius_scaling: Scaling for atom placement min_distance_factor: Factor for minimum distance checks connectivity_factor: Factor for connectivity threshold Returns: Atoms object with target composition, or None if generation fails """ if cell_side is None: cell_side = VACUUM_DEFAULT * 2 if target_n_atoms <= 0: return None # Generate base template if not composition: raise ValueError( f"Cannot generate template with empty composition for {target_n_atoms} atoms" ) elif len(composition) >= base_n_atoms: base_composition = composition[:base_n_atoms] else: base_composition = _cycle_composition_to_length(composition, base_n_atoms) gen_func: Callable[..., Atoms | None] | None = _TEMPLATE_GENERATORS.get( base_template_type ) if gen_func is None: return None base_cluster: Atoms | None = gen_func(base_composition, base_n_atoms, rng) if base_cluster is None: return None base_cluster.set_cell([cell_side, cell_side, cell_side]) base_cluster.center() n_diff: int = target_n_atoms - base_n_atoms if n_diff == 0: result: Atoms = _assign_exact_composition( base_cluster, composition, target_n_atoms, rng=rng ) return result if n_diff < 0: n_remove: int = -n_diff removal_ratio: float = n_remove / base_n_atoms if removal_ratio >= 0.5: return None max_removal_attempts: int = min(3, base_n_atoms) for attempt in range(max_removal_attempts): attempt_rng: Generator = ( rng if attempt == 0 else np.random.default_rng(rng.integers(0, 2**31)) ) adjusted = remove_atoms_from_vertices( base_cluster, n_remove, target_composition=composition, connectivity_factor=connectivity_factor, min_distance_factor=min_distance_factor, rng=attempt_rng, ) if adjusted is None: if attempt < max_removal_attempts - 1: continue return None return adjusted return None adjusted = _adjust_template_to_target( cluster=base_cluster, target_n_atoms=target_n_atoms, composition=composition, rng=rng, template_name=base_template_type, connectivity_factor=connectivity_factor, min_distance_factor=min_distance_factor, cell_side=cell_side, placement_radius_scaling=placement_radius_scaling, ) if adjusted is not None: return adjusted return None def _validate_template_geometry(atoms: Atoms) -> bool: """Validate that a template structure has reasonable geometry. Filters out templates with atoms that are unreasonably far apart or too close. This ensures templates are physically reasonable starting structures. Args: atoms: The Atoms object to validate Returns: True if geometry is reasonable, False otherwise """ if len(atoms) <= 1: return True positions: Any | np.ndarray[tuple[Any, ...], np.dtype[Any]] = atoms.get_positions() symbols = atoms.get_chemical_symbols() distances = [] for i in range(len(atoms)): for j in range(i + 1, len(atoms)): dist: np.floating[Any] = np.linalg.norm(positions[i] - positions[j]) distances.append(dist) if not distances: return True min_dist = min(distances) max_dist = max(distances) if len(atoms) <= 3: max_covalent_sum = 0.0 for i in range(len(atoms)): for j in range(i + 1, len(atoms)): r_i = get_covalent_radius(symbols[i]) r_j = get_covalent_radius(symbols[j]) max_covalent_sum = max(max_covalent_sum, r_i + r_j) # For 2-atom clusters, use strict criteria: within 1.2x sum of covalent radii max_reasonable_distance: float = ( BOND_DISTANCE_MULTIPLIER_2ATOM * max_covalent_sum if len(atoms) == 2 else BOND_DISTANCE_MULTIPLIER_3ATOM * max_covalent_sum ) if max_dist > max_reasonable_distance: return False min_covalent_sum = float("inf") for i in range(len(atoms)): for j in range(i + 1, len(atoms)): r_i = get_covalent_radius(symbols[i]) r_j = get_covalent_radius(symbols[j]) min_covalent_sum = min(min_covalent_sum, r_i + r_j) min_allowed_distance: float = min_covalent_sum * MIN_DISTANCE_FACTOR_DEFAULT return not (min_dist < min_allowed_distance) def _validate_and_add_template( atoms: Atoms, results: list[Atoms], template_type: str, template_description: str, min_distance_factor: float, connectivity_factor: float, logger_instance: Any = None, ) -> bool: """Validate a template structure and add it to results if valid. This helper consolidates the common validation pattern used in template generation functions. It performs geometry validation and cluster structure validation, then adds valid templates to the results list. Args: atoms: The Atoms object to validate results: List to append valid templates to template_type: Type of template (e.g., "icosahedron") template_description: Description string for logging (e.g., "for 13 atoms") min_distance_factor: Factor for minimum distance checks connectivity_factor: Factor for connectivity threshold logger_instance: Logger instance (defaults to module logger) Returns: True if template was added to results, False otherwise """ if logger_instance is None: logger_instance = logger if not _validate_template_geometry(atoms): return False validated_atoms, is_valid, error_message = validate_cluster( atoms, composition=None, min_distance_factor=min_distance_factor, connectivity_factor=connectivity_factor, check_clashes=True, check_connectivity=None, sort_atoms=False, raise_on_failure=False, source="", ) if is_valid: _set_template_info(validated_atoms, template_type) results.append(validated_atoms) return True else: return False def generate_template_matches( composition: list[str], n_atoms: int, rng: np.random.Generator | None = None, cell_side: float | None = None, placement_radius_scaling: float = PLACEMENT_RADIUS_SCALING_DEFAULT, min_distance_factor: float = MIN_DISTANCE_FACTOR_DEFAULT, connectivity_factor: float = CONNECTIVITY_FACTOR, include_exact: bool = True, include_near: bool = True, ) -> list[Atoms]: """Generate template structures for the target size. Provides both exact and near-match template generation in a single interface. Exact matches when n_atoms is a magic number; near matches by adjusting from the nearest magic number. Args: composition: Target composition n_atoms: Target number of atoms rng: Optional random number generator cell_side: Cell side length (needed for near matches with growth) placement_radius_scaling: Scaling for atom placement (needed for near matches) min_distance_factor: Factor for minimum distance checks connectivity_factor: Factor for connectivity threshold include_exact: If True, generate exact matches when n_atoms is a magic number include_near: If True, generate near matches from nearest magic number Returns: List of Atoms objects with template structures """ rng = ensure_rng_or_create(rng) results: list[Atoms] = [] nearest_magic: int | None = get_nearest_magic_number(n_atoms) if nearest_magic is None: return results is_exact_match: bool = nearest_magic == n_atoms if include_exact and is_exact_match: valid_types = _find_valid_template_types(n_atoms) for template_type in valid_types: try: atoms = _TEMPLATE_GENERATORS[template_type]( composition, n_atoms, rng, connectivity_factor ) if atoms is not None and len(atoms) == n_atoms: assigned = _assign_exact_composition( atoms, composition, n_atoms, rng=rng ) if assigned is not None: _validate_and_add_template( atoms=assigned, results=results, template_type=template_type, template_description=f"for {n_atoms} atoms", min_distance_factor=min_distance_factor, connectivity_factor=connectivity_factor, ) except (ValueError, RuntimeError, AttributeError, TypeError, KeyError) as e: logger.debug( "Template generation failed for %s (n_atoms=%s): %s: %s", template_type, n_atoms, type(e).__name__, e, ) if include_near and not is_exact_match: valid_types = _find_valid_template_types(nearest_magic) for template_type in valid_types: try: adjusted: Atoms | None = _generate_template_with_atom_adjustment( base_template_type=template_type, base_n_atoms=nearest_magic, target_n_atoms=n_atoms, composition=composition, rng=rng, cell_side=cell_side, placement_radius_scaling=placement_radius_scaling, min_distance_factor=min_distance_factor, connectivity_factor=connectivity_factor, ) if adjusted is not None and len(adjusted) == n_atoms: _validate_and_add_template( atoms=adjusted, results=results, template_type=template_type, template_description=f"({nearest_magic} -> {n_atoms})", min_distance_factor=min_distance_factor, connectivity_factor=connectivity_factor, ) except (ValueError, RuntimeError, AttributeError, TypeError, KeyError) as e: logger.debug( "Template adjustment failed for %s (base=%s, target=%s): %s: %s", template_type, nearest_magic, n_atoms, type(e).__name__, e, ) return results