Source code for atomworks.ml.utils.testing

import os

import numpy as np
from biotite.structure import AtomArray, CellList

from atomworks.io import parse
from atomworks.io.common import immutable_lru_cache
from atomworks.io.constants import PDB_MIRROR_PATH
from atomworks.ml.preprocessing.constants import CELL_SIZE
from atomworks.ml.preprocessing.utils.structure_utils import get_atom_mask_from_cell_list


[docs] def get_pdb_mirror_path(pdbid: str, base_dir: str = PDB_MIRROR_PATH) -> str: """Convenience util to get the path to a CIF file on the DIGS""" # Assert that the base directory exists assert os.path.exists(base_dir), f"Base directory {base_dir} does not exist" # Build the path to the file pdbid = pdbid.lower() filename = f"{base_dir}/{pdbid[1:3]}/{pdbid}.cif.gz" if not os.path.exists(filename): raise ValueError(f"File {filename} does not exist") return filename
[docs] @immutable_lru_cache(maxsize=1000) def cached_parse(pdb_id: str, **kwargs) -> dict: """Wrapper around _cached_parse with caching to return an immutable copy of the output dict""" data = parse(filename=get_pdb_mirror_path(pdb_id), **kwargs) if "atom_array" not in data: assembly_ids = list(data["assemblies"].keys()) data["atom_array"] = data["assemblies"][assembly_ids[0]][0] data["pdb_id"] = pdb_id return data
[docs] def is_clash(atom_array_1: AtomArray, atom_array_2: AtomArray, clash_distance: float = 1.0) -> bool: """ Checks for clashes between two arrays. Based on atomworks.ml.preprocessing.process.DataPreprocessor. Recommended to pass in minimal masks of arrays to check to reduce runtime. """ cell_list = CellList(atom_array_1, cell_size=CELL_SIZE) clashing_atom_mask = get_atom_mask_from_cell_list(atom_array_2.coord, cell_list, len(atom_array_1), clash_distance) return np.any(clashing_atom_mask)