Source code for scgo.initialization.atomic_radii

"""Atomic radii with ASE gap-filling, interpolation, and per-element caching.

ASE's ``vdw_radii`` table has NaN for many transition metals and lanthanides.
Missing values are resolved once per element via linear interpolation (or
extrapolation) along atomic number, with a scaled-covalent fallback for VdW.
Resolved values are cached so production runs do not repeat work or log noise.
"""

from __future__ import annotations

import logging
from collections.abc import Iterable
from functools import lru_cache
from typing import TYPE_CHECKING

import numpy as np
from ase.data import atomic_numbers, chemical_symbols, covalent_radii, vdw_radii

if TYPE_CHECKING:
    from ase import Atoms

# NOTE: Keep stdlib logging here to avoid early package-import cycles during
# scgo bootstrap (initialization imports can execute before scgo.utils is ready).
logger = logging.getLogger(__name__)

VDW_COVALENT_FALLBACK_SCALE = 1.3

_patched_log_keys: set[tuple[str, str]] = set()


def _is_valid_radius(value: float) -> bool:
    return bool(np.isfinite(value) and value > 0)


def _interpolate_radius_at_z(z: int, radii: np.ndarray) -> float | None:
    """Linearly interpolate or extrapolate a radius at atomic number ``z``.

    Scans left and right for the nearest elements with finite positive radii.
    """
    n = len(radii)
    if z <= 0 or z >= n:
        return None

    left: tuple[int, float] | None = None
    for zz in range(z - 1, 0, -1):
        if zz < n and _is_valid_radius(float(radii[zz])):
            left = (zz, float(radii[zz]))
            break

    right: tuple[int, float] | None = None
    for zz in range(z + 1, n):
        if _is_valid_radius(float(radii[zz])):
            right = (zz, float(radii[zz]))
            break

    if left is not None and right is not None:
        left_z, left_val = left
        right_z, right_val = right
        weight = (z - left_z) / (right_z - left_z)
        return left_val + weight * (right_val - left_val)

    if left is not None:
        left_z, left_val = left
        for zz in range(left_z - 1, 0, -1):
            if zz < n and _is_valid_radius(float(radii[zz])):
                weight = (z - zz) / (left_z - zz)
                return float(radii[zz]) + weight * (left_val - float(radii[zz]))
        return left_val

    if right is not None:
        right_z, right_val = right
        for zz in range(right_z + 1, n):
            if _is_valid_radius(float(radii[zz])):
                weight = (z - right_z) / (zz - right_z)
                return right_val + weight * (float(radii[zz]) - right_val)
        return right_val

    return None


def _log_patch_once(kind: str, symbol: str, message: str) -> None:
    key = (kind, symbol)
    if key in _patched_log_keys:
        return
    _patched_log_keys.add(key)
    logger.info(message)


def _resolve_ase_radius(
    symbol: str,
    *,
    kind: str,
    radii_table: np.ndarray,
    fallback: float | None = None,
) -> float:
    try:
        z = atomic_numbers[symbol]
    except KeyError as exc:
        raise ValueError(
            f"Unknown element symbol: {symbol}. Could not find {kind} radius."
        ) from exc

    raw = float(radii_table[z])
    if _is_valid_radius(raw):
        return raw

    patched = _interpolate_radius_at_z(z, radii_table)
    if patched is not None:
        _log_patch_once(
            kind,
            symbol,
            f"{kind.capitalize()} radius for {symbol} is missing/NaN in ASE; "
            f"using interpolated value {patched:.3f} Å",
        )
        return patched

    if fallback is not None:
        _log_patch_once(
            kind,
            symbol,
            f"{kind.capitalize()} radius for {symbol} is missing/NaN in ASE; "
            f"using fallback value {fallback:.3f} Å",
        )
        return fallback

    raise ValueError(
        f"Could not resolve {kind} radius for {symbol}: ASE value is invalid "
        "and interpolation/extrapolation failed."
    )


[docs] @lru_cache(maxsize=256) def get_covalent_radius(symbol: str) -> float: """Return the covalent radius for ``symbol`` in Angstroms.""" return _resolve_ase_radius(symbol, kind="covalent", radii_table=covalent_radii)
[docs] @lru_cache(maxsize=256) def get_vdw_radius(symbol: str) -> float: """Return the van-der-Waals radius for ``symbol`` in Angstroms.""" try: z = atomic_numbers[symbol] except KeyError as exc: raise ValueError( f"Unknown element symbol: {symbol}. Could not find vdw radius." ) from exc raw = float(vdw_radii[z]) if _is_valid_radius(raw): return raw return _resolve_ase_radius( symbol, kind="vdw", radii_table=vdw_radii, fallback=get_covalent_radius(symbol) * VDW_COVALENT_FALLBACK_SCALE, )
def clear_atomic_radii_cache() -> None: """Clear cached radii and one-shot patch logs (mainly for tests).""" get_covalent_radius.cache_clear() get_covalent_radius_by_z.cache_clear() get_vdw_radius.cache_clear() _patched_log_keys.clear() @lru_cache(maxsize=256) def get_covalent_radius_by_z(z: int) -> float: """Return the covalent radius for atomic number ``z`` in Angstroms.""" return get_covalent_radius(chemical_symbols[int(z)])
[docs] def build_blmin_from_zs( zs: Iterable[int], ratio: float = 0.7, ) -> dict[tuple[int, int], float]: """Build an ASE-compatible blmin table using scgo gap-filled covalent radii.""" unique = sorted({int(z) for z in zs}) out: dict[tuple[int, int], float] = {} for i, zi in enumerate(unique): ri = get_covalent_radius_by_z(zi) for zj in unique[i:]: rj = get_covalent_radius_by_z(zj) dist = (ri + rj) * ratio out[(zi, zj)] = dist if zi != zj: out[(zj, zi)] = dist return out
[docs] def build_blmin( symbols: Iterable[str], ratio: float = 0.7 ) -> dict[tuple[int, int], float]: """Build an ASE-compatible blmin table for the given element symbols.""" zs = [atomic_numbers[str(s)] for s in symbols] return build_blmin_from_zs(zs, ratio)
def cluster_passes_ga_blmin( atoms: Atoms, blmin_ratio: float, ) -> bool: """Return True if ``atoms`` satisfies ASE GA steric checks at ``blmin_ratio``.""" from ase_ga.utilities import atoms_too_close blmin = build_blmin_from_zs(atoms.get_atomic_numbers(), ratio=blmin_ratio) return not atoms_too_close(atoms, blmin, use_tags=False) def resolve_steric_floor( min_distance_factor: float, blmin_ratio: float | None, ) -> float: """Minimum clash factor for placement: at least ``blmin_ratio`` when set.""" if blmin_ratio is None: return min_distance_factor return max(min_distance_factor, blmin_ratio)