import logging
from collections import defaultdict
from typing import Any, ClassVar, Literal
import biotite.structure as struc
import numpy as np
import torch
from biotite.structure import AtomArray
from rdkit import Chem
from atomworks.enums import GroundTruthConformerPolicy
from atomworks.io.constants import CCD_MIRROR_PATH, ELEMENT_NAME_TO_ATOMIC_NUMBER, UNKNOWN_LIGAND
from atomworks.io.tools.rdkit import atom_array_from_rdkit, remove_hydrogens
from atomworks.io.utils.ccd import get_available_ccd_codes
from atomworks.io.utils.selection import get_residue_starts
from atomworks.ml.common import exists
from atomworks.ml.transforms._checks import check_atom_array_annotation, check_contains_keys, check_is_instance
from atomworks.ml.transforms.base import Transform
from atomworks.ml.transforms.rdkit_utils import (
ccd_code_to_rdkit_with_conformers,
sample_rdkit_conformer_for_atom_array,
)
from atomworks.ml.utils.geometry import masked_center, random_rigid_augmentation
logger = logging.getLogger("atomworks.ml")
# UNL is a special CCD code for unknown ligands; we do not consider it "known" as it has no structure
KNOWN_CCD_CODES = get_available_ccd_codes(CCD_MIRROR_PATH) - {UNKNOWN_LIGAND}
def _extract_cached_conformers(
res_stochiometry: dict[str, int],
max_conformers_per_residue: int | None,
cached_residue_level_data: dict | None,
) -> tuple[dict[str, Chem.Mol], dict[str, int]]:
"""Extract cached conformers and return remaining stochiometry."""
cached_mols = {}
remaining_stochiometry = res_stochiometry.copy()
if cached_residue_level_data is None:
return cached_mols, remaining_stochiometry
for res_name, count in res_stochiometry.items():
needed_conformers = min(count, max_conformers_per_residue) if max_conformers_per_residue is not None else count
if res_name in cached_residue_level_data:
# (We remove hydrogens to be consistent with on-the-fly conformer generation)
cached_mol = remove_hydrogens(cached_residue_level_data[res_name].get("mol"))
if cached_mol is not None and cached_mol.GetNumConformers() >= needed_conformers:
# We have enough cached conformers - use the cached mol
cached_mols[res_name] = cached_mol
del remaining_stochiometry[res_name]
return cached_mols, remaining_stochiometry
def _get_rdkit_mols_with_conformers(
res_stochiometry: dict[str, int],
max_conformers_per_residue: int | None = None,
timeout: float | None | tuple[float, float] = (3.0, 0.15),
timeout_strategy: Literal["signal", "subprocess"] = "subprocess",
**generate_conformers_kwargs,
) -> dict[str, Chem.Mol]:
"""Generate RDKit molecules with conformers for each residue in bulk (given the counts in `res_stochiometry`).
Args:
res_stochiometry: A dictionary mapping residue names to their count.
max_conformers_per_residue: Maximum number of conformers to generate per residue type.
If None, generates conformers equal to the count. If set, generates min(count, max_conformers_per_residue).
timeout: The timeout for conformer generation. If None, no timeout is applied and
the timeout strategy is ignored (no subprocesses will be spawned). Defaults to (3.0, 0.15), which
gives a timeout of 3.0 + 0.15 * (n_conformers - 1) seconds per unique CCD code.
timeout_strategy: The strategy to use for the timeout. Defaults to "subprocess".
**generate_conformers_kwargs: Additional keyword arguments to pass to the
generate_conformers function.
Returns:
A dictionary mapping residue names to RDKit molecules with generated conformers.
Note:
This function uses the res_name_to_rdkit_with_conformers function to generate conformers
for each residue. If conformer generation fails or times out for a residue, it falls back
to using the idealized conformer from the CCD entry if available.
Reference:
- https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf
"""
ref_mols = {}
for res_name, count in res_stochiometry.items():
if res_name not in KNOWN_CCD_CODES:
ref_mols[res_name] = None # placeholder so that the unknown CCD codes are still counted later on
continue
n_conformers_to_generate = (
min(count, max_conformers_per_residue) if max_conformers_per_residue is not None else count
)
mol = ccd_code_to_rdkit_with_conformers(
ccd_code=res_name,
n_conformers=n_conformers_to_generate,
timeout=timeout,
timeout_strategy=timeout_strategy,
**generate_conformers_kwargs,
)
ref_mols[res_name] = mol
return ref_mols
def _encode_atom_names_like_af3(atom_names: np.ndarray) -> np.ndarray:
"""Encodes atom names like AF3.
This generates the `ref_atom_name_chars` feature used in AF3.
One-hot encoding of the unique atom names in the reference conformer.
Each character is encoded as ord(c) - 32, and names are padded to
length 4.
Reference:
- https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf
"""
# Ensure uppercase
atom_names = np.char.upper(atom_names)
# Turn into 4 character ASCII string (this truncates longer atom names)
atom_names = atom_names.astype("|S4")
# Pad to 4 char string with " " (ord(" ") = 32)
atom_names = np.char.ljust(atom_names, width=4, fillchar=" ")
# Interpret ASCII bytes to uint8
atom_names = atom_names.view(np.uint8)
# Reshape to (N, 4) and subtract 32 to get back to range [0, 64]
return atom_names.reshape(-1, 4) - 32
def _map_reference_conformer_to_residue(
res_name: str, atom_names: np.ndarray, conformer: AtomArray
) -> tuple[np.ndarray, np.ndarray]:
"""Maps the coordinate information from a reference conformer to a
given residue, dropping all atoms that are not in the residue.
Args:
- res_name (str): The name of the residue to map to.
- atom_names (np.ndarray): Array of atom names in the residue to map to.
- conformer (AtomArray): The reference conformer.
Returns:
- ref_pos (np.ndarray): Reference positions for atoms in the residue.
- ref_mask (np.ndarray): Mask indicating valid reference positions.
"""
# ... mark the atoms that are in the residue (keep) and where they are in the residue (to_within_res_idx)
keep = np.zeros(len(conformer), dtype=bool) # [n_atoms_in_conformer]
# Mapping from conformer atom indices to residue atom indices
to_within_res_idx = -np.ones(len(conformer), dtype=int) # [n_atoms_in_conformer]
for i, atom_name in enumerate(atom_names):
matching_atom_idx = np.where(conformer.atom_name == atom_name)[0]
if len(matching_atom_idx) == 0:
logger.warning(f"Atom {atom_name} not found in conformer for residue {res_name} with {atom_names=}.")
continue
matching_atom_idx = matching_atom_idx[0]
keep[matching_atom_idx] = True
to_within_res_idx[matching_atom_idx] = i
# ... fill the reference positions
# (We must handle the case where to_within_res_idx[keep] contains indices out of bounds for the filtered conformer)
kept_atoms = np.where(keep)[0]
ordering = np.array([to_within_res_idx[idx] for idx in kept_atoms])
coord = conformer.coord[kept_atoms][np.argsort(ordering)] # [n_atoms_in_res, 3]
ref_pos = coord
ref_mask = np.isfinite(coord).all(axis=-1) # [n_atoms_in_res]
return ref_pos, ref_mask # [n_atoms_in_res, 3], [n_atoms_in_res]
[docs]
def get_af3_reference_molecule_features(
atom_array: AtomArray,
conformer_generation_timeout: float | tuple[float, float] = (3.0, 0.15),
apply_random_rotation_and_translation: bool = True,
use_element_for_atom_names_of_atomized_tokens: bool = False,
timeout_strategy: Literal["signal", "subprocess"] = "subprocess",
max_conformers_per_residue: int | None = None,
cached_residue_level_data: dict | None = None,
residue_conformer_indices: dict[int, np.ndarray] | None = None,
**generate_conformers_kwargs,
) -> tuple[dict[str, Any], dict[str, Chem.Mol]]:
"""Get AF3 reference features for each residue in the atom array.
Args:
atom_array: The input atom array.
conformer_generation_timeout: Maximum time allowed for conformer generation per residue.
Defaults to (3.0, 0.15), which gives a timeout of 3.0 + 0.15 * (n_conformers - 1) seconds.
If None, no timeout is applied and the timeout strategy is ignored (no subprocesses will be spawned).
apply_random_rotation_and_translation: Whether to apply a random rotation and translation to each conformer (AF-3-style)
timeout_strategy: The strategy to use for the timeout.
Defaults to "subprocess" (which is the most reliable choice).
max_conformers_per_residue: Maximum number of conformers to generate per residue type.
If None, generates conformers equal to residue count. If set, generates min(count, max_conformers_per_residue)
and randomly samples from those conformers for each residue instance.
cached_residue_level_data: Optional cached conformer data by residue name. If provided,
cached conformers will be preferred over generated ones.
residue_conformer_indices: Optional mapping of global residue IDs to specific conformer indices.
If provided, these specific conformers will be used for the corresponding residues.
**generate_conformers_kwargs: Additional keyword arguments to pass to the generate_conformers function.
Returns:
ref_conformer: A dictionary containing the generated reference features.
ref_mols: A dictionary containing all generated RDKit molecules, including those with unknown CCD codes.
This function generates the following reference features, following AF3:
- ref_pos: [N_atoms, 3] Atom positions in the reference conformer, with a random rotation and
translation applied. Atom positions are given in Å.
- ref_mask: [N_atoms] Mask indicating which atom slots are used in the reference conformer.
- ref_element: [N_atoms, 128] One-hot encoding of the element atomic number for each atom in the
reference conformer, up to atomic number 128.
- ref_charge: [N_atoms] Charge for each atom in the reference conformer.
- ref_atom_name_chars: [N_atoms, 4, 64] One-hot encoding of the unique atom names in the reference conformer.
Each character is encoded as ord(c) - 32, and names are padded to length 4.
- ref_space_uid: [N_atoms] Numerical encoding of the chain id and residue index associated with
this reference conformer. Each (chain id, residue index) tuple is assigned an integer on first appearance.
(Optionally) The following custom features, helpful for extra conditioning:
- ref_pos_is_ground_truth (optional): [N_atoms] Whether the reference conformer is the ground-truth conformer.
Determined by the `ground_truth_conformer_policy` annotation.
- ref_pos_ground_truth (optional): [N_atoms, 3] The ground-truth conformer positions.
Determined by the `ground_truth_conformer_policy` annotation.
- is_atomized_atom_level: [N_atoms] Whether the atom is atomized (atom-level version of "is_ligand")
Reference:
- Section 2.8 of the AF3 supplementary information
https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf
"""
_has_ground_truth_conformer_policy = "ground_truth_conformer_policy" in atom_array.get_annotation_categories()
_has_global_res_id = "res_id_global" in atom_array.get_annotation_categories()
# Generate reference conformers for each residue (if cropped, each residue that has tokens in the crop)
# ... get residue-level stochiometry
_res_start_ends = get_residue_starts(atom_array, add_exclusive_stop=True)
_res_starts, _res_ends = _res_start_ends[:-1], _res_start_ends[1:]
_res_names = atom_array.res_name[_res_starts]
res_stochiometry = dict(zip(*np.unique(_res_names, return_counts=True), strict=False))
# Extract cached conformers and get remaining stochiometry
if cached_residue_level_data is not None:
cached_mols, remaining_stochiometry = _extract_cached_conformers(
res_stochiometry=res_stochiometry,
max_conformers_per_residue=max_conformers_per_residue,
cached_residue_level_data=cached_residue_level_data,
)
else:
cached_mols, remaining_stochiometry = {}, res_stochiometry
# ... get reference molecules with conformers for remaining residues
# (We do not generate conformers for unknown CCD codes here, as we will do that later)
generated_mols = _get_rdkit_mols_with_conformers(
res_stochiometry=remaining_stochiometry,
max_conformers_per_residue=max_conformers_per_residue,
hydrogen_policy="remove",
timeout=conformer_generation_timeout,
timeout_strategy=timeout_strategy,
**generate_conformers_kwargs,
)
# Merge cached and generated molecules
ref_mols = {**cached_mols, **generated_mols}
# ... generate conformers for CCD codes that are unknown (including UNL)
unknown_ccd_conformers = defaultdict(list)
if not all(res_name in KNOWN_CCD_CODES for res_name in res_stochiometry):
res_indices_with_unknown = np.where(~np.isin(_res_names, list(KNOWN_CCD_CODES)))[0]
for res_index in res_indices_with_unknown:
res_name = _res_names[res_index]
conf_i, mol_i = sample_rdkit_conformer_for_atom_array(
atom_array[_res_starts[res_index] : _res_ends[res_index]],
timeout=conformer_generation_timeout,
timeout_strategy=timeout_strategy,
return_mol=True,
**generate_conformers_kwargs,
)
unknown_ccd_conformers[res_name].append(conf_i)
ref_mols[res_name] = mol_i
# ... initialize reference features
ref_pos = np.zeros((len(atom_array), 3), dtype=np.float32)
ref_mask = np.zeros(len(atom_array), dtype=bool)
if _has_ground_truth_conformer_policy:
ref_pos_is_ground_truth = np.zeros(len(atom_array), dtype=bool)
ref_pos_ground_truth = np.zeros((len(atom_array), 3), dtype=np.float32)
# Fill `ref_pos` and `ref_mask` arrays
# ... helper variable to keep track of the next conformer to use for each residue type
_next_conf_idx = {res_name: 0 for res_name in ref_mols}
# ... iterate over all residues in the atom array and fill the `ref_pos` and `ref_mask` arrays using the next reference conformer for each residue type
# We also check the `ground_truth_conformer_policy` annotation to see if we should use the ground-truth conformer
for res_start, res_end in zip(_res_starts, _res_ends, strict=False):
res_name = atom_array.res_name[res_start]
if _has_global_res_id and residue_conformer_indices is not None:
res_global_id = int(atom_array.res_id_global[res_start]) # Convert to Python int
if res_global_id in residue_conformer_indices:
conformer_indices = residue_conformer_indices[res_global_id]
# (We don't yet support multiple conformers per residue, so we just use the first one, which is random anyhow)
conf_idx = int(conformer_indices[0] if isinstance(conformer_indices, np.ndarray) else conformer_indices)
else:
conf_idx = _next_conf_idx[res_name]
else:
conf_idx = _next_conf_idx[res_name]
# ... turn conformer into an atom array
if res_name not in KNOWN_CCD_CODES:
# (conformers for unknown CCD codes are already atom arrays, since we generated them directly)
conformer = unknown_ccd_conformers[res_name][conf_idx % len(unknown_ccd_conformers[res_name])]
else:
# Ensure conf_idx is within bounds for generated conformers
n_conformers = ref_mols[res_name].GetNumConformers()
conformer = atom_array_from_rdkit(
ref_mols[res_name],
conformer_id=conf_idx % n_conformers,
remove_hydrogens=True,
)
if _has_ground_truth_conformer_policy:
_has_valid_ground_truth = ~np.isnan(atom_array.coord[res_start:res_end]).any()
ground_truth_conformer = None
if not _has_valid_ground_truth:
logger.debug(
"Ground-truth conformer policy set, but NaNs found in the atom array. Conformer policy will be treated as IGNORE."
)
else:
# We REPLACE the generated conformer with the ground-truth conformer if either:
# (a) the ground-truth conformer policy is set to "REPLACE" for all atoms in the residue
# (b) the current conformer is all 0's/NaN's (i.e., the conformer generation failed), and the policy is set to "FALLBACK" for all atoms in the residue
if np.all(
atom_array.ground_truth_conformer_policy[res_start:res_end] == GroundTruthConformerPolicy.REPLACE
) or (
np.all(np.nan_to_num(conformer.coord) == 0)
and np.all(
atom_array.ground_truth_conformer_policy[res_start:res_end]
== GroundTruthConformerPolicy.FALLBACK
)
):
# NOTE: Inefficient since we generate with RDKit, and then discard, the conformer; however, this replacement-based approach is more interpretable and thus preferred
# ... use the ground-truth AtomArray (e.g., during inference if we provide a SDF, or if we want to leak ligand geometry)
conformer = atom_array[res_start:res_end]
# (Center around the origin to avoid leaking 1D information)
conformer.coord = masked_center(conformer.coord)
ref_pos_is_ground_truth[res_start:res_end] = True
# We ADD another feature, `ref_pos_ground_truth`, if the policy is set to "ADD" for all atoms in the residue
if np.all(
atom_array.ground_truth_conformer_policy[res_start:res_end] == GroundTruthConformerPolicy.ADD
):
if np.isnan(atom_array.coord[res_start:res_end]).any():
logger.warning(
"Ground-truth conformer requested, but NaNs found in the atom array. Conformer will not be replaced with ground truth."
)
else:
ground_truth_conformer = atom_array[res_start:res_end]
ground_truth_conformer.coord = masked_center(ground_truth_conformer.coord)
# ... map the reference conformer information to the given residue
_ref_pos, _ref_mask = _map_reference_conformer_to_residue(
res_name=res_name,
atom_names=atom_array.atom_name[res_start:res_end],
conformer=conformer,
)
# ... apply a random rotation and translation to the reference conformer, if requested
if apply_random_rotation_and_translation:
# TODO: Implement more elegantly directly in numpy
_ref_pos = random_rigid_augmentation(torch.from_numpy(_ref_pos[np.newaxis, :]), batch_size=1).numpy()
# ... fill the reference features for this residue
ref_pos[res_start:res_end] = _ref_pos
ref_mask[res_start:res_end] = _ref_mask
# (Repeat for the ground truth conformer, if adding through an additional feature)
if _has_ground_truth_conformer_policy and exists(ground_truth_conformer):
_ref_pos_ground_truth, _ = _map_reference_conformer_to_residue(
res_name=res_name,
atom_names=atom_array.atom_name[res_start:res_end],
conformer=ground_truth_conformer,
)
if apply_random_rotation_and_translation:
_ref_pos_ground_truth = random_rigid_augmentation(
torch.from_numpy(_ref_pos_ground_truth[np.newaxis, :]), batch_size=1
).numpy()
ref_pos_ground_truth[res_start:res_end] = _ref_pos_ground_truth
# ... update to the next conformer index
_next_conf_idx[res_name] += 1
# Generate remaining reference features
# ... element
ref_element = (
atom_array.atomic_number
if "atomic_number" in atom_array.get_annotation_categories()
else np.vectorize(ELEMENT_NAME_TO_ATOMIC_NUMBER.get)(atom_array.element)
)
# ... charge
ref_charge = atom_array.charge
# ... atom name
ref_atom_name_chars = _encode_atom_names_like_af3(atom_array.atom_name)
if use_element_for_atom_names_of_atomized_tokens:
assert (
"atomize" in atom_array.get_annotation_categories()
), "Atomize annotation is required when using element for atom names of atomized tokens."
ref_atom_name_chars[atom_array.atomize] = _encode_atom_names_like_af3(atom_array.element[atom_array.atomize])
# ... space uid (type conversion needed for some older torch versions)
# we assign a unique integer for each residue instance:
ref_space_uid = struc.segments.spread_segment_wise(_res_start_ends, np.arange(len(_res_starts), dtype=np.int64))
is_atomized_atom_level = atom_array.atomize if "atomize" in atom_array.get_annotation_categories() else None
ref_conformer = {
"ref_pos": ref_pos, # (n_atoms, 3)
"ref_mask": ref_mask, # (n_atoms)
"ref_element": ref_element, # (n_atoms)
"ref_charge": ref_charge, # (n_atoms)
"ref_atom_name_chars": ref_atom_name_chars, # (n_atoms, 4)
"ref_space_uid": ref_space_uid, # (n_atoms)
"is_atomized_atom_level": is_atomized_atom_level, # (n_atoms)
}
if _has_ground_truth_conformer_policy:
ref_conformer["ref_pos_ground_truth"] = ref_pos_ground_truth # (n_atoms, 3)
ref_conformer["ref_pos_is_ground_truth"] = ref_pos_is_ground_truth # (n_atoms)
return ref_conformer, ref_mols
[docs]
class GetAF3ReferenceMoleculeFeatures(Transform):
"""Generate AF3 reference molecule features for each residue in the atom array.
This transform adds the following features to the data dictionary under the 'feats' key, following AF3:
- ref_pos: [N_atoms, 3] Atom positions in the reference conformer, with a random rotation and
translation applied. Atom positions are given in Å.
- ref_mask: [N_atoms] Mask indicating which atom slots are used in the reference conformer.
- ref_element: [N_atoms] One-hot encoding of the element atomic number for each atom in the
reference conformer, up to atomic number 128.
- ref_charge: [N_atoms] Charge for each atom in the reference conformer.
- ref_atom_name_chars: [N_atoms, 4, 64] One-hot encoding of the unique atom names in the reference conformer.
Each character is encoded as ord(c) - 32, and names are padded to length 4.
- ref_space_uid: [N_atoms] Numerical encoding of the chain id and residue index associated with
this reference conformer. Each (chain id, residue index) tuple is assigned an integer on first appearance.
And the following custom features, helpful for extra conditioning/downstream use:
- ref_pos_is_ground_truth: [N_atoms] Whether the reference conformer is the ground-truth conformer.
Determined by the `ground_truth_conformer_policy` annotation.
- ref_pos_ground_truth: [N_atoms, 3] The ground-truth conformer positions.
Determined by the `ground_truth_conformer_policy` annotation.
- is_atomized_atom_level: [N_atoms] Whether the atom is atomized (atom-level version of "is_ligand")
Note: This transform should be applied after cropping.
Reference:
- Section 2.8 of the AF3 supplementary information
https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf
"""
def __init__(
self,
conformer_generation_timeout: float = 10.0,
save_rdkit_mols: bool = True,
use_element_for_atom_names_of_atomized_tokens: bool = False,
apply_random_rotation_and_translation: bool = True,
max_conformers_per_residue: int | None = None,
use_cached_conformers: bool = True,
**generate_conformers_kwargs,
):
self.conformer_generation_timeout = conformer_generation_timeout
self.generate_conformers_kwargs = generate_conformers_kwargs
self.save_rdkit_mols = save_rdkit_mols
self.use_element_for_atom_names_of_atomized_tokens = use_element_for_atom_names_of_atomized_tokens
self.apply_random_rotation_and_translation = apply_random_rotation_and_translation
self.max_conformers_per_residue = max_conformers_per_residue
self.use_cached_conformers = use_cached_conformers
if self.use_element_for_atom_names_of_atomized_tokens:
logger.warning("Using element type for atom names of atomized tokens.")
[docs]
def forward(self, data: dict) -> dict:
atom_array = data["atom_array"]
# Extract cached data and conformer indices, if enabled
cached_residue_level_data = None
if self.use_cached_conformers and "cached_residue_level_data" in data:
cached_residue_level_data = data["cached_residue_level_data"]["residues"]
residue_conformer_indices = data.get("residue_conformer_indices") if self.use_cached_conformers else None
# Generate reference features
reference_features, rdkit_mols = get_af3_reference_molecule_features(
atom_array,
conformer_generation_timeout=self.conformer_generation_timeout,
use_element_for_atom_names_of_atomized_tokens=self.use_element_for_atom_names_of_atomized_tokens,
apply_random_rotation_and_translation=self.apply_random_rotation_and_translation,
max_conformers_per_residue=self.max_conformers_per_residue,
cached_residue_level_data=cached_residue_level_data,
residue_conformer_indices=residue_conformer_indices,
**self.generate_conformers_kwargs,
)
# Add reference features to the 'feats' dictionary
if "feats" not in data:
data["feats"] = {}
data["feats"].update(reference_features)
if self.save_rdkit_mols:
if "rdkit" not in data:
data["rdkit"] = {}
data["rdkit"].update(rdkit_mols)
return data