Source code for atomworks.io.utils.testing

"""Utility functions for writings tests for AtomArray objects."""

__all__ = ["assert_same_atom_array"]

import os
from collections.abc import Iterable

import biotite.structure as struc
import numpy as np
from biotite.structure.atoms import AtomArray, AtomArrayStack

import atomworks.io.utils.bonds as cb
from atomworks.io.constants import PDB_MIRROR_PATH
from atomworks.io.utils.scatter import apply_group_wise, apply_segment_wise


def get_pdb_path(pdbid: str, mirror_path: str | os.PathLike = PDB_MIRROR_PATH) -> str:
    """Get the local path to a PDB file based on the provided mirror path.

    Args:
        pdbid (str): The PDB ID.
        mirror_path (str | os.PathLike, optional): Path to the PDB mirror directory.
            Defaults to PDB_MIRROR_PATH constant.

    Returns:
        str: The local path to the PDB file.

    Raises:
        FileNotFoundError: If the file does not exist at the expected location or
            if no mirror path is provided.
    """
    if mirror_path is None:
        raise FileNotFoundError("No mirror path provided.")
    pdbid = pdbid.lower()
    filename = os.path.join(mirror_path, pdbid[1:3], f"{pdbid}.cif.gz")
    if not os.path.exists(filename):
        raise FileNotFoundError(f"File {filename} does not exist")
    return filename


def is_same_in_segment(segment_start_stop: np.ndarray, data: np.ndarray, raise_if_false: bool = False) -> np.ndarray:
    """Check if all elements in a segment are the same.

    Args:
        segment_start_stop (np.ndarray): Array of segment start and stop indices (end of segment is inclusive),
            as obtained from `struc.get_residue_starts(... add_exclusive_stop=True)` for example.
        data (np.ndarray): Data array to check for sameness within segments.

    Returns:
        np.ndarray: Boolean array indicating whether all elements in each segment are the same.
    """
    all_same = lambda x: np.all(x == x[0]) if len(x) > 0 else True  # noqa: E731
    is_segment_valid = apply_segment_wise(segment_start_stop, data, all_same)
    return is_segment_valid


def is_same_in_group(groups: np.ndarray, data: np.ndarray) -> np.ndarray:
    """
    Check if all elements in `data` are the same within each group defined by `groups`.

    Args:
        groups: 1D array of group identifiers, same length as `data`.
        data: 1D array of data values to check for sameness within each group.

    Returns:
        np.ndarray: Boolean array of shape (n_groups,) indicating whether all elements in each group are the same.

    Example:
        >>> groups = np.array([1, 1, 2, 2, 2, 3])
        >>> data = np.array([5, 5, 7, 7, 7, 9])
        >>> is_same_in_group(groups, data)
        array([ True,  True,  True])
        >>> data = np.array([5, 5, 7, 8, 7, 9])
        >>> is_same_in_group(groups, data)
        array([ True, False,  True])
    """
    is_same = lambda x: np.all(x == x[0]) if len(x) > 0 else True  # noqa: E731
    is_group_data_same = apply_group_wise(groups, data, is_same)
    return is_group_data_same


def is_monotonic_increasing(arr: np.ndarray, strict: bool = True) -> bool:
    """Check if an array is monotonically increasing.

    Args:
        arr (np.ndarray): Array to check.
    """
    if strict:
        return np.all(np.diff(arr) > 0)
    else:
        return np.all(np.diff(arr) >= 0)


def is_monotonic_decreasing(arr: np.ndarray, strict: bool = True) -> bool:
    """Check if an array is monotonically decreasing.

    Args:
        arr (np.ndarray): Array to check.
    """
    if strict:
        return np.all(np.diff(arr) < 0)
    else:
        return np.all(np.diff(arr) <= 0)


def is_increasing_without_gaps(arr: np.ndarray, strict: bool = True) -> bool:
    """Check if an array is monotonically increasing without gaps.

    Args:
        arr (np.ndarray): Array to check.
    """
    if strict:
        return np.all(np.diff(arr) == 0)
    else:
        return np.all((np.diff(arr) == 0) | (np.diff(arr) == 1))


def has_annotation(arr: AtomArray, annotation: str | list[str]) -> bool:
    """Check if an AtomArray has an annotation.

    Args:
        arr: AtomArray to check.
        annotation: Annotation(s) to check for.
    """
    existing_annotations = frozenset(["coord", *arr.get_annotation_categories()])
    if isinstance(annotation, str):
        return annotation in existing_annotations
    else:
        return set(annotation).issubset(existing_annotations)


def _get_atom_array_stats(arr: AtomArray) -> str:
    msg = f"AtomArray: {len(arr)} atoms, {struc.get_residue_count(arr)} residues, {struc.get_chain_count(arr)} chains\n"
    msg += f"\t... unique chain ids: {np.unique(arr.chain_id)}\n"
    msg += f"\t... unique residue ids: {np.unique(arr.res_id)}\n"
    msg += f"\t... unique atom types: {np.unique(arr.atom_name)}\n"
    msg += f"\t... unique elements: {np.unique(arr.element)}\n"
    return msg


[docs] def assert_same_atom_array( arr1: AtomArray | AtomArrayStack, arr2: AtomArray | AtomArrayStack, compare_coords: bool = True, compare_bonds: bool = True, compare_box: bool = False, annotations_to_compare: list[str] | None = None, enforce_order: bool = True, compare_bond_order: bool = True, _n_mismatches_to_show: int = 5, ) -> None: """Asserts that two AtomArray objects are equal. Args: arr1 (AtomArray): The first AtomArray to compare. arr2 (AtomArray): The second AtomArray to compare. compare_coords (bool, optional): Whether to compare coordinates. Defaults to True. compare_bonds (bool, optional): Whether to compare bonds. Defaults to True. compare_box (bool, optional): Whether to compare the box attribute. Defaults to False. annotations_to_compare (list[str] | None, optional): List of annotation categories to compare. Defaults to None, in which case all annotations are compared. enforce_order (bool, optional): Whether to enforce the order of the atoms. Defaults to True. NOTE: Enforcing order is much faster; use False only when strictly necessary. compare_bond_order (bool, optional): Whether to compare bond order. Defaults to True. _n_mismatches_to_show (int, optional): Number of mismatches to show. Defaults to 20. WARNING: If AtomArrayStack objects are passed, only the first array is compared. Raises: AssertionError: If the AtomArray objects are not equal. """ assert isinstance( arr1, AtomArray | AtomArrayStack ), f"arr1 is not an AtomArray or AtomArrayStack but has type {type(arr1)}" assert isinstance( arr2, AtomArray | AtomArrayStack ), f"arr2 is not an AtomArray or AtomArrayStack but has type {type(arr2)}" # If the input is a stack, only compare the first array if isinstance(arr1, AtomArrayStack): arr1 = arr1[0] if isinstance(arr2, AtomArrayStack): arr2 = arr2[0] # Compare lengths, down to the residue-level if necessary if len(arr1) != len(arr2): msg = "AtomArrays are not the same length!\n" # Find the chains that are different lengths for chain_id in np.unique(arr1.chain_id): arr1_chain_aa = arr1[arr1.chain_id == chain_id] arr2_chain_aa = arr2[arr2.chain_id == chain_id] if len(arr1_chain_aa) != len(arr2_chain_aa): msg += f"+--------- Mismatches for chain: {chain_id} -----------+\n" # Find the residues that are different lengths for res_id in np.unique(arr1_chain_aa.res_id): arr1_res_aa = arr1_chain_aa[arr1_chain_aa.res_id == res_id] arr2_res_aa = arr2_chain_aa[arr2_chain_aa.res_id == res_id] # Give an informative error message if len(arr1_res_aa) != len(arr2_res_aa): msg += f"Mismatch at residue {res_id}:\n" msg += f"\tarr1: {_get_atom_array_stats(arr1_res_aa)}\n" msg += f"\tarr2: {_get_atom_array_stats(arr2_res_aa)}\n" raise AssertionError(msg) if compare_coords: assert ( arr1.coord.shape == arr2.coord.shape ), f"Coord shapes do not match: {arr1.coord.shape} != {arr2.coord.shape}" if not np.allclose(arr1.coord, arr2.coord, equal_nan=True, atol=1e-3, rtol=1e-3): mismatch_idxs = np.where(arr1.coord != arr2.coord)[0] msg = f"Coords do not match at {len(mismatch_idxs)} indices. First few mismatches:" + "\n" for idx in mismatch_idxs[:_n_mismatches_to_show]: msg += f"\t{idx}: {arr1.coord[idx]} != {arr2.coord[idx]}\n" raise AssertionError(msg) # Not returned by `get_annotation_categories` if compare_box: if arr1._box is None: assert arr2._box is None else: assert np.array_equal(arr1._box, arr2._box, equal_nan=True) if annotations_to_compare is None: arr1_annotation_keys = arr1.get_annotation_categories() arr2_annotation_keys = arr2.get_annotation_categories() missing_in_arr1 = set(arr2_annotation_keys) - set(arr1_annotation_keys) missing_in_arr2 = set(arr1_annotation_keys) - set(arr2_annotation_keys) assert len(missing_in_arr1) == 0, f"Annotations missing in arr1: {missing_in_arr1}" assert len(missing_in_arr2) == 0, f"Annotations missing in arr2: {missing_in_arr2}" annotations_to_compare = arr1_annotation_keys if enforce_order: # Compare annotations directly for annotation in annotations_to_compare: if annotation not in arr1.get_annotation_categories(): raise AssertionError(f"Annotation {annotation} not in arr1.") if annotation not in arr2.get_annotation_categories(): raise AssertionError(f"Annotation {annotation} not in arr2.") # Check if the arrays contain floating-point numbers (in which case, we allow NaN == NaN) if np.issubdtype(arr1.get_annotation(annotation).dtype, np.floating) and np.issubdtype( arr2.get_annotation(annotation).dtype, np.floating ): arrays_equal = np.array_equal( arr1.get_annotation(annotation), arr2.get_annotation(annotation), equal_nan=True ) else: arrays_equal = np.array_equal( arr1.get_annotation(annotation), arr2.get_annotation(annotation), equal_nan=False ) if not arrays_equal: mismatch_idxs = np.where(arr1.get_annotation(annotation) != arr2.get_annotation(annotation))[0] msg = ( f"Annotation {annotation} does not match at {len(mismatch_idxs)} indices. First few mismatches:" + "\n" ) for idx in mismatch_idxs[:_n_mismatches_to_show]: msg += ( f"\t{idx}: {arr1.get_annotation(annotation)[idx]} != {arr2.get_annotation(annotation)[idx]}\n" ) if idx >= _n_mismatches_to_show: break raise AssertionError(msg) else: # Convert annotations to a sorted list of tuples and compare (order-invariant) def convert_atom_array_to_sorted_tuples(arr: AtomArray, annotations: list[str]) -> list[tuple]: atoms = [] for atom in arr: atom_info = [(annotation, atom.__getattr__(annotation)) for annotation in annotations] atoms.append(tuple(sorted(atom_info))) return sorted(atoms) arr1_atoms_sorted = convert_atom_array_to_sorted_tuples(arr1, annotations_to_compare) arr2_atoms_sorted = convert_atom_array_to_sorted_tuples(arr2, annotations_to_compare) if arr1_atoms_sorted != arr2_atoms_sorted: msg = "Annotations do not match. First few mismatches:\n" for idx, atom in enumerate(set(arr1_atoms_sorted).symmetric_difference(set(arr2_atoms_sorted))): msg += f"\t{idx}: {atom}\n" if idx >= _n_mismatches_to_show: break raise AssertionError(msg) if compare_bonds: assert arr1.bonds is not None, "arr1.bonds is None" assert arr2.bonds is not None, "arr2.bonds is None" if enforce_order: # Compare bond arrays directly bonds1 = arr1.bonds.as_array() bonds2 = arr2.bonds.as_array() if not compare_bond_order: bonds1 = bonds1[:, :2] bonds2 = bonds2[:, :2] if not np.array_equal(bonds1, bonds2): mismatch_idxs = np.where(bonds1 != bonds2)[0] msg = f"Bonds do not match at {len(mismatch_idxs)} indices. First few mismatches:" + "\n" for idx in mismatch_idxs[:_n_mismatches_to_show]: msg += f"\t{idx}: {bonds1[idx]} != {bonds2[idx]}\n" raise AssertionError(msg) else: # Check graph isomorphisms, labeling nodes with element arr1_hash = cb.hash_atom_array( arr1, annotations=["element"], bond_order=compare_bond_order, cast_aromatic_bonds_to_same_type=True ) arr2_hash = cb.hash_atom_array( arr2, annotations=["element"], bond_order=compare_bond_order, cast_aromatic_bonds_to_same_type=True ) assert arr1_hash == arr2_hash, f"Graph hashes do not match: {arr1_hash} != {arr2_hash}"
def has_ambiguous_annotation_set( atom_array: AtomArray, annotation_set: Iterable = ["chain_id", "res_id", "res_name", "atom_name", "ins_code"], ) -> bool: """Detect whether a given set of annotations is insufficient to distinguish all atoms in the input AtomArray. For example, this is used to detect ambiguous annotation of the structure that would lead to loss of information when writing out the structure. This happens because the `struct_conn` category distinguishes bonds between different atoms based on the 5-tuple: (chain_id, res_id, res_name, atom_name, ins_code) To properly save bonds with a structure, make sure that all atoms have unique 5-tuples. Args: atom_array (AtomArray): The atom array to check for ambiguous annotations. annotation_set (Iterable, optional): The set of annotations to check for ambiguity. Defaults to ["chain_id", "res_id", "res_name", "atom_name", "ins_code"], which is relevant for determining possible bond ambiguity. Returns: bool: True if ambiguous annotations are detected, False otherwise. """ # Create a structured array with the 5-tuple elements identifier_dtypes = [ ( annotation, atom_array.get_annotation(annotation).dtype if annotation in atom_array.get_annotation_categories() else "U1", ) for annotation in annotation_set ] structured_array = np.empty(atom_array.array_length(), dtype=identifier_dtypes) for category in identifier_dtypes: name, dtype = category structured_array[name] = ( atom_array.get_annotation(name) if name in atom_array.get_annotation_categories() else ["."] * atom_array.array_length() ) # Use numpy's unique function with return_counts=True to find duplicates _, counts = np.unique(structured_array, return_counts=True) # If any count is greater than 1, we have ambiguous annotations return np.any(counts > 1)