Transforms#

This module contains various transformation classes and utilities for processing molecular data.

Core Transform Classes#

Base classes for transformations.

class atomworks.ml.transforms.base.AddData(data: dict, allow_overwrite: bool = False)[source]#

Bases: Transform

Add data to the data dictionary.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.base.ApplyFunction(func: callable)[source]#

Bases: Transform

Applies a function to the data dictionary.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.base.Compose(transforms: list[Transform], track_rng_state: bool = True, print_rng_state: bool = False)[source]#

Bases: Transform

Compose multiple transformations together.

This class allows you to chain multiple transformations and apply them sequentially to a data dictionary. It is particularly useful for preprocessing pipelines where multiple steps need to be applied in a specific order.

- transforms

A list of transformations to be applied.

Type:

list[Transform]

- track_rng_state

Whether to track and serialize the random number generator (RNG) state. This is useful for debugging when dealing with probabilistic transformations. The RNG state is returned with the error message if the transform pipeline fails, allowing you to instantiate the same RNG state with eval for debugging.

Type:

bool

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict, rng_state_dict: dict[str, Any] | None = None, _stop_before: Transform | str | int | None = None) dict[source]#

Apply a series of transformations to the input data.

Parameters:
  • data (dict) – The input data to be transformed.

  • rng_state_dict (dict[str, Any] | None, optional) – Random number generator state dictionary. If provided, sets the RNG state before applying transforms. Defaults to None.

  • _stop_before (Transform | str | int | None, optional) – Specifies a point to stop the transformation process. Can be a Transform instance, a string (transform class name), or an integer (index). Defaults to None.

Returns:

The transformed data.

Return type:

dict

Raises:

Exception – If any transform in the pipeline fails, with details about the failure point and RNG state.

class atomworks.ml.transforms.base.ConditionalRoute(condition_func: Callable[[dict[str, Any]], Any], transform_map: dict[Any, Transform])[source]#

Bases: Transform

Route conditionally between various transforms.

This Transform is useful for routing between different transforms based on a condition, e.g. skipping transforms during inference.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply the appropriate transformation based on the condition value.

Parameters:

data (dict[str, Any]) – The input data dictionary.

Returns:

The transformed data dictionary.

Return type:

dict[str, Any]

class atomworks.ml.transforms.base.ConvertToTorch(keys: list[str], device: str = 'cpu')[source]#

Bases: Transform

Converts the contents of specified data keys to torch tensors and moves them to the specified device.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.base.Identity[source]#

Bases: Transform

Identity transform. Does nothing and just passes the data through.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

raise_if_invalid_input: bool = False#
validate_input: bool = False#
class atomworks.ml.transforms.base.LogData(log_level: int = 20, depth: int | None = 1, **pprint_kwargs)[source]#

Bases: Transform

Log the data dictionary. Meant for debugging.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.base.PickleToDisk(dir_path: str, file_name_func: Callable[[dict], str] | None = None, save_transform_history: bool = False, overwrite: bool = False)[source]#

Bases: Transform

Save the data dictionary to a pickle file.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.base.RaiseError(error_type: Exception = <class 'ValueError'>, error_message: str = 'User requested raising an error.')[source]#

Bases: Transform

Raises an error for testing and debugging purposes.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.base.RaiseOnCondition(condition: callable, error_message: str, exception_to_raise: type[Exception] = <class 'ValueError'>)[source]#

Bases: Transform

Raises a user-specified exception if a given condition is met.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.base.RandomRoute(transforms: list[Transform], probs: list[float])[source]#

Bases: Transform

Route probabilistically between various transforms.

This transform is useful for routing between different transforms probabilistically, e.g. for sampling different cropping strategies.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

classmethod from_dict(transform_dict: dict[Transform, float]) RandomRoute[source]#
classmethod from_list(transform_list: list[tuple[float, Transform]]) RandomRoute[source]#
raise_if_invalid_input: bool = False#
validate_input: bool = False#
class atomworks.ml.transforms.base.RemoveKeys(keys: list[str], require_keys_exist: bool = True)[source]#

Bases: Transform

Remove keys from the data dictionary.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.base.SubsetToKeys(keys: list[str], require_keys_exist: bool = True)[source]#

Bases: Transform

Keep only the keys in the data dictionary.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.base.Transform[source]#

Bases: ABC

Abstract base class for transformations on dictionary objects.

Class level attributes:
  • validate_input (bool): Whether to validate the input.

  • raise_if_invalid_input (bool): Whether to raise an error if the input is invalid.

  • requires_previous_transforms (list[str]): Transforms that must have been applied before this transform.

  • incompatible_previous_transforms (list[str]): Transforms that cannot have preceeded this transform.

  • previous_transforms_order_matters (bool): Whether the order of the transforms is important.

  • _track_transform_history (bool): Whether to track the transform history.

To write a subclass, you need to implement the following methods:
  • check_input(data: dict): Validates the input data. Should raise an error if the input is invalid.

    The returned value is not used.

  • forward(data: dict): Applies the transformation to the input data and returns the transformed data.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

abstract forward(data: dict[str, Any], *args, **kwargs) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str]] = []#
previous_transforms_order_matters: bool = False#
raise_if_invalid_input: bool = True#
requires_previous_transforms: ClassVar[list[str]] = []#
validate_input: bool = True#
exception atomworks.ml.transforms.base.TransformPipelineError(message: str, rng_state_dict: dict[str, Any] | None = None)[source]#

Bases: Exception

A custom error class for Transform pipelines (via Compose).

class atomworks.ml.transforms.base.TransformedDict(_TransformedDict__existing_dict_to_wrap: dict[str, Any] | None = None, **kwargs)[source]#

Bases: dict

A thin wrapper around a dictionary that can be used to track the transform history.

atomworks.ml.transforms.base.convert_to_torch(data: dict[str, Any], keys: list[str], device: str = 'cpu') dict[str, Any][source]#

Convert the contents of specified data keys to torch tensors and move them to the specified device.

For each given top-level data key, all nested numpy arrays are converted to torch tensors.

Parameters:
  • data (dict[str, Any]) – The input data dictionary.

  • keys (list[str]) – List of data keys within which to search for numpy arrays to convert to torch tensors.

  • device (str) – The device to which the tensors should be moved (e.g., ‘cpu’, ‘cuda’). Default is ‘cpu’.

Returns:

The data dictionary with numpy arrays converted to torch tensors.

Return type:

dict[str, Any]

class atomworks.ml.transforms.af3_reference_molecule.GetAF3ReferenceMoleculeFeatures(conformer_generation_timeout: float = 10.0, save_rdkit_mols: bool = True, use_element_for_atom_names_of_atomized_tokens: bool = False, apply_random_rotation_and_translation: bool = True, max_conformers_per_residue: int | None = None, use_cached_conformers: bool = True, **generate_conformers_kwargs)[source]#

Bases: Transform

Generate AF3 reference molecule features for each residue in the atom array.

This transform adds the following features to the data dictionary under the ‘feats’ key, following AF3:
  • ref_pos: [N_atoms, 3] Atom positions in the reference conformer, with a random rotation and translation applied. Atom positions are given in Å.

  • ref_mask: [N_atoms] Mask indicating which atom slots are used in the reference conformer.

  • ref_element: [N_atoms] One-hot encoding of the element atomic number for each atom in the reference conformer, up to atomic number 128.

  • ref_charge: [N_atoms] Charge for each atom in the reference conformer.

  • ref_atom_name_chars: [N_atoms, 4, 64] One-hot encoding of the unique atom names in the reference conformer. Each character is encoded as ord(c) - 32, and names are padded to length 4.

  • ref_space_uid: [N_atoms] Numerical encoding of the chain id and residue index associated with this reference conformer. Each (chain id, residue index) tuple is assigned an integer on first appearance.

And the following custom features, helpful for extra conditioning/downstream use:
  • ref_pos_is_ground_truth: [N_atoms] Whether the reference conformer is the ground-truth conformer. Determined by the ground_truth_conformer_policy annotation.

  • ref_pos_ground_truth: [N_atoms, 3] The ground-truth conformer positions. Determined by the ground_truth_conformer_policy annotation.

  • is_atomized_atom_level: [N_atoms] Whether the atom is atomized (atom-level version of “is_ligand”)

Note: This transform should be applied after cropping.

Reference:
check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.af3_reference_molecule.RandomApplyGroundTruthConformerByChainType(chain_type_probabilities: dict | None = None, default_probability: float = 0.0, policy: GroundTruthConformerPolicy = GroundTruthConformerPolicy.REPLACE)[source]#

Bases: Transform

Apply ground truth conformer policy with configurable probabilities per chain type.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['GetAF3ReferenceMoleculeFeatures']#
atomworks.ml.transforms.af3_reference_molecule.get_af3_reference_molecule_features(atom_array: AtomArray, conformer_generation_timeout: float | tuple[float, float] = (3.0, 0.15), apply_random_rotation_and_translation: bool = True, use_element_for_atom_names_of_atomized_tokens: bool = False, timeout_strategy: Literal['signal', 'subprocess'] = 'subprocess', max_conformers_per_residue: int | None = None, cached_residue_level_data: dict | None = None, residue_conformer_indices: dict[int, ndarray] | None = None, **generate_conformers_kwargs) tuple[dict[str, Any], dict[str, Mol]][source]#

Get AF3 reference features for each residue in the atom array.

Parameters:
  • atom_array – The input atom array.

  • conformer_generation_timeout – Maximum time allowed for conformer generation per residue. Defaults to (3.0, 0.15), which gives a timeout of 3.0 + 0.15 * (n_conformers - 1) seconds. If None, no timeout is applied and the timeout strategy is ignored (no subprocesses will be spawned).

  • apply_random_rotation_and_translation – Whether to apply a random rotation and translation to each conformer (AF-3-style)

  • timeout_strategy – The strategy to use for the timeout. Defaults to “subprocess” (which is the most reliable choice).

  • max_conformers_per_residue – Maximum number of conformers to generate per residue type. If None, generates conformers equal to residue count. If set, generates min(count, max_conformers_per_residue) and randomly samples from those conformers for each residue instance.

  • cached_residue_level_data – Optional cached conformer data by residue name. If provided, cached conformers will be preferred over generated ones.

  • residue_conformer_indices – Optional mapping of global residue IDs to specific conformer indices. If provided, these specific conformers will be used for the corresponding residues.

  • **generate_conformers_kwargs – Additional keyword arguments to pass to the generate_conformers function.

Returns:

A dictionary containing the generated reference features. ref_mols: A dictionary containing all generated RDKit molecules, including those with unknown CCD codes.

Return type:

ref_conformer

This function generates the following reference features, following AF3:
  • ref_pos: [N_atoms, 3] Atom positions in the reference conformer, with a random rotation and

    translation applied. Atom positions are given in Å.

  • ref_mask: [N_atoms] Mask indicating which atom slots are used in the reference conformer.

  • ref_element: [N_atoms, 128] One-hot encoding of the element atomic number for each atom in the

    reference conformer, up to atomic number 128.

  • ref_charge: [N_atoms] Charge for each atom in the reference conformer.

  • ref_atom_name_chars: [N_atoms, 4, 64] One-hot encoding of the unique atom names in the reference conformer.

    Each character is encoded as ord(c) - 32, and names are padded to length 4.

  • ref_space_uid: [N_atoms] Numerical encoding of the chain id and residue index associated with

    this reference conformer. Each (chain id, residue index) tuple is assigned an integer on first appearance.

(Optionally) The following custom features, helpful for extra conditioning:
  • ref_pos_is_ground_truth (optional): [N_atoms] Whether the reference conformer is the ground-truth conformer.

    Determined by the ground_truth_conformer_policy annotation.

  • ref_pos_ground_truth (optional): [N_atoms, 3] The ground-truth conformer positions.

    Determined by the ground_truth_conformer_policy annotation.

  • is_atomized_atom_level: [N_atoms] Whether the atom is atomized (atom-level version of “is_ligand”)

Reference:
atomworks.ml.transforms.af3_reference_molecule.random_apply_ground_truth_conformer_by_chain_type(atom_array: AtomArray, chain_type_probabilities: dict | None = None, default_probability: float = 0.0, policy: GroundTruthConformerPolicy = GroundTruthConformerPolicy.REPLACE, is_unconditional: bool = False) AtomArray[source]#

Apply ground truth conformer policy with configurable probabilities per chain type.

Adds the ground_truth_conformer_policy annotation to the AtomArray if it does not already exist. This annotation indicates if/how residues should use the ground-truth coordinates (i.e., the coordinates from the original structure) as the reference conformer.

Possible values are (as defined in the GroundTruthConformerPolicy enum):
  • REPLACE: Use the ground-truth coordinates as the reference conformer (replacing the RDKit-generated conformer in-place)

  • ADD: Use the ground-truth coordinates as an additional feature (rather than replacing the RDKit-generated conformer)

  • FALLBACK: Use the ground-truth coordinates only if our standard conformer generation pipeline fails (e.g., we cannot generate a conformer with RDKit,

    and the molecule is either not in the CCD or the CCD entry is invalid)

  • IGNORE: Do not use the ground-truth coordinates as the reference conformer, under any circumstances

Parameters:
  • atom_array (AtomArray) – The input atom array.

  • chain_type_probabilities (dict, optional) – Dictionary mapping chain types to their probability of using ground truth conformer. Defaults to None.

  • default_probability (float, optional) – Default probability for any chain type not explicitly specified. Defaults to 0.0.

  • policy (GroundTruthConformerPolicy, optional) – Which ground truth conformer policy to apply when selected. Defaults to GroundTruthConformerPolicy.REPLACE.

  • is_unconditional (bool, optional) – Whether we are sampling unconditionally (and thus should not apply the policy).

Returns:

The input atom array with the ground_truth_conformer_policy annotation updated.

Return type:

AtomArray

Transforms on atom arrays.

class atomworks.ml.transforms.atom_array.AddGlobalAtomIdAnnotation(allow_overwrite: bool = False)[source]#

Bases: Transform

Adds a global atom ID annotation to the atom array.

Useful for keeping track of atoms after cropping, slicing or shuffling operations.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['AddGlobalAtomIdAnnotation']#
class atomworks.ml.transforms.atom_array.AddGlobalResIdAnnotation[source]#

Bases: Transform

Adds a global residue ID annotation to the atom array.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['AddGlobalResIdAnnotation']#
class atomworks.ml.transforms.atom_array.AddGlobalTokenIdAnnotation[source]#

Bases: Transform

Adds a global token ID annotation token_id to the atom array.

Useful for keeping track of tokens after cropping, slicing or shuffling operations.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['AddGlobalTokenIdAnnotation']#
class atomworks.ml.transforms.atom_array.AddMoleculeSymmetricIdAnnotation[source]#

Bases: Transform

Adds the molecule_symmetric_id annotation to the AtomArray.

For a molecule, the symmetric_id is a unique integer within the set of molecules that share the same molecule_entity.

Example: - If molecule_entity 0 has 3 molecules, they will have symmetric_ids 0, 1, 2.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.atom_array.AddProteinTerminiAnnotation[source]#

Bases: Transform

Annotate protein termini (i.e. N- and C-terminus) for protein chains in the atom array.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['CropContiguousLikeAF3', 'CropSpatialLikeAF3']#
class atomworks.ml.transforms.atom_array.AddWithinChainInstanceResIdx[source]#

Bases: Transform

Add the within-chain instance residue index to the atom array (0-indexed).

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.atom_array.AddWithinPolyResIdxAnnotation[source]#

Bases: Transform

Adds the within_poly_res_idx (within polymer residue index) annotation to the AtomArray.

For polymers, the within_poly_res_idx is a zero-indexed, continuous residue index within the chain. For non-polymers, the within_poly_res_idx is set to -1. This annotation is later used to index into the MSA, as it remains consistent with MSA indices even after cropping the AtomArray.

Note

The within_poly_res_idx is zero-indexed, since it is used as an index into the MSA. In contrast, the res_id annotation (derived from the mmCIF file) is one-indexed. We generate within_poly_res_idx from scratch rather than inferring from res_id to avoid any mmCIF annotation errors.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['CropContiguousLikeAF3', 'CropSpatialLikeAF3']#
class atomworks.ml.transforms.atom_array.ApplyFunctionToAtomArray(func: Callable[[AtomArray], AtomArray])[source]#

Bases: Transform

Apply a function to the atom array.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.atom_array.ComputeAtomToTokenMap[source]#

Bases: Transform

Add length [n_atom] array to the feats dictionary that indicates the token_id for each atom.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = ['AddGlobalTokenIdAnnotation']#
class atomworks.ml.transforms.atom_array.CopyAnnotation(annotation_to_copy: str, new_annotation: str)[source]#

Bases: Transform

Copies an existing annotation from the AtomArray and assigns it a new name.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.atom_array.RaiseIfTooManyAtoms(max_atoms: int)[source]#

Bases: Transform

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.atom_array.RenumberNonPolymerResidueIdx[source]#

Bases: Transform

Re-numbers non-polymer residue indices to be one-indexed, similar to polymer residues.

This transformation ensures that non-polymer residue indices start from 1, providing a consistent indexing scheme across both polymer and non-polymer residues. It addresses the issue where non-polymer residue indices may start at “101”, which can lead to non-deterministic behavior.

Note

The renumbering is applied to each non-polymer chain independently, ensuring that the indices are continuous and start from 1 for each chain.

Returns:

The updated data dictionary containing the modified atom_array with renumbered

non-polymer residue indices.

Return type:

  • data (dict)

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.atom_array.SortLikeRF2AA[source]#

Bases: Transform

Sort the atom array in 3 groups (in this order). Within each group the atoms are ordered by their pn_unit_iid (and within a pn_unit their order is preserved).

    1. polymer atoms

    1. non-poly atoms of a pn-unit bonded to a polymer (covalent modifications)

    1. non-poly atoms of a free-floating pn-unit (free-floating ligands)

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['EncodeAtomArray', 'CropSpatialLikeAF3', 'CropContiguousLikeAF3']#
requires_previous_transforms: ClassVar[list[str | Transform]] = ['AtomizeByCCDName']#
class atomworks.ml.transforms.atom_array.SortPolyThenNonPoly(treat_atomized_as_non_poly: bool = True)[source]#

Bases: Transform

Sort the atom array such that polymer chains are first, followed by non-polymer chains.

The order within the poly and non_poly chains is preserved.

This transformation is useful for models like RF2AA, which expect the input to be formatted as [polys, non-polys].

Parameters:

treat_atomized_as_non_poly (-) – If True, atomized structures are treated as non-polymer. Defaults to True.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['EncodeAtomArray', 'CropSpatialLikeAF3', 'CropContiguousLikeAF3']#
atomworks.ml.transforms.atom_array.add_global_atom_id_annotation(atom_array: AtomArray) AtomArray[source]#

Adds a global atom ID annotation atom_id to the atom array.

This annotation is useful for tracking atoms after operations such as cropping, slicing, or shuffling. The atom_id is generated as a sequence of integers corresponding to the number of atoms in the atom array.

Parameters:

atom_array (AtomArray) – The AtomArray to which the atom ID annotation will be added.

Returns:

The AtomArray with the added atom_id annotation.

Return type:

AtomArray

atomworks.ml.transforms.atom_array.add_global_res_id_annotation(atom_array: AtomArray) AtomArray[source]#

Add a global residue ID annotation to the atom array.

atomworks.ml.transforms.atom_array.add_global_token_id_annotation(atom_array: AtomArray) AtomArray[source]#

Adds a global token ID annotation token_id to the atom array.

This annotation is useful for tracking tokens after operations such as cropping, slicing, or shuffling. The token_id is generated as a sequence of integers corresponding to the number of tokens in the atom array, and is spread across the atom array to maintain the association with each atom.

Parameters:

atom_array (AtomArray) – The AtomArray to which the token ID annotation will be added.

Returns:

The AtomArray with the added token_id annotation.

Return type:

AtomArray

atomworks.ml.transforms.atom_array.add_protein_termini_annotation(atom_array: AtomArray) AtomArray[source]#

Adds the annotation is_N_terminus and is_C_terminus to the respective residues in the atom array.

Parameters:

atom_array (AtomArray) – The AtomArray that the annotations will be added to

Returns:

The AtomArray with is_N_terminus and is_C_terminus annotations

Return type:

AtomArray

atomworks.ml.transforms.atom_array.apply_and_spread_chain_wise(atom_array: AtomArray, data: ndarray, function: Callable[[ndarray], generic], axis: int | None = None) ndarray[source]#

Apply a function chain wise and then spread the result to the atoms.

atomworks.ml.transforms.atom_array.apply_and_spread_residue_wise(atom_array: AtomArray, data: ndarray, function: Callable[[ndarray], generic], axis: int | None = None) ndarray[source]#

Apply a function residue wise and then spread the result to the atoms.

atomworks.ml.transforms.atom_array.atom_id_to_atom_idx(atom_array: AtomArray, atom_id: int) int[source]#

Convert an atom ID to an atom index in the given array.

atomworks.ml.transforms.atom_array.atom_id_to_token_idx(atom_array: AtomArray, atom_id: int) int[source]#

Convert an atom ID to a token index in the given array.

atomworks.ml.transforms.atom_array.chain_instance_iter(array: AtomArray) Iterator[AtomArray][source]#

Returns an iterator over the chain instances (chain_iid) in the atom array.

This will match biotite.structure.chain_iter in the case where there are no transformations.

atomworks.ml.transforms.atom_array.compute_atom_to_token_map(atom_array: AtomArray) dict[source]#
atomworks.ml.transforms.atom_array.copy_annotation(atom_array: AtomArray, annotation_to_copy: str, new_annotation: str) AtomArray[source]#

Copies an existing annotation from the AtomArray and assigns it a new name.

Particularly useful for scenarios such as diffusive training, where the new annotation is altered (e.g., adding noise) without affecting the ground truth data.

Parameters:
  • atom_array (AtomArray) – The AtomArray object containing the annotations.

  • annotation_to_copy (str) – The name of the annotation to be copied.

  • new_annotation (str) – The name for the new annotation.

Returns:

The AtomArray with the newly added annotation.

Return type:

AtomArray

Example

updated_atom_array = copy_annotation(atom_array, “coord”, “coord_to_be_noised”)

atomworks.ml.transforms.atom_array.get_chain_instance_starts(array: AtomArray, add_exclusive_stop: bool = False) ndarray[source]#

Get indices for an atom array, each indicating the beginning of a new chain instance (chain_iid).

Inspired by biotite.strucutre.get_chain_starts.

Args: - atom_array (AtomArray): The atom array to get the chain_iid starts from. - add_exclusive_stop (bool, optional): If True, add an exclusive stop to the chain_iid starts for the last chain instance. Defaults to False.

Returns: - np.ndarray: An array of indices indicating the beginning of each chain instance.

atomworks.ml.transforms.atom_array.get_within_entity_idx(atom_array: AtomArray, level: Literal['chain', 'pn_unit', 'molecule']) tuple[ndarray, list[ndarray]][source]#
Get the within-entity instance index for the atom array.
  • Allowed levels are “chain”, “pn_unit”, or “molecule”.

  • Entities do not need to be contiguous.

  • Entities are defined by the unique values of the {level}_entity annotation.

Parameters:
  • atom_array (-) – The atom array to process.

  • level (-) – The level at which to calculate the within-entity index.

Returns:

An array of within-entity instance indices for each atom in the atom array.

Return type:

  • np.ndarray

Example

>>> import biotite.structure as struc
>>> atom_array = struc.AtomArray(7)
>>> atom_array.set_annotation("chain_iid", ["A", "A", "B", "C", "D", "D", "E"])
>>> atom_array.set_annotation("chain_entity", ["1", "1", "1", "1", "2", "2", "2"])
>>> iids, within_entity_idx = get_within_entity_idx(atom_array, level="chain")
>>> print(within_entity_idx)
[0 0 1 2 0 0 1]
>>> print(iids)
['A' 'B' 'C'] ['D' 'E']
atomworks.ml.transforms.atom_array.get_within_group_atom_idx(atom_array: AtomArray, group_by: str) ndarray[source]#

Get the within-group atom index for the atom array.

Of note:
  • Groups do not need to be contiguous.

  • Groups are defined by the unique values of the group_by annotation.

atomworks.ml.transforms.atom_array.get_within_group_res_idx(atom_array: AtomArray, group_by: str) ndarray[source]#

Get the within-group residue index for the atom array.

Of note:
  • Groups do not need to be contiguous.

  • Groups are defined by the unique values of the group_by annotation.

atomworks.ml.transforms.atom_array.get_within_poly_res_idx(atom_array: AtomArray) ndarray[source]#
atomworks.ml.transforms.atom_array.sort_like_rf2aa(atom_array: AtomArray) AtomArray[source]#

Sort the atom array such that non-polymer chains are sorted by their covalent bonds and PN unit IIDs.

atomworks.ml.transforms.atom_array.sort_poly_then_non_poly(atom_array: AtomArray, treat_atomized_as_non_poly: bool = True) AtomArray[source]#

Sort the atom array such that polymer chains are first, followed by non-polymer chains.

The order within the poly and non_poly chains is preserved.

This function is useful for ensuring that models like RF2AA, which expect the input to be formatted as [polys, non-polys], receive the correctly ordered atom array.

Parameters:
  • atom_array (-) – The AtomArray to be sorted.

  • treat_atomized_as_non_poly (-) – If True, atomized structures are treated as non-polymer. Defaults to True.

Returns:

The sorted AtomArray with polymer chains first, followed by non-polymer chains.

Return type:

AtomArray

Transforms to handle the assignment of RF2AA’s atom frames

class atomworks.ml.transforms.atom_frames.AddAtomFrames(order_independent_atom_frame_prioritization: bool = True)[source]#

Bases: Transform

Add atom frames to the data dictionary. See the RF2AA supplement for more details.

NOTE: We do not assume that all atomized residues are at the end of the AtomArray to allow for more flexibility in the future.

Parameters:

order_independent_atom_frame_prioritization (bool, optional) – If True, sorts atom types within frames to consider them order-independent. Defaults to True.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = [<class 'atomworks.ml.transforms.atomize.AtomizeByCCDName'>, <class 'atomworks.ml.transforms.encoding.EncodeAtomArray'>]#
class atomworks.ml.transforms.atom_frames.AddIsRealAtom(token_encoding: TokenEncoding)[source]#

Bases: Transform

Makes a faux version of is_real_atom that we previously derived from the ChemData heavy atom mask. Determines how many atoms are in each residue based on the atom array to accomodate terminal oxygens etc… This mask is used in the pLDDT calculation, where it is used to mask pLDDT logits in the [B,I,Max_N_Atoms] representation. Uses the atom_to_token_map to determine the number of atoms in each residue, outputting a boolean mask in [I,36] format. This can accomodate up to 36 atoms per residue, as the RF2aa is_real_atom object has 36 atoms per residue. In AF3, the maximum number of atoms per residue is 23, and this tensor is truncated in the pLDDT calculation.

Adds:
  • ‘is_real_atom’: torch.Tensor of shape [I, 36] (bool)

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = [<class 'atomworks.ml.transforms.atom_array.ComputeAtomToTokenMap'>, <class 'atomworks.ml.transforms.filters.RemoveTerminalOxygen'>, <class 'atomworks.ml.transforms.filters.RemoveNucleicAcidTerminalOxygen'>]#
class atomworks.ml.transforms.atom_frames.AddPolymerFrameIndices[source]#

Bases: Transform

Adds indices for the atoms that will constitute the backbone frames for non-ligands. Adds an I,3 tensor, where the first index is the atom index of the N, the second is the atom index of the CA, and the third is the atom index of the C for protein. Follows the AF3 pattern for nucleic acids. For ligands and noncanonicals (ie anything atomized), this functions adds the index of each atom to the CA position.

Adds:
  • ‘frame_idxs’: torch.Tensor of shape [I, 3] (long)

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

atomworks.ml.transforms.atom_frames.find_all_paths_of_length_n(graph: Graph, n: int, order_independent_atom_frame_prioritization: bool = True) list[source]#

Find all paths of a given length n in a NetworkX graph.

Parameters: G (nx.Graph): The input graph. n (int): The length of the paths to find. order_independent_frame_prioritization (bool, optional):

If True, considers paths with the same nodes but in different orders as equivalent. Defaults to True.

Returns: np.ndarray: A tensor containing all unique paths of length n.

Reference:

https://stackoverflow.com/questions/28095646/finding-all-paths-walks-of-given-length-in-a-networkx-graph’’’

atomworks.ml.transforms.atom_frames.get_rf2aa_atom_frames(encoded_query_pn_unit: ndarray, graph: Graph, order_independent_atom_frame_prioritization: bool = True) Tensor[source]#

Choose a frame of 3 bonded atoms for each atom in the molecule, using a rule-based system that prioritizes frames based on atom types.

Parameters:
  • encoded_query_pn_unit (torch.Tensor) – Sequence of the pn_unit that we want to build frames for, encoded using the RF2AA TokenEncoding.

  • G (nx.Graph) – The input graph representing the non-polymer molecule.

  • order_independent_frame_prioritization (bool, optional) – If True, sorts atom types within frames to consider them order-independent. Defaults to True.

Returns:

A tensor containing the selected frames for each atom.

Return type:

torch.Tensor

class atomworks.ml.transforms.atomize.AtomizeByCCDName(atomize_by_default: bool, res_names_to_atomize: list[str] | None = None, res_names_to_ignore: list[str] | None = None, move_atomized_part_to_end: bool = False, validate_atomize: bool = False)[source]#

Bases: Transform

Atomize residues by breaking down the CCD res_name field into the actual element names.

NOTE: Both polymers AND non-polymers are considered “residues” by the CCD, and have a corresponding res_name.

This transform allows for the atomization of residues in an AtomArray by breaking down the residue names into their constituent atoms. It provides options to atomize residues by default, specify residues to atomize or ignore, and move atomized parts to the end of the array. It must be run before any transforms that rely on the tokens established during atomization, such as AddTokenBondAdjacency.

- atomize_by_default

Whether to atomize residues by default.

Type:

bool

- res_names_to_atomize

List of residue names to atomize.

Type:

list[str]

- res_names_to_ignore

List of residue names to ignore.

Type:

list[str]

- move_atomized_part_to_end

Whether to move atomized parts to the end of the array. This is done e.g. in RF2AA, when atomizing polymer residues covalently bound to a ligand.

Type:

bool

Raises:
  • ValueError – If a residue name appears in both res_names_to_atomize and res_names_to_ignore.

  • ValueError – If some atoms in a residue are atomized and some are not.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['AddTokenBondAdjacency', 'AddRF2AAChiralFeatures', 'AddGlobalTokenIdAnnotation', 'AtomizeByCCDName', 'EncodeAtomArray']#
class atomworks.ml.transforms.atomize.FlagNonPolymersForAtomization[source]#

Bases: Transform

Flag all non-polymer residues for atomization.

This is relevant for examples such as 6w12, which have a protein residue outside of a polymer (e.g. an individual SER bonded to a sugar in 6w12)

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['AtomizeByCCDName']#
atomworks.ml.transforms.atomize.atomize_by_ccd_name(atom_array: AtomArray, atomize_by_default: bool = True, res_names_to_atomize: list[str] = [], res_names_to_ignore: list[str] = [], move_atomized_part_to_end: bool = False, validate_atomize: bool = False) AtomArray[source]#

Atomize residues by breaking down the res_name field into the actual element names.

Parameters:
  • atom_array (AtomArray) – The atom array to atomize.

  • atomize_by_default (bool) – Whether to atomize residues by default.

  • res_names_to_atomize (list[str]) – List of residue names to atomize. Defaults to [].

  • res_names_to_ignore (list[str]) – List of residue names to ignore. These residues will only be atomized, if their atomize flag is already explicitly set to True, e.g. from a previous transform to sample random residues for atomization for data augmentation. Defaults to [].

  • move_atomized_part_to_end (bool, optional) – Whether to move atomized parts to the end of the array. Defaults to False. This is relevant for RF2AA, which follows the convention that atomized parts are grouped together at the end of the input.

  • validate_atomize (bool, optional) – Whether to validate that a residue is either atomized or not. Defaults to False.

Returns:

The atomized atom array. The atomize flag is set for each atom in the array.

NOTE: The returned array may be reordered if move_atomized_part_to_end is True.

Return type:

AtomArray

class atomworks.ml.transforms.bonds.AddAF3TokenBondFeatures(distance_cutoff: float = 2.4)[source]#

Bases: Transform

Transform that generates AF3-style token bond features for an AtomArray.

This transform creates a 2D matrix indicating if there is a bond between any atom in token i and token j, restricted to just polymer-ligand and ligand-ligand bonds and bonds less than a specified distance cutoff.

Parameters:

distance_cutoff (-) – The maximum distance (in Angstroms) for considering a bond. Defaults to 2.4.

Returns:

A dictionary containing the input data and the new ‘af3_token_bond_features’ key with

the computed boolean matrix.

Return type:

  • dict

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = ['AtomizeByCCDName']#
class atomworks.ml.transforms.bonds.AddAtomLevelBondAdjacencyMatrix[source]#

Bases: Transform

Adds the atom-level bond adjacency matrix to the data as a feature.

This transform uses Biotite’s adjacency_matrix() function to create a binary matrix where element (i, j) is 1 if atoms i and j are bonded, and 0 otherwise.

The matrix is added to the data dictionary under data[“feats”][“atom_level_bond_adjacency”].

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.bonds.AddRF2AABondFeaturesMatrix[source]#

Bases: Transform

Adds a matrix indicating the RF2AA bond type between two nodes to the data. This transform builds from the Biotite bond type, modifying as needed for residue-residue and residue-atom mappings. We then add the matrix to the data dictionary under the key rf2aa_bond_features_matrix.

From the RF2AA supplement, Supplementary Methods Table 8: Inputs to RFAA:#

bond_feats | (L, L, 7) Pairwise bond adjacency matrix. Pairs of residues are either

single, double, triple, aromatic, residue-residue, residue-atom, or other.


Specifically, we map to the following enum, as described in ChemData:
  • 0 = No bonds

  • 1 = Single bond

  • 2 = Double bond

  • 3 = Triple bond

  • 4 = Aromatic

  • 5 = Residue-residue

  • 6 = Residue-atom

  • 7 = Other

We build the matrix from the Biotite bond types. The Biotite BondType enum contains the following mapping:

  • ANY = 0

  • SINGLE = 1

  • DOUBLE = 2

  • TRIPLE = 3

  • QUADRUPLE = 4

  • AROMATIC_SINGLE = 5

  • AROMATIC_DOUBLE = 6

  • AROMATIC_TRIPLE = 7

The the index -1 is used for non-bonded interactions.

Reference: - Biotite documentation (https://www.biotite-python.org/apidoc/biotite.structure.BondType.html#biotite.structure.BondType)

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = [<class 'atomworks.ml.transforms.atomize.AtomizeByCCDName'>, <class 'atomworks.ml.transforms.bonds.AddTokenBondAdjacency'>]#
class atomworks.ml.transforms.bonds.AddRF2AATraversalDistanceMatrix[source]#

Bases: Transform

Generates a matrix indicating the minimum amount of bonds to traverse between two nodes. We define the traversal distance between two protein nodes as zero. Sets the “traversal_distance_matrix” key in the data dictionary.

From the RF2AA supplement, Supplementary Methods Table 8: Inputs to RFAA:#

dist_matrix | (L, L) Minimum amount of bonds to traverse between two nodes.

This is 0 between all protein nodes.


check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.bonds.AddTokenBondAdjacency[source]#

Bases: Transform

Adds the token bond adjacency matrix to the data.

This transform computes the token bond adjacency matrix from the atom bond adjacency matrix and adds it to the data dictionary under the key token_bond_adjacency.

The token bond adjacency matrix is a binary [n_tokens, n_tokens] matrix where element (i, j) is True if there is at least one bond between any atom in token i and any atom in token j, and False otherwise.

Depends on the definition of tokens and therefore has to be applied after any transform that alters what is considered a token (e.g. AtomizeByCCDName) or that changes the order or number of tokens. By default, a token is defined as a residue in the input AtomArray.

Raises:

AssertionError – If the input data does not contain the required keys or types.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = [<class 'atomworks.ml.transforms.atomize.AtomizeByCCDName'>]#
atomworks.ml.transforms.bonds.get_af3_token_bond_features(atom_array: AtomArray, distance_cutoff: float = 2.4) ndarray[source]#

Generates AF3-style token bond features for an AtomArray. For bonds between multi-atom tokens (i.e., residues), we define the “bond distance” as the minimum distance between an atom of one token and any atom of the other token.

From AF3:

Returns a 2D matrix indicating if there is a bond between any atom in token i and token j, restricted to just polymer-ligand and ligand-ligand bonds and bonds less than 2.4 Å during training.

Parameters:
  • atom_array (-) – The input AtomArray containing atomic coordinates and bond information.

  • distance_cutoff (-) – The maximum distance (in Angstroms) for considering a bond. Defaults to 2.4.

Returns:

A boolean matrix where True indicates a bond between tokens that meets the specified criteria.

Return type:

  • np.ndarray

atomworks.ml.transforms.bonds.get_bond_distance_matrix(atom_array: AtomArray) ndarray[source]#

Returns the bond adjacency matrix with bond distances as values.

atomworks.ml.transforms.bonds.get_bond_distances(atom_array: AtomArray) ndarray[source]#

Returns the bond distance (adjacency) list as a 1D array.

atomworks.ml.transforms.bonds.get_token_bond_adjacency(atom_array: AtomArray) ndarray[source]#

Computes the token bond adjacency matrix from the atom bond adjacency matrix.

This is done by performing a block-wise reduction of the atom adjacency matrix, where block (i, j) is the sub-matrix of the atom adjacency matrix for bonds between atoms of token i and j. The reduction is performed by np.any, which returns True if at least one bond exists between the two tokens.

Transforms relating to stereochemistry (chirality)

class atomworks.ml.transforms.chirals.AddAF3ChiralFeatures(take_first_chiral_subordering: bool = True)[source]#

Bases: Transform

Adds chiral features into the feats dictionary.

Adds the following features to the data dictionary under the ‘feats’ key:
  • chiral_feats: [N_chiral_centers, 5] A listing of chiral centers of the format:

    tensor([[ 5., 1., 2., 3., 0.61546…],…])

    Here, the first 4 columns define atom indices of chiral center; the 5th is target dihedral

Metadata from GetRDKitChiralCenters, held in the “chiral_centers” key, is needed for this transform.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = ['GetRDKitChiralCenters']#
class atomworks.ml.transforms.chirals.AddRF2AAChiralFeatures[source]#

Bases: Transform

AddRF2AAChiralFeatures adds chiral features to the atom array data under the “chiral_feats” key. Chiral centers are taken from data[“chiral_centers”], which is a list of dictionaries, of the form:

{“chiral_center_atom_id”: int, “bonded_explicit_atom_ids”: list[int]}

This metadata can be added by running e.g. the AddOpenBabelMoleculesForAtomizedMolecules and GetChiralCentersFromOpenBabel transforms.This transform also requires the AtomizeByCCDName transform to be applied previously to ensure the atom array is properly atomized.

Parameters:

data (dict[str, Any]) – A dictionary containing the input data, including the atom array and chiral centers.

Returns:

The updated data dictionary with the added chiral features under the “chiral_feats” key.

Return type:

dict[str, Any]

Example

data = {

“atom_array”: atom_array, “chiral_centers”: [

{

“chiral_center_atom_id”: 5, “bonded_explicit_atom_ids”: [1, 2, 3, 4]

}, {

“chiral_center_atom_id”: 10, “bonded_explicit_atom_ids”: [6, 7, 8, 9]

}

]

}

transform = AddRF2AAChiralFeatures() result = transform.forward(data)

print(result[“chiral_feats”]) # Output might look like: # (assuming the atom_id s above also correspond to the indices in the atom array, # otherwise the first 4 columns look different as they are the indices in the atom array) # tensor([[ 5., 1., 2., 3., 0.61546…], # [ 5., 2., 3., 4., -0.61546…], # … # [10., 7., 8., 9., -0.61546…]])

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = ['AtomizeByCCDName']#
atomworks.ml.transforms.chirals.add_af3_chiral_features(atom_array: AtomArray, chiral_centers: dict, rdkit_mols: dict[str, Mol], take_first_chiral_subordering: bool = True) Tensor[source]#

Computes chiral features from atom array, chiral centers, and RDKit molecules.

See AddAF3ChiralFeatures for more details.

atomworks.ml.transforms.chirals.get_dih(a: Tensor, b: Tensor, c: Tensor, d: Tensor, eps: float = 0.0001) Tensor[source]#

Calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i]) given Cartesian coordinates of four sets of atoms a,b,c,d

Copied from rf2aa.kinematics.get_dih to decouple the transform from the rf2aa package.

Parameters:
  • a – PyTorch tensors of shape [batch,nres,3] that store Cartesian coordinates of four sets of atoms

  • b – PyTorch tensors of shape [batch,nres,3] that store Cartesian coordinates of four sets of atoms

  • c – PyTorch tensors of shape [batch,nres,3] that store Cartesian coordinates of four sets of atoms

  • d – PyTorch tensors of shape [batch,nres,3] that store Cartesian coordinates of four sets of atoms

Returns:

pytorch tensor of shape [batch,nres] that stores resulting the dihedrals

Return type:

dih

atomworks.ml.transforms.chirals.get_rf2aa_chiral_features(chiral_centers: list[dict], coords: ndarray, take_first_chiral_subordering: bool = True) Tensor[source]#

Extracts chiral centers and featurize them for RF2AA.

NOTE: Each row of output features contains the indices of the plane pairs and the signed ideal
dihedral angle for each chiral center. For example, the entry:

[c, i, j, k, angle]

means that the atom at index c is a chiral center with atoms at indices (i, j, k) bonded to it. The signed dihedral angle angle is the signed angle between the planes (cij) and (ijk). The sign of the angle determines the chirality of the chiral center.

NOTE: Each chiral center will result in more than one feature. In particular:
  • 3 features if one of the 4 atoms bonded to the chiral center is an implicit hydrogen (as

    we do not look at any pair of planes where one plane contains an implicit hydrogen).

  • 12 features if all 4 atoms bonded to the chiral center are explicit atoms.

In RF2AA this is used to compute the angles for all unique pairs of planes in the center that are explicitly modeled (hydrogens are implicit), measure their error from the ideal tetrahedron in the unit sphere, and pass the gradients of the error in predicted angles with respect to the predicted coordinates into the subsequent blocks as vector input features in the SE(3)-Transformer which breaks the symmetry over reflections present in the rest of the network and allows the network to iteratively refine predictions to match ideal tetrahedral geometry.

Parameters:
  • chiral_centers (list[dict]) – A list of dictionaries, of the form: {“chiral_center_idx”: int, “bonded_explicit_atom_idxs”: list[int]} where chiral_center_idx is the index of the chiral center atom, and bonded_explicit_atom_idxs is a list of the indices of the atoms bonded to the chiral center (excluding implicit hydrogens).

  • coords (np.ndarray) – A numpy array of atomic coordinates.

  • take_first_chiral_subordering (bool) – If True, only the first subordering is considered (when four bonded non-hydrogen atoms are present). If False, all orderings are considered (leading to 12 unique plane pairs in the case of 4 bonded atoms, or 3 unique plane pairs in the case of 3 bonded atoms).

Returns:

A tensor of shape [n_chirals, 5] where each row contains the indices of the plane pairs

and the signed ideal dihedral angle for each chiral center. The sign of the dihedral angle determines the chirality of the chiral center (+1 for clockwise, -1 for counterclockwise). If no stereocenters are found, returns an empty tensor of shape [0, 5].

Return type:

torch.Tensor

Transforms to handle covalent modifications

class atomworks.ml.transforms.covalent_modifications.FlagAndReassignCovalentModifications[source]#

Bases: Transform

Handles covalent modifications within the AtomArray.

Covalent modifications, e.g., glycosylation, are handled by the following algorithm:#

for polymer residues with atoms covalently bound to a NON-POLYMER:
for ALL atoms in the polymer residue:

set the pn_unit_iid and pn_unit_id identifying annotations to that of the NON-POLYMER polymer/non-polymer unit set atomize = true (thus, this transform must be run before the Atomize transform) set is_covalent_modification = true (for the entire pn_unit)


check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = [<class 'atomworks.ml.transforms.atomize.AtomizeByCCDName'>, 'AddGlobalTokenIdAnnotation']#
atomworks.ml.transforms.covalent_modifications.flag_and_reassign_covalent_modifications(atom_array: AtomArray) AtomArray[source]#

Mark covalent modifications for atomization and reassign the corresponding PN unit annotations.

Parameters:

atom_array (AtomArray) – Current AtomArray within the Transform pipeline

Returns:

The modified AtomArray with updated annotations for covalent

modifications. The pn_unit_id and pn_unit_iid of polymer atoms are reassigned to those of the non-polymer unit they are bound to, and the atomize annotation is set to True for these atoms. Additionally, the entire pn_unit is marked with is_covalent_modification = True.

Return type:

AtomArray

NOTE: If atomize annotation is not present in the AtomArray, it will be added. NOTE: If is_covalent_modification annotation is not present in the AtomArray, it will be added. NOTE: We do not modify the is_polymer annotation, which will still refer to the protein chain for the atomized polymer atoms.

class atomworks.ml.transforms.crop.CropContiguousLikeAF3(crop_size: int, keep_uncropped_atom_array: bool = False, max_atoms_in_crop: int | None = None, **kwargs)[source]#

Bases: CropTransformBase

A transform that performs contiguous cropping similar to AF3.

This class implements the contiguous cropping procedure as described in AF3. It selects a crop center from a contiguous region of the atom array and samples a crop around this center.

WARNING: This transform is probabilistic if the atom array is larger than the crop size!

References

crop_size#

The maximum number of tokens to crop.

Type:

int

keep_uncropped_atom_array#

Whether to keep the uncropped atom array in the data. If True, the uncropped atom array will be stored in the crop_info dictionary under the key “atom_array”. Defaults to False.

Type:

bool

max_atoms_in_crop#

Maximum number of atoms allowed in a crop. If None, no resizing is performed.

Type:

int | None

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['EncodeAtomArray', 'CropSpatialLikeAF3', 'CropContiguousLikeAF3', 'PlaceUnresolvedTokenOnClosestResolvedTokenInSequence']#
requires_previous_transforms: ClassVar[list[str | Transform]] = ['AtomizeByCCDName']#
class atomworks.ml.transforms.crop.CropSpatialLikeAF3(crop_size: int, jitter_scale: float = 0.001, crop_center_cutoff_distance: float = 15.0, keep_uncropped_atom_array: bool = False, force_crop: bool = False, max_atoms_in_crop: int | None = None, raise_if_missing_query: bool = True, **kwargs)[source]#

Bases: CropTransformBase

A transform that performs spatial cropping similar to AF3 and AF2 Multimer.

This class implements the spatial cropping procedure as described in AF3. It selects a crop center from a spatial region of the atom array and samples a crop around this center.

WARNING: This transform is probabilistic if the atom array is larger than the crop size!

References

crop_size#

The maximum number of tokens to crop. Must be greater than 0.

Type:

int

jitter_scale#

The scale of the jitter to apply to the crop center. This is to break ties between atoms with the same spatial distance. Defaults to 1e-3.

Type:

float

crop_center_cutoff_distance#

The cutoff distance to consider for selecting crop centers. Measured in Angstroms. Defaults to 15.0.

Type:

float

keep_uncropped_atom_array#

Whether to keep the uncropped atom array in the data. If True, the uncropped atom array will be stored in the crop_info dictionary under the key “atom_array”. Defaults to False.

Type:

bool

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['EncodeAtomArray', 'CropContiguousLikeAF3', 'CropSpatialLikeAF3', 'PlaceUnresolvedTokenOnClosestResolvedTokenInSequence']#
requires_previous_transforms: ClassVar[list[str | Transform]] = ['AddGlobalAtomIdAnnotation', 'AtomizeByCCDName']#
class atomworks.ml.transforms.crop.CropTransformBase(annotate_crop_boundary: bool = False, crop_boundary_radius: float = 6.0, **kwargs)[source]#

Bases: Transform

Base class for crop-type transforms.

atomworks.ml.transforms.crop.compute_local_hash(atom_array: AtomArray, radius: float = 6.0) ndarray[source]#

Compute a local hash for each atom in the atom array.

Currently, the hash is the number of neighbours within a given radius.

Parameters:
  • atom_array (AtomArray) – The atom array to compute the local hash for.

  • radius (float) – The radius to use for the local hash.

Returns:

A numpy array of shape (n_atoms,) containing the local hash for each atom.

Return type:

np.ndarray

atomworks.ml.transforms.crop.crop_contiguous_af2_multimer(iids: list[int | str], instance_lens: list[int], crop_size: int) dict[source]#

Crop contiguous tokens from the given instances to reach the given crop size probabilistically.

Implements the crop_contiguous (algorithm 1 in section 7.2.1) of AF2 Multimer and section 2.7.2 of AF3.

Parameters:
  • iids (list[int | str]) – List of instance identifiers.

  • instance_lens (list[int]) – List of lengths corresponding to each instance.

  • crop_size (int) – Desired number of tokens to crop. Must be greater than 0.

Returns:

Dictionary mapping instance identifiers

(iids) to crop masks (i.e. boolean arrays) indicating which tokens to keep.

Return type:

keep_tokens (dict[int | str, np.ndarray[bool]])

References

Example

>>> iids = [1, 2, 3]
>>> instance_lens = [3, 4, 2]
>>> crop_size = 5
>>> result = crop_contiguous_af2_multimer(iids, instance_lens, crop_size)
>>> print(result)
# Output might look like (probabilistic!):
# {
#     3: array([False, True]),
#     2: array([False, True, True, False]),
#     1: array([True, True, False])
# }
atomworks.ml.transforms.crop.crop_spatial_like_af3(atom_array: AtomArray, query_pn_unit_iids: list[str], crop_size: int, jitter_scale: float = 0.001, crop_center_cutoff_distance: float = 15.0, force_crop: bool = False, raise_if_missing_query: bool = True) dict[source]#

Crop spatial tokens around a given crop_center by keeping the crop_size nearest neighbors (with jitter).

Parameters:
  • atom_array (-) – The atom array to crop.

  • query_pn_unit_iids (-) – List of query polymer/non-polymer unit instance IDs.

  • crop_size (-) – The maximum number of tokens to crop.

  • jitter_scale (-) – Scale of jitter to apply when calculating distances. Defaults to 1e-3.

  • crop_center_cutoff_distance (-) – Maximum distance from query units to consider for crop center. Defaults to 15.0 Angstroms.

  • force_crop (-) – Whether to force crop even if the atom array is already small enough. Defaults to False.

  • raise_if_missing_query (-) – Whether to raise an Exception if no crop centers are found, e.g. if the query pn_unit(s) are not present due to a previous filtering step. Defaults to True. If False, a random pn_unit will be selected for the crop center.

Returns:

A dictionary containing crop information, including:
  • requires_crop (bool): Whether cropping was necessary.

  • crop_center_atom_id (int or np.nan): ID of the atom chosen as crop center.

  • crop_center_atom_idx (int or np.nan): Index of the atom chosen as crop center.

  • crop_center_token_idx (int or np.nan): Index of the token containing the crop center.

  • crop_token_idxs (np.ndarray): Indices of tokens included in the crop.

  • crop_atom_idxs (np.ndarray): Indices of atoms included in the crop.

Return type:

dict

Note

This function implements the spatial cropping procedure as described in AlphaFold 3 and AlphaFold 2 Multimer.

References

atomworks.ml.transforms.crop.get_spatial_crop_center(atom_array: AtomArray, query_pn_unit_iids: list[str], cutoff_distance: float = 15.0, raise_if_missing_query: bool = True) ndarray[source]#

Sample a crop center from a spatial region of the atom array.

Implements the selection of a crop center as described in AF3. ```

In this procedure, polymer residues and ligand atoms are selected that are within close spatial distance of an interface atom. The interface atom is selected at random from the set of token centre atoms (defined in subsection 2.6) with a distance under 15 Å to another chain’s token centre atom. For examples coming out of the Weighted PDB or Disordered protein PDB complex datasets, where a preferred chain or interface is provided (subsection 2.5), the reference atom is selected at random from interfacial token centre atoms that exist within this chain or interface.

```

Parameters:
  • atom_array (AtomArray) – The array containing atom information.

  • query_pn_unit_iids (list[str]) – List of PN unit instance IDs to query.

  • cutoff_distance (float, optional) – The distance cutoff to consider for spatial proximity. Defaults to 15.0.

  • raise_if_missing_query (bool) – Whether to raise an Exception if no crop centers are found, e.g. if the query pn_unit(s) are not present due to a previous filtering step. Defaults to True. If False, a random pn_unit will be selected for the crop center.

Returns:

A boolean mask indicating the crop center.

Return type:

np.ndarray

atomworks.ml.transforms.crop.get_spatial_crop_mask(coord: ndarray, crop_center_idx: int, crop_size: int, jitter_scale: float = 0.001) ndarray[source]#

Crop spatial tokens around a given crop_center by keeping the crop_size nearest neighbors (with jitter).

Implements the crop_spatial (algorithm 2 in section 7.2.1) of AF2 Multimer and AF3

Parameters:
  • coord (np.ndarray) – A 2D numpy array of shape (N, 3) representing the 3D token-level coordinates. Coordinates are expected to be in Angstroms.

  • crop_center_idx (int) – The index of the token to be used as the center of the crop.

  • crop_size (int) – The number of nearest neighbors to include in the crop.

  • jitter_scale (float) – The scale of the jitter to add to the coordinates.

Returns:

A boolean mask of shape (N,) where True indicates that the token is within the crop.

Return type:

crop_mask (np.ndarray)

References

Example

>>> coord = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])
>>> crop_center_idx = 1
>>> crop_size = 2
>>> crop_mask = get_spatial_crop_mask(coord, crop_center_idx, crop_size)
>>> print(crop_mask)
[ True  True False False]
atomworks.ml.transforms.crop.resize_crop_info_if_too_many_atoms(crop_info: dict, atom_array: AtomArray, max_atoms: int) dict[source]#

Resizes crops that exceed the maximum allowed number of atoms by removing tokens based on the distance to the crop center. If no crop center is provided, the center of mass of the tokens in the crop is used as center.

NOTE: This is mostly needed for AF3 when crops on nucleic acids have too many atoms to work with the current atom local attention when training on GPUs with less memory.

Parameters:
  • crop_info (-) – Dictionary containing crop information. Must include: - crop_atom_idxs: Array of atom indices in the crop - crop_token_idxs: Array of token indices in the crop

  • atom_array (-) – The atom array containing the full structure

  • max_atoms (-) – Maximum number of atoms allowed in a crop. If None, no resizing is performed.

Returns:

Updated crop_info dictionary with resized crop indices if necessary

Return type:

dict

Transforms and helper functions to convert from AtomArray objects to various encoding schemes.

During encoding, sequences of tokens are converted to sequences of integers, and the AtomArray of coordinates is converted to a (N_token, N_atoms_per_token, 3) tensor.

The token type (residue-level or atom-level) is encoded as a boolean in the atomize flag.

class atomworks.ml.transforms.encoding.AddTokenAnnotation(encoding: TokenEncoding)[source]#

Bases: Transform

Add a token annotation to the atom array. This is mostly meant as a debug transform and not expected to be used in production.

Sets the token annotation to the token name for each atom in the atom array.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.encoding.EncodeAF3TokenLevelFeatures(sequence_encoding: AF3SequenceEncoding)[source]#

Bases: Transform

A transform that encodes token-level features like AF3. The token-level features are returned as:

  • feats:

    # (Standard AF3 token-level features) - residue_index: Residue number in the token’s original input chain (pre-crop) - token_index: Token number. Increases monotonically; does not restart at 1 for new

    chains. (Runs from 0 to N_tokens)

    • asym_id: Unique integer for each distinct chain (pn_unit_iid)

      NOTE: We use pn_unit_iid rather than chain_iid to be more consistent with handling of multi-residue/multi-chain ligands (especially sugars)

    • entity_id: Unique integer for each distinct sequence (pn_unit entity)

    • sym_id: Unique integer within chains of this sequence. E.g. if pn_units A, B and C

      share a sequence but D does not, their `sym_id`s would be [0, 1, 2, 0].

    • restype: Integer encoding of the sequence. 32 possible values: 20 AA + unknown,

      4 RNA nucleotides + unknown, 4 DNA nucleotides + unknown, and gap. Ligands are represented as unknown amino acid (UNK)

    • is_protein: whether a token is of protein type

    • is_rna: whether a token is of RNA type

    • is_dna: whether a token is of DNA type

    • is_ligand: whether a token is a ligand residue

    # (Custom token-level features) - is_atomized: whether a token is an atomized token

  • feat_metadata:
    • asym_name: The asymmetric unit name for each id in asym_id. Acts as a legend.

    • entity_name: The entity name for each id in entity_id. Acts as a legend.

    • sym_name: The symmetric unit name for each id in sym_id. Acts as a legend.

Reference:
check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.encoding.EncodeAtomArray(encoding: TokenEncoding, default_coord: float | ndarray = nan, occupancy_threshold: float = 0.0, extra_annotations: list[str] = ['chain_id', 'chain_entity', 'molecule_iid', 'chain_iid', 'transformation_id'])[source]#

Bases: Transform

Encode an atom array to an arbitrary TokenEncoding.

This will add the following information to the data dict:
  • encoding (dict)
    • xyz: Atom coordinates (xyz)

    • mask: Atom mask giving information about which atoms are resolved in the encoded sequence (mask)

    • seq: Token sequence (seq)

    • token_is_atom: Token type (atom or residue) (token_is_atom)

    • Various other optional annotations such as chain_id, chain_entity, etc. See atom_array_to_encoding for more details.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

atomworks.ml.transforms.encoding.atom_array_from_encoding(encoded_coord: Tensor | ndarray, encoded_mask: Tensor | ndarray, encoded_seq: Tensor | ndarray, encoding: TokenEncoding, chain_id: str = 'A', token_is_atom: Tensor | ndarray | None = None, **other_annotations: ndarray | None) AtomArray[source]#

Create an AtomArray from encoded coordinates, mask, and sequence.

This function takes encoded data and reconstructs an AtomArray, which is a structured representation of atomic information. The encoded coordinates, mask, and sequence are used to populate the AtomArray, ensuring that all relevant annotations are included.

Parameters:
  • encoded_coord (-) – Encoded coordinates tensor.

  • encoded_mask (-) – Encoded mask tensor.

  • encoded_seq (-) – Encoded sequence tensor.

  • encoding (-) – The encoding to use for encoding the atom array.

  • chain_id (-) – Chain ID. Can be a single string (e.g., “A”) or a numpy array of shape (n_res,) corresponding to each residue. Defaults to “A”.

  • token_is_atom (-) – Boolean mask indicating whether each token corresponds to an atom.

  • **other_annotations (-) –

    Additional annotations to include in the AtomArray. The shape must match one of the following:

    • scalar, for global annotations

    • (n_atom,) for per-atom annotations,

    • (n_res,) for per-residue annotations,

    • (n_chain,) for per-chain annotations.

Returns:

The created AtomArray containing the encoded atomic information.

Return type:

  • atom_array (AtomArray)

atomworks.ml.transforms.encoding.atom_array_to_encoding(atom_array: AtomArray, encoding: TokenEncoding, default_coord: ndarray | float = nan, occupancy_threshold: float = 0.0, extra_annotations: list[str] = ['chain_id', 'chain_entity', 'molecule_iid', 'chain_iid', 'transformation_id']) dict[source]#

Encode an atom array using a specified TokenEncoding.

This function processes an AtomArray to generate encoded representations, including coordinates, masks, sequences, and additional annotations. The encoded data comes in numpy arrays which can readily be converted to tensors and used in machine learning tasks

Note

  • n_token refers to the number of tokens in the atom array.

  • n_atoms_per_token indicates the number of atoms associated with each token in the encoding. The number of atoms in a token corresponds to the number of residues in the atom array, unless the atom array has the atomize annotation, in which case the number of tokens may exceed the number of residues.

Parameters:
  • atom_array (-) – The atom array containing polymer information. If the atom array has the atomize annotation (True for atoms that should be atomized), the number of tokens will differ from the number of residues.

  • encoding (-) – The encoding scheme to apply to the atom array.

  • default_coord (-) – Default coordinate value to use for uninitialized coordinates. Defaults to float(“nan”).

  • occupancy_threshold (-) – Minimum occupancy for atoms to be considered resolved in the mask. Defaults to 0.0 (only completely unresolved atoms are masked).

  • extra_annotations (-) – A list of additional annotations to encode. These must be id style annotations (e.g., chain_id, molecule_iid). The encoding will be generated as integers, where the first occurrence of a given ID is encoded as 0, and subsequent occurrences are encoded as 1, 2, etc. Defaults to [“chain_id”, “chain_entity”, “molecule_iid”, “chain_iid”, “transformation_id”].

Returns:

A dictionary containing the following keys:
  • xyz (np.ndarray): Encoded coordinates of shape [n_token, n_atoms_per_token, 3].

  • mask (np.ndarray): Encoded mask of shape [n_token, n_atoms_per_token], indicating which atoms are resolved in the encoded sequence.

  • seq (np.ndarray): Encoded sequence of shape [n_token].

  • token_is_atom (np.ndarray): Boolean array of shape [n_token] indicating whether each token corresponds to an atom.

  • Various additional annotations encoded as extra keys in the dictionary. Each extra annotation

    that gets exposed is results in 2 keys in the dictionary. One for the encoded annotation itself and one mapping the annotation to integers if e.g. the original annotation was strings. For example, the defaults above result in: - chain_id (np.ndarray): Encoded chain IDs of shape [n_token]. - chain_id_to_int (dict): Mapping of chain IDs to integers in the chain_id array. - chain_entity (np.ndarray): Encoded entity IDs of shape [n_token]. - chain_entity_to_int (dict): Mapping of entity IDs to integers in the chain_entity array.

Return type:

  • dict

Transforms to handle featurization of edge cases with unresolved residues.

NOTE: Transforms that “filter” based on unresolved residues will be found in the “filters” file, not here.

class atomworks.ml.transforms.featurize_unresolved_residues.MaskPolymerResiduesWithUnresolvedFrameAtoms(occupancy_threshold: float = 0.0)[source]#

Bases: Transform

For residues with at least one unresolved frame atom, mask (set to occupancy zero) the entire residue.

This is a backwards-compatible wrapper around MaskResiduesWithSpecificUnresolvedAtoms.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['EncodeAtomArray', 'CropContiguousLikeAF3', 'CropSpatialLikeAF3', 'PlaceUnresolvedTokenOnClosestResolvedTokenInSequence']#
class atomworks.ml.transforms.featurize_unresolved_residues.MaskResiduesWithSpecificUnresolvedAtoms(chain_type_to_atom_names: dict[tuple | list | Any, list[str]] | None = None, occupancy_threshold: float = 0.0)[source]#

Bases: Transform

For residues with at least one unresolved atom from the specified list, mask (set to occupancy zero) the entire residue.

Helpful for e.g., when we are missing backbone frame atoms, since if we don’t have frame atoms, then:
  • We cannot build residue frames

  • The local structure quality is likely poor

We (and AF-3) consider the frame atoms to be:
  • Proteins: N, CA, C

  • Nucleic Acids: C1’, C3’, C4’

As an example for proteins, see PDB ID 6Z3R, which has unresolved C and CA atoms. As an example fo nucleic acids, see 7Z24, which has unresolved C1’, C2’, and C3’ (but does have a resolved oxygen)

NOTE: This transform must be applied before other transform that rely on the occupancy annotation.

This transform allows specification of which atoms to check for each chain type; for MPNN, we consider the backbone oxygen (O) as well.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['EncodeAtomArray', 'CropContiguousLikeAF3', 'CropSpatialLikeAF3', 'PlaceUnresolvedTokenOnClosestResolvedTokenInSequence']#
class atomworks.ml.transforms.featurize_unresolved_residues.PlaceUnresolvedTokenAtomsOnRepresentativeAtom(annotation_to_update: str = 'coord_to_be_noised')[source]#

Bases: Transform

Place unresolved token atoms (e.g., side chain atoms) on the representative atom of the residue (token).

Note that this Transform has no impact on non-polymers, as all atoms are considered tokens.

Parameters:
  • annotation_to_update (str) – The annotation to update with the new coordinates. E.g., “coord” (if we want to modify the ground-truth),

  • "coord_to_be_noised" (or)

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = ['AtomizeByCCDName']#
class atomworks.ml.transforms.featurize_unresolved_residues.PlaceUnresolvedTokenOnClosestResolvedTokenInSequence(annotation_to_update: str = 'coord_to_be_noised', annotation_to_copy: str = 'coord_to_be_noised')[source]#

Bases: Transform

Place fully unresolved tokens on their closest resolved neighbor in sequence space, breaking ties by choosing the “leftmost” neighbor.

This heuristic is helpful to avoid noising unresolved residue coordinates from the origin during diffusion training.

Parameters:
  • annotation_to_update (str) – The annotation to update with the new coordinates. E.g., “coord” (if we want to modify the ground-truth), or “coord_to_be_noised” (if we want to modify only the coordinates that will be noised). NOTE: Must match the annotation used for PlaceUnresolvedTokenAtomsOnRepresentativeAtom.

  • annotation_to_copy (str) – The annotation to copy from the resolved atom to the unresolved atom.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = ['AtomizeByCCDName', 'PlaceUnresolvedTokenAtomsOnRepresentativeAtom']#
atomworks.ml.transforms.featurize_unresolved_residues.mask_polymer_residues_with_unresolved_frame_atoms(atom_array: AtomArray, occupancy_threshold: float = 0.0) AtomArray[source]#

If a polymer residue has an unresolved backbone atom (occupancy <= occupancy_threshold), set the occupancy of the entire residue to zero.

This is a backwards-compatible wrapper around mask_residues_with_specific_unresolved_atoms.

atomworks.ml.transforms.featurize_unresolved_residues.mask_residues_with_specific_unresolved_atoms(atom_array: AtomArray, chain_type_to_atom_names: dict[tuple | list | Any, list[str]] | None = None, occupancy_threshold: float = 0.0) AtomArray[source]#

If a residue has any unresolved atoms from the specified list, set the occupancy of the entire residue to zero.

Parameters:
  • atom_array (AtomArray) – The atom array to modify.

  • chain_type_to_atom_names (dict[tuple | list | ChainType, list[str]], optional) – A dictionary mapping chain types to lists of atom names that should be checked for resolution. Keys can be: - Single chain type (e.g., ChainType.POLYPEPTIDE_L) - Tuple/list of chain types (e.g., ChainTypeInfo.PROTEINS) If None, uses the default AF-3 frame atoms. Defaults to None.

  • occupancy_threshold (float) – Atoms with occupancy <= this value are considered unresolved. Defaults to 0.0.

Returns:

The modified atom array.

Return type:

AtomArray

atomworks.ml.transforms.featurize_unresolved_residues.place_unresolved_token_atoms_on_token_representative_atom(atom_array: AtomArray, annotation_to_update: str = 'coord_to_be_noised') AtomArray[source]#

Place unresolved token atoms (e.g., side chain atoms) on the representative atom of the corresponding residue (token).

In cases where the representative atom is unresolved, we also try the token center atom (e.g., if the CB is unresolved but the CA is resolved, like in 8E83 for chain A, residue 194). Helpful for diffusive models to avoid noising unresolved side-chain atoms from the origin.

NOTE: For non-polymers, all atoms are considered tokens (and are atomized); in such cases this Transform will have no effect.

Parameters:
  • atom_array (AtomArray) – The atom array to modify.

  • annotation_to_update (str) – The annotation to update with the new coordinates. E.g., “coord” (if we want to modify the ground-truth), or “coord_to_be_noised” (if we want to modify only the coordinates that will be noised).

Returns:

The modified atom array.

Return type:

AtomArray

atomworks.ml.transforms.featurize_unresolved_residues.place_unresolved_token_on_closest_resolved_token_in_sequence(atom_array: AtomArray, annotation_to_update: str = 'coord_to_be_noised', annotation_to_copy: str = 'coord') AtomArray[source]#

Place all atoms within fully-unresolved residues on the closest resolved neighbor in sequence space.

NOTE: For non-polymers, each atom is considered a token, so this transform will place unresolved atoms on the closest resolved token in sequence space (i.e., the previous or next atom).

NOTE: We only perform the operation WITHIN chains, such that we don’t resolve across chain boundaries.

Parameters:
  • atom_array (AtomArray) – The atom array to modify.

  • annotation_to_update (str) – The annotation to update with the new coordinates. E.g., “coord” (if we want to modify the ground-truth), or “coord_to_be_noised” (if we want to modify only the coordinates that will be noised).

  • annotation_to_copy (str) – The annotation to copy from the resolved atom to the unresolved atom. E.g., “coord” (if we want to copy the ground-truth), or “coord_to_be_noised” (if we want to copy the coordinates that will be noised, which may have been modified by previous transforms). In the AF-3 pipeline, we want to copy “coord_to_be_noised”, to correctly resolve residues after applying PlaceUnresolvedTokenAtomsOnRepresentativeAtom.

Returns:

The modified atom array.

Return type:

AtomArray

Transforms that filter an AtomArray, removing chains, residues, or atoms based on some criteria

class atomworks.ml.transforms.filters.FilterToProteins(min_size: int = 5)[source]#

Bases: ApplyFunctionToAtomArray

Filter atom array to only include protein residues.

class atomworks.ml.transforms.filters.FilterToSpecifiedPNUnits(extra_info_key_with_pn_unit_iids_to_keep: str = 'all_pn_unit_iids_after_processing')[source]#

Bases: Transform

Filter atom array to only include specific PN units, denoted via the row metadata (held in extra_info). Such a filter is useful, for example, when during pre-processing we have identified clashing PN Units that we may want to exclude from the AtomArray.

Parameters:

extra_info_key_with_pn_unit_iids_to_keep (-) – The key in the “extra_info” dictionary that contains the PN unit IDs to filter to. If the key does not exist, the AtomArray is not filtered.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.filters.HandleUndesiredResTokens(undesired_res_tokens: list | tuple)[source]#

Bases: Transform

Remove, or otherwise handle, undesired residue tokens from the AtomArray.

For undesired residue names res_name, the following actions are taken:
  • For undesired residues in non-polymer residues:
    • Remove the entire non-polymer (pn_unit_iid)

  • For undesired residues in polymer residues:
    • Map to the closest canonical residue name (if possible)

    • Else, map to an unknown residue name (if possible, i.e if backbone atoms are present)

    • Else, atomize

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.filters.RandomlyRemoveLigands(delete_probability: float = 1.0, rng_seed: int = 42)[source]#

Bases: RandomlyRemovePNUnitsByAnnotationQuery

Randomly remove free-floating ligands (non-polymer ligands that are not covalent modifications)

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

class atomworks.ml.transforms.filters.RandomlyRemovePNUnitsByAnnotationQuery(query: str, delete_probability: float = 1.0, rng_seed: int = 42)[source]#

Bases: Transform

Randomly remove pn_units from atom_array based on a query string with configurable probability.

Parameters:
  • query – Query string in atomworks.io query syntax to identify pn_units to potentially delete

  • delete_probability – Probability of deleting matched pn_units (0.0 = never delete, 1.0 = always delete)

  • rng_seed – Random seed for reproducibility (default: 42)

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.filters.RemoveHydrogens(hydrogen_names: tuple | list = ('1', 'H', 'D', 'T'))[source]#

Bases: Transform

Remove hydrogens from the atom array.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.filters.RemoveNucleicAcidTerminalOxygen[source]#

Bases: Transform

Remove terminal oxygen atoms (OP3) in nucleic acids from the atom array.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.filters.RemovePolymersWithTooFewResolvedResidues(min_residues: int = 4)[source]#

Bases: Transform

From the AF-3 supplement, Section 2.5.4:

> “Any polymer chain containing fewer than 4 resolved residues is filtered out.”

We implement this filter as a Transform that removes polymer chains with fewer than min_residues resolved residues. Note that upstream, we must ensure that the chosen query PN units are not polymer chains with too few resolved residues themselves.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.filters.RemoveTerminalOxygen[source]#

Bases: Transform

Remove terminal oxygen atoms (OXT) from the atom array.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.filters.RemoveUnresolvedAtoms(min_occupancy: float = 0.5)[source]#

Bases: ApplyFunctionToAtomArray

class atomworks.ml.transforms.filters.RemoveUnresolvedLigandAtomsIfTooMany(unresolved_ligand_atom_limit: int)[source]#

Bases: Transform

If the number of unresolved (zero-occupancy) ligand atoms exceeds a specified threshold, remove all masked (zero-occupancy) ligand atoms from the atom array.

This Transform is needed to avoid a significant proportion of the crop window from being filled with unresolved ligand atoms. Most commonly, this occurs with poorly resolved liposomes.

Parameters:

unresolved_ligand_atom_limit (-) – The maximum number of unresolved ligand atoms allowed in the atom array.

Example: See PDB ID 6CLZ, which contains a liposome with many unresolved atoms.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.filters.RemoveUnresolvedPNUnits[source]#

Bases: Transform

Filters PN units that have all unresolved atoms (i.e., atoms with occupancy 0) from the AtomArray. Can be applied before or after croppping, since cropping may lead to PN units with all unresolved atoms that were previously not entirely unresolved. At training time, these unresolved PN units provide minimal value and can lead to errors in the model.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.filters.RemoveUnresolvedTokens[source]#

Bases: Transform

Filters tokens that have all unresolved atoms (i.e., atoms with occupancy 0) from the AtomArray.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.filters.RemoveUnsupportedChainTypes(supported_chain_types: Sequence[ChainType] = [ChainType.NON_POLYMER, ChainType.POLYPEPTIDE_L, ChainType.DNA, ChainType.RNA, ChainType.BRANCHED])[source]#

Bases: Transform

Filter out chains with unsupported chain types from the AtomArray.

Additionally, asserts that none of the query pn_units are of an unsupported chain type if given. (in which case they should have been filtered out upstream, otherwise our example is not valid).

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = []#
atomworks.ml.transforms.filters.filter_to_specified_pn_units(atom_array: AtomArray, pn_unit_iids: list | set | ndarray) AtomArray[source]#

Filter atom array to only include specific PN units.

atomworks.ml.transforms.filters.random_remove_pn_units_by_annotation_query(atom_array: AtomArray, query: str, delete_probability: float = 1.0, rng: Generator | None = None) AtomArray[source]#

Randomly remove pn_units from atom_array based on a query string with configurable probability.

A pn_unit is considered to match the query if ALL atoms in the pn_unit satisfy the query condition.

Parameters:
  • atom_array – The AtomArray to filter

  • query – Query string in atomworks.io Query syntax to identify pn_units to potentially delete

  • delete_probability – Probability of deleting matched pn_units (0.0 = never delete, 1.0 = always delete)

  • rng – Random number generator for probabilistic deletion

atomworks.ml.transforms.filters.remove_hydrogens(atom_array: AtomArray, hydrogen_names: tuple | list = ('H', 'H2', 'D', 'T')) AtomArray[source]#

Remove hydrogens from the atom array.

atomworks.ml.transforms.filters.remove_nucleic_acid_terminal_oxygen(atom_array: AtomArray | AtomArrayStack) AtomArray | AtomArrayStack[source]#

Remove terminal oxygen atoms (OP3) in nucleic acids from the atom array.

atomworks.ml.transforms.filters.remove_polymers_with_too_few_resolved_residues(atom_array: AtomArray, min_residues: int = 4) AtomArray[source]#
atomworks.ml.transforms.filters.remove_protein_terminal_oxygen(atom_array: AtomArray | AtomArrayStack) AtomArray | AtomArrayStack[source]#

Remove terminal oxygen atoms (OXT) from protein chains.

Terminal oxygen atoms are only removed from protein residues that are not atomized.

atomworks.ml.transforms.filters.remove_unresolved_atoms(atom_array: AtomArray, min_occupancy: float = 0.5) AtomArray[source]#

Remove atoms with occupancy less than min_occupancy from the atom array.

atomworks.ml.transforms.filters.remove_unresolved_pn_units(atom_array: AtomArray) AtomArray[source]#

Filters PN units that have all unresolved atoms (i.e., atoms with occupancy 0) from the AtomArray. Can be applied before or after croppping, since cropping may lead to PN units with all unresolved atoms that were previously not entirely unresolved. At training time, these unresolved PN units provide minimal value and can lead to errors in the model.

atomworks.ml.transforms.filters.remove_unresolved_tokens(atom_array: AtomArray) AtomArray[source]#

Filters tokens that have all unresolved atoms (i.e., atoms with occupancy 0) from the AtomArray.

A token is defined by the token utilities and can be either:
  • A residue (when atomize=False)

  • An individual atom (when atomize=True)

atomworks.ml.transforms.filters.remove_unsupported_chain_types(atom_array: AtomArray, query_pn_unit_iids: Sequence[str] | None = None, supported_chain_types: Sequence[ChainType] = [ChainType.NON_POLYMER, ChainType.POLYPEPTIDE_L, ChainType.DNA, ChainType.RNA, ChainType.BRANCHED]) AtomArray[source]#

Filter out chains with unsupported chain types from the AtomArray.

Additionally, asserts that none of the query pn_units are of an unsupported chain type if given. (in which case they should have been filtered out upstream, otherwise our example is not valid).

Parameters:
  • query_pn_unit_iids (Sequence[str] | None) – The PN unit IDs to check for unsupported chain types.

  • supported_chain_types (Sequence[ChainType]) – The chain types to filter out.

Returns:

The filtered AtomArray.

Return type:

AtomArray

Transforms that add masks for an AtomArray to the data

class atomworks.ml.transforms.masks.AddSpatialKNNMask(num_neighbors: int, max_atoms_in_crop: int = 40000)[source]#

Bases: Transform

Generate a spatial k-nearest neighbors mask for each atom in the input atom array and add it to the data with the key ‘spatial_knn_masks’ (shape: (n_atoms, n_atoms)).

This mask is e.g. used as a local attention mask in diffusion for the input sequence based on given coordinates.

Parameters:
  • num_neighbors (int) – The number of neighbors to keep.

  • max_atoms_in_crop (int) – The maximum allowed number of atoms in the crop. This transform builds an (n_atoms, n_atoms) mask, so this limit on the number of atoms avoids unexpected memory baloons. The default of 40’000 atoms should allow any crop of <1’538 tokens to pass (worst case: RNA guanine, which has 26 heavy atoms per token, resulting in 39,988 atoms for a structure made up of only guanine)

check_input(data: dict[str, Any]) None[source]#

Check if the input data contains the required keys and types. :param data: The input data dictionary. :type data: Dict[str, Any]

Raises:
  • KeyError – If a required key is missing from the input data.

  • TypeError – If a value in the input data is not of the expected type.

forward(data: dict[str, Any]) dict[str, Any][source]#

Generate a local attention mask for the input sequence based on given coordinates. only keep k nearest neighbors

Parameters:

data (dict[str, Any]) – The input data dictionary.

Returns:

The output data dictionary with the added ‘spatial_knn_masks’ key
  • ’spatial_knn_masks’ (np.ndarray): Boolean mask of shape (n_atoms, n_atoms) where True indicates that the atom is a k-nearest neighbor of the other atom. NOTE: atoms with no coordinates will not receive any neighbors in the mask (i.e. a row of all False values)

Return type:

dict[str, Any]

atomworks.ml.transforms.masks.compute_spatial_knn_mask(coords: ndarray, k: int) ndarray[source]#

Compute the spatial KNN mask for an atom array

class atomworks.ml.transforms.sasa.CalculateSASA(probe_radius: float = 1.4, atom_radii: Literal['ProtOr'] | ndarray = 'ProtOr', point_number: int = 100)[source]#

Bases: Transform

Transform for calculating Solvent-Accessible Surface Area (SASA) for each atom in an AtomArray.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict, key_to_add_sasa_to: str = 'atom_array') dict[source]#

Calculates SASA and adds it to the data dictionary under the key atom_array. :param data: dict

A dictionary containing the input data atomarray.

Parameters:

key_to_add_sasa_to – str The key in the data dictionary to add the SASA values to.

Returns:

The data dictionary with SASA values added.

Return type:

dict

atomworks.ml.transforms.sasa.calculate_atomwise_rasa(atom_array: AtomArray, probe_radius: float = 1.4, atom_radii: str | ndarray = 'ProtOr', point_number: int = 100) ndarray[source]#

Calculate the Relative Solvent-Accessible Surface Area (RASA) for each atom in atom_array.

The RASA is defined as the ratio of the SASA of a residue in a protein structure to the SASA of the same residue in an extended conformation.

The output will have the same length as the input AtomArray, with NaN values for excluded (invalid) atoms.

Parameters:
  • atom_array (AtomArray) – The input AtomArray containing the atomic coordinates.

  • probe_radius (float, optional) – Van-der-Waals radius of the probe in Angstrom. Defaults to 1.4 (for water).

  • atom_radii (str | np.ndarray, optional) – Atom radii set to use for calculation. Defaults to “ProtOr”. “ProtOr” will not get sasa’s for hydrogen atoms and some other atoms, like ions or certain atoms with charges

  • point_number (int, optional) – Number of points in the Shrake-Rupley algorithm to sample for calculating SASA. Defaults to 100.

atomworks.ml.transforms.sasa.calculate_atomwise_sasa(atom_array: AtomArray, probe_radius: float = 1.4, atom_radii: str | ndarray = 'ProtOr', point_number: int = 100) ndarray[source]#

Calculate the SASA for each atom in atom_array, excluding those with NaN coordinates. The output will have the same length as the input AtomArray, with NaN values for excluded (invalid) atoms.

Args:

probe_radius (float, optional): Van-der-Waals radius of the probe in Angstrom. Defaults to 1.4 (for water). atom_radii (str | np.ndarray, optional): Atom radii set to use for calculation. Defaults to “ProtOr”. “ProtOr” will not get sasa’s for hydrogen atoms and some other atoms, like ions or certain atoms with charges point_number (int, optional): Number of points in the Shrake-Rupley algorithm to sample for calculating SASA. Defaults to 100.

class atomworks.ml.transforms.symmetry.AddPostCropMoleculeEntityToFreeFloatingLigands[source]#

Bases: Transform

Relabels the molecule entities of free-floating (i.e. not bonded to a polymer), cropped ligands. This is relevant for identifying identical, swappable ligands, which are treated as swappable in the RF2AA loss.

The relabelled molecule entity labels are stored in the post_crop_molecule_entity annotation of the AtomArray. This ensures that any downstream processes can accurately reference the modified entities without confusion.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.symmetry.CreateSymmetryCopyAxisLikeRF2AA(encoding: ~atomworks.ml.encoding_definitions.TokenEncoding = Encoding(n_tokens=80, n_atoms_per_token=36)  Token      | 0    | 1    | 2    | 3    | 4    | 5    | 6    | 7    | 8    | 9    | 10   | 11   | 12   | 13   | 14   | 15   | 16   | 17   | 18   | 19   | 20   | 21   | 22   | 23   | 24   | 25   | 26   | 27   | 28   | 29   | 30   | 31   | 32   | 33   | 34   | 35   -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   0 : ALA   | N    | CA   | C    | O    | CB   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 3HB  |      |      |      |      |      |      |      |        1 : ARG   | N    | CA   | C    | O    | CB   | CG   | CD   | NE   | CZ   | NH1  | NH2  |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HG  | 2HG  | 1HD  | 2HD  | HE   | 1HH1 | 2HH1 | 1HH2 | 2HH2   2 : ASN   | N    | CA   | C    | O    | CB   | CG   | OD1  | ND2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HD2 | 2HD2 |      |      |      |      |      |      |        3 : ASP   | N    | CA   | C    | O    | CB   | CG   | OD1  | OD2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  |      |      |      |      |      |      |      |      |        4 : CYS   | N    | CA   | C    | O    | CB   | SG   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | HG   |      |      |      |      |      |      |      |        5 : GLN   | N    | CA   | C    | O    | CB   | CG   | CD   | OE1  | NE2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HG  | 2HG  | 1HE2 | 2HE2 |      |      |      |      |        6 : GLU   | N    | CA   | C    | O    | CB   | CG   | CD   | OE1  | OE2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HG  | 2HG  |      |      |      |      |      |      |        7 : GLY   | N    | CA   | C    | O    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | 1HA  | 2HA  |      |      |      |      |      |      |      |      |      |        8 : HIS   | N    | CA   | C    | O    | CB   | CG   | ND1  | CD2  | CE1  | NE2  |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 2HD  | 1HE  | 2HE  |      |      |      |      |      |        9 : ILE   | N    | CA   | C    | O    | CB   | CG1  | CG2  | CD1  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | HB   | 1HG2 | 2HG2 | 3HG2 | 1HG1 | 2HG1 | 1HD1 | 2HD1 | 3HD1 |      |       10 : LEU   | N    | CA   | C    | O    | CB   | CG   | CD1  | CD2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | HG   | 1HD1 | 2HD1 | 3HD1 | 1HD2 | 2HD2 | 3HD2 |      |       11 : LYS   | N    | CA   | C    | O    | CB   | CG   | CD   | CE   | NZ   |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HG  | 2HG  | 1HD  | 2HD  | 1HE  | 2HE  | 1HZ  | 2HZ  | 3HZ   12 : MET   | N    | CA   | C    | O    | CB   | CG   | SD   | CE   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HG  | 2HG  | 1HE  | 2HE  | 3HE  |      |      |      |       13 : PHE   | N    | CA   | C    | O    | CB   | CG   | CD1  | CD2  | CE1  | CE2  | CZ   |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HD  | 2HD  | 1HE  | 2HE  | HZ   |      |      |      |       14 : PRO   | N    | CA   | C    | O    | CB   | CG   | CD   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | HA   | 1HB  | 2HB  | 1HG  | 2HG  | 1HD  | 2HD  |      |      |      |      |      |       15 : SER   | N    | CA   | C    | O    | CB   | OG   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HG   | HA   | 1HB  | 2HB  |      |      |      |      |      |      |      |       16 : THR   | N    | CA   | C    | O    | CB   | OG1  | CG2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HG1  | HA   | HB   | 1HG2 | 2HG2 | 3HG2 |      |      |      |      |      |       17 : TRP   | N    | CA   | C    | O    | CB   | CG   | CD1  | CD2  | CE2  | CE3  | NE1  | CZ2  | CZ3  | CH2  |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HD  | 1HE  | HZ2  | HH2  | HZ3  | HE3  |      |      |       18 : TYR   | N    | CA   | C    | O    | CB   | CG   | CD1  | CD2  | CE1  | CE2  | CZ   | OH   |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HD  | 1HE  | 2HE  | 2HD  | HH   |      |      |      |       19 : VAL   | N    | CA   | C    | O    | CB   | CG1  | CG2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | HB   | 1HG1 | 2HG1 | 3HG1 | 1HG2 | 2HG2 | 3HG2 |      |      |      |       20 : UNK   | N    | CA   | C    | O    | CB   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 3HB  |      |      |      |      |      |      |      |       21 : <M>   | N    | CA   | C    | O    | CB   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 3HB  |      |      |      |      |      |      |      |       22 : DA    | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C2'  | C1'  | N9   | C4   | N3   | C2   | N1   | C6   | C5   | N7   | C8   | N6   |      |      | H5'' | H5'  | H4'  | H3'  | H2'' | H2'  | H1'  | H2   | H61  | H62  | H8   |      |       23 : DC    | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C2'  | C1'  | N1   | C2   | O2   | N3   | C4   | N4   | C5   | C6   |      |      |      |      | H5'' | H5'  | H4'  | H3'  | H2'' | H2'  | H1'  | H42  | H41  | H5   | H6   |      |       24 : DG    | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C2'  | C1'  | N9   | C4   | N3   | C2   | N1   | C6   | C5   | N7   | C8   | N2   | O6   |      | H5'' | H5'  | H4'  | H3'  | H2'' | H2'  | H1'  | H1   | H22  | H21  | H8   |      |       25 : DT    | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C2'  | C1'  | N1   | C2   | O2   | N3   | C4   | O4   | C5   | C7   | C6   |      |      |      | H5'' | H5'  | H4'  | H3'  | H2'' | H2'  | H1'  | H3   | H71  | H72  | H73  | H6   |       26 : DN    | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C2'  | C1'  |      |      |      |      |      |      |      |      |      |      |      |      | H5'' | H5'  | H4'  | H3'  | H2'' | H2'  | H1'  |      |      |      |      |      |       27 : A     | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C1'  | C2'  | O2'  | N1   | C2   | N3   | C4   | C5   | C6   | N6   | N7   | C8   | N9   |      | H5'  | H5'' | H4'  | H3'  | H2'  | HO2' | H1'  | H2   | H61  | H62  | H8   |      |       28 : C     | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C1'  | C2'  | O2'  | N1   | C2   | O2   | N3   | C4   | N4   | C5   | C6   |      |      |      | H5'  | H5'' | H4'  | H3'  | H2'  | HO2' | H1'  | H42  | H41  | H5   | H6   |      |       29 : G     | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C1'  | C2'  | O2'  | N1   | C2   | N2   | N3   | C4   | C5   | C6   | O6   | N7   | C8   | N9   | H5'  | H5'' | H4'  | H3'  | H2'  | HO2' | H1'  | H1   | H22  | H21  | H8   |      |       30 : U     | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C1'  | C2'  | O2'  | N1   | C2   | O2   | N3   | C4   | O4   | C5   | C6   |      |      |      | H5'  | H5'' | H4'  | H3'  | H2'  | HO2' | H1'  | H3   | H5   | H6   |      |      |       31 : N     | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C1'  | C2'  | O2'  |      |      |      |      |      |      |      |      |      |      |      | H5'  | H5'' | H4'  | H3'  | H2'  | HO2' | H1'  |      |      |      |      |      |       32 : HIS_D | N    | CA   | C    | O    | CB   | CG   | NE2  | CD2  | CE1  | ND1  |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 2HD  | 1HE  | 1HD  |      |      |      |      |      |       33 : 13    |      | 13   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       34 : 33    |      | 33   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       35 : 79    |      | 79   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       36 : 5     |      | 5    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       37 : 4     |      | 4    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       38 : 35    |      | 35   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       39 : 6     |      | 6    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       40 : 20    |      | 20   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       41 : 17    |      | 17   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       42 : 27    |      | 27   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       43 : 24    |      | 24   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       44 : 29    |      | 29   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       45 : 9     |      | 9    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       46 : 26    |      | 26   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       47 : 80    |      | 80   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       48 : 53    |      | 53   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       49 : 77    |      | 77   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       50 : 19    |      | 19   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       51 : 3     |      | 3    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       52 : 12    |      | 12   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       53 : 25    |      | 25   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       54 : 42    |      | 42   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       55 : 7     |      | 7    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       56 : 28    |      | 28   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       57 : 8     |      | 8    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       58 : 76    |      | 76   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       59 : 15    |      | 15   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       60 : 82    |      | 82   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       61 : 46    |      | 46   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       62 : 59    |      | 59   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       63 : 78    |      | 78   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       64 : 75    |      | 75   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       65 : 45    |      | 45   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       66 : 44    |      | 44   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       67 : 16    |      | 16   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       68 : 51    |      | 51   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       69 : 34    |      | 34   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       70 : 14    |      | 14   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       71 : 50    |      | 50   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       72 : 65    |      | 65   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       73 : 52    |      | 52   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       74 : 92    |      | 92   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       75 : 74    |      | 74   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       76 : 23    |      | 23   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       77 : 39    |      | 39   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       78 : 30    |      | 30   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       79 : 0     |      | 0    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |, max_automorphs: int = 1000, max_isomorphisms: int = 1000)[source]#

Bases: Transform

Create the symmetry axis for the xyz and mask features that go into RF2AA. These are required to resolve the equivalence between equivalent polymers and non-polymer configurations in the RF2AA loss.

This transform generates what we loosely refer to as ‘symmetry copies’ of the coordinates and mask values (True indicates an atom exists) by computing and applying isomorphisms between molecules (for polymers) and automorphisms within each molecule (for non-polymers). This transform is very bespoke to the RF2AA loss and implementation and is not intended to be used outside of the RF2AA codebase.

The transform roughly follows these steps:

  1. Input Validation:
    • Ensures the input data contains the necessary keys and types, including atom arrays, encoded data, and crop information. It also checks that the data satisfies the assumptions in RF2AA, namely that all polymer tokens occur before non-polymer (or atomized) tokens in the atom array and encoding. Atomized bits of polymers are treated as non-polymers in this transform.

  2. Polymer Symmetries:
    • Identifies which polymers are equivalent (isomorphic) based on molecule entities.

    • Generates all possible combinations of in-group permutations (isomorphisms) for these polymers.

    • Maps these isomorphisms from the instance level to the token level.

    • Encodes the pre-cropped atom array to get the xyz coordinates and masks.

    • Applies the isomorphisms to the pre-cropped xyz and mask, then subsets to the crop tokens for the permuted post-cropped xyz and mask.

    • Ensures the first column corresponds to the unpermuted, original polymer coordinates and mask.

  3. Non-Polymer Symmetries:
    • Checks if there are any non-polymers in the crop.

    • Identifies the full molecules (pre-crop) for each non-polymer that has tokens in the crop.

    • Computes the automorphisms for each of these full molecules.

    • Applies the automorphisms to the coordinates and masks of the encoded full molecules, then subsets to the crop tokens.

    • Concatenates all the automorphs together, padded to the maximum number of automorphs for any molecule.

    • Ensures all non-polymers are entirely atomized.

  4. Combining Results:
    • Combines the polymer and non-polymer xyz and masks by concatenating them along the token axis and padding the newly created symmetry axis.

    • Updates the encoded data with the combined xyz and mask, which will be used as input for RF2AA.

The effect of this function is to:
  1. Update the ‘encoded’ key of the data dict with the symmetry copies of the xyz and mask.

  2. Add the ‘symmetry_info’ key to the data dict, which contains metadata on the symmetry.

Example

>>> transform = CreateSymmetryCopyAxisLikeRF2AA()
>>> data = {
...     "atom_array": AtomArray(...),
...     "encoded": {"xyz": torch.tensor(...), "mask": torch.tensor(...)},
...     "openbabel": {...},
...     "crop_info": {"atom_array": AtomArray(...), "crop_token_idxs": np.array(...)},
... }
>>> transformed_data = transform.forward(data)
>>> # transformed_data["encoded"]["xyz"] and transformed_data["encoded"]["mask"] now contain the symmetry copies,
>>> # i.e. they are of shape
>>> #  - [n_permutations, n_crop_tokens, n_atoms_per_token, 3]
>>> #  - [n_permutations, n_crop_tokens, n_atoms_per_token]
assert_nonpoly_come_after_polys(atom_array: AtomArray) None[source]#
check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

handle_nonpoly_automorphisms(pre_nonpoly_array: AtomArray, post_nonpoly_array: AtomArray, crop_tmask: ndarray, openbabel_data: dict[int, Any]) tuple[Tensor, Tensor][source]#

Handles non-polymer symmetries by computing automorphs within each non-polymer that is at least partially in the crop.

This function calculates the swapped coordinate and mask values for each molecule and concatenates all automorphs together, padding the n_permutations dimension to the maximum number of automorphs for any molecule.

WARNING: Unlike polymer symmetries, inter-molecule symmetries are not considered here, as they are managed by the RF2AA loss function through a greedy search.

For non-polymers, the following steps are performed: 1. Subset the pre-cropped non-poly array to only include the non-poly molecules that are

at least partially in the crop.

  1. Compute the automorphs for each identified full molecule (i.e. BEFORE cropping).

  2. Apply the computed automorphs to the coordinates and masks of the encoded full molecules.

  3. Crop to the relevant sections of the molecules that appear in the crop.

  4. Concatenate all automorphs together, padding to the maximum number of automorphs for any molecule.

  5. De-duplicate the automorphs (duplications can happen if the cropped tokens do not differ between two automorphs, but the atom swaps are in a part that is not within the crop)

Parameters:
  • pre_nonpoly_array (-) – The pre-cropped non-polymer array to process.

  • post_nonpoly_array (-) – The post-cropped non-polymer array to process.

  • crop_tmask (-) – A boolean mask indicating which tokens are in the crop.

  • openbabel_data (-) – A dictionary containing Open Babel data for molecules.

Returns:

A tensor containing the coordinates of the non-polymer automorphs. - nonpoly_masks (torch.Tensor): A tensor containing the masks of the non-polymer automorphs. - symmetry_info (dict[tuple[int, str], int]): A dictionary containing the symmetry information.

Return type:

  • nonpoly_xyzs (torch.Tensor)

handle_polymer_isomorphisms(pre_poly_array: AtomArray, post_poly_array: AtomArray, crop_tmask: ndarray) tuple[Tensor, Tensor][source]#

Handles polymer symmetries by computing all swaps between isomorphic (i.e., equivalent) polymers that are at least partially in the crop.

NOTE: This function only swaps full chains. Swaps within atoms of a polymer (e.g., residue naming ambiguities) are not considered and are handled elsewhere.

The process involves the following steps: 1. Subset the crop mask and pre-cropped atom array to only include polymers that are at least

partially in the crop.

  1. Among these, identify polymers that are equal to each other (i.e., symmetry groups).

3. Generate all possible combinations of in-group permutations (isomorphisms). 3. Apply these isomorphisms to the coordinates and masks of the pre-cropped, encoded polymers. 4. Crop to the relevant bits that appear in the crop. 5. De-duplicate the isomorphisms to remove any redundancies.

Parameters:
  • pre_poly_array (-) – The atom array representing the state before cropping, containing polymer tokens.

  • post_poly_array (-) – The atom array representing the state after cropping, containing polymer tokens.

  • crop_tmask (-) – A boolean mask indicating which tokens are included in the crop.

Returns:

The xyz coordinates of the polymers after applying the isomorphisms.

It has shape [n_permutations, n_crop_tokens, n_atoms_per_token, 3].

  • poly_mask (torch.Tensor): The mask of the polymers after applying the isomorphisms.

    It has shape [n_permutations, n_crop_tokens, n_atoms_per_token].

Return type:

  • poly_xyz (torch.Tensor)

previous_transforms_order_matters: bool = True#
requires_previous_transforms: ClassVar[list[str | Transform]] = ['SortLikeRF2AA', 'AddOpenBabelMoleculesForAtomizedMolecules', 'EncodeAtomArray']#
class atomworks.ml.transforms.symmetry.FindAutomorphismsWithNetworkX[source]#

Bases: Transform

Generates a list of automorphisms (including both polymer and non-polymer residues) for a given atom array. Used in AF-3/AF-Multimer-style symmetry resolution

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = [<class 'atomworks.ml.transforms.atomize.AtomizeByCCDName'>]#
atomworks.ml.transforms.symmetry.apply_automorphs(data: Tensor, automorphs: ndarray | Tensor) Tensor[source]#

Create data permutations of the input data for each of the automorphs.

This function generates permutations of the input tensor data based on the provided automorphisms. Each permutation corresponds to a different automorphism, effectively reordering the data according to the automorphisms.

Parameters:
  • data (-) – The input tensor to be permuted. The first dimension has to correspond to the number of atoms.

  • automorphs (-) – A tensor or numpy array of shape [n_automorphs, n_atoms, 2] representing the automorphisms. Each automorphism is a list of paired atom indices (from_idx, to_idx). The from_idx column is essentially just a repetition of np.arange(n_atoms).

Returns:

A tensor of shape [n_automorphs, *data.shape] containing the permuted

data for each automorphism.

Return type:

  • data_automorphs (torch.Tensor)

Example

>>> data = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> # Example automorphisms (2 automorphisms for 3 atoms)
>>> automorphs = np.array([
        [[0, 0],
         [1, 1],
         [2, 2]],
        [[0, 2],
         [1, 0],
         [2, 1]]
... ])
>>> permuted_data = create_automorph_permutations(data, automorphs)
>>> print(permuted_data)
tensor([[[1.0, 2.0],
         [3.0, 4.0],
         [5.0, 6.0]],
[[5.0, 6.0],

[1.0, 2.0], [3.0, 4.0]]])

atomworks.ml.transforms.symmetry.find_automorphisms(atom_array: AtomArray) ndarray[source]#
atomworks.ml.transforms.symmetry.find_automorphisms_with_networkx(atom_array: AtomArray, max_automorphs: int = 1000) ndarray[source]#

Finds automorphisms in an AtomArray using NetworkX, returning indices of atoms that can be permuted.

Parameters:
  • atom_array (AtomArray) – The input AtomArray object. Must have the following annotations: pn_unit_iid, is_polymer, res_id, res_name, atom_name, and element.

  • max_automorphs (int, optional) – The maximum number of automorphisms to generate. Default is 1000.

Returns:

A Python list of arrays, each containing indices of atoms that can be permuted within the global

frame of the input atom_array.

Return type:

np.ndarray

Example

>>> automorphisms = find_automorphisms_with_networkx(atom_array)
# Output:
# [
#     array([  # E.g., corresponding to the first residue
#         [0, 1, 2, 3, 4, 5],  # The first row is the identity permutation
#         [0, 1, 2, 3, 5, 4]   # Atoms with global indices 4 and 5 are swappable
#     ]),
#     array([  # E.g., corresponding to the second residue
#         [6, 7, 8, 9, 10, 11],  # The first row is the identity permutation. Indices are global (within the AtomArray).
#     ])
# ]
# Each sub-array represents indices of atoms that can be permuted within the global frame.
atomworks.ml.transforms.symmetry.generate_automorphisms_from_atom_array_with_networkx(atom_array: AtomArray, max_automorphs: int = 1000, node_features: str | list = 'element', ignore_bond_type: bool = True, hash_key: Hashable = None) ndarray[source]#

Generate automorphisms of a molecular graph using NetworkX.

In some cases, the automorphisms generated by RDKit or OpenBabel may be overly strict; e.g., they do not account for resonance. This function uses NetworkX to generate automorphisms of a molecular graph, which can be more flexible (but in some cases overly permissive).

Parameters:
  • atom_array (AtomArray) – The input molecular structure as an AtomArray object.

  • max_automorphs (int) – The maximum number of automorphisms to generate. Default is 1000.

  • node_features (str or list of str) – The node-level features to use for coloring nodes. Can be a single feature (e.g., ‘element’) or a list of features (e.g., [‘element’, ‘charge’]). Default is ‘element’.

  • ignore_bond_type (bool) – If True, the bond type is ignored when generating automorphisms. Must be true in order to detect some resonance-based automorphisms. Default is True.

  • hash_key (Hashable) – A hashable key to use for caching automorphisms. If None, no caching is used. Used by the decorator cache_based_on_subset_of_args, so cannot be deleted (even if unused in this function).

Returns:

An array where the first row is the identity permutation [0, 1, 2, …, n],

and subsequent rows are the permutations representing automorphisms.

Return type:

np.ndarray

Example

>>> atom_array = struc.info.residue("H2O")  # Water molecule; two identical hydrogen atoms
>>> automorphisms = generate_automorphisms_from_atom_array_with_networkx(atom_array)
>>> print(automorphisms)
[[0, 1, 2],
 [0, 2, 1]]  # Example output for a simple molecule like H2O
atomworks.ml.transforms.symmetry.get_isomorphisms_from_symmetry_groups(group_to_instance_map: dict[int | str, Sequence[int]], max_isomorphisms: int = 1000) ndarray[source]#

Create an array of all possible isomorphisms for a given entity to instances mapping.

Parameters:
  • group_to_instance_map (dict[int | str, Sequence[int]]) – A dictionary mapping entities to their instances. For example, {1: [0, 1, 2], 2: [3], 3: [4, 5]}.

  • max_isomorphisms (int) – The maximum number of isomorphisms to return. Defaults to 1000.

Returns:

A 2D array of shape (n_isomorphisms, n_instances) containing all possible

isomorphisms. n_isomorphisms is the number of possible isomorphisms and n_instances is the total number of instances. The values are the id`s (i.e. the values in the `group_to_instance_map) of the instances in the isomorphism. Each group of instances in the isomorphisms array appears consecutively in the array (column-wise) and the order of the group is the order of the instances in the group_to_instance_map.

Return type:

isomorphisms (np.ndarray)

Example

>>> get_isomorphisms_from_symmetry_groups({1: [0, 1, 2], 2: [3], 3: [4, 5]})
#       1, 1, 1, 2, 3, 3  <-(symmetry group)
#     --------------------
array([[0, 1, 2, 3, 4, 5],
       [0, 2, 1, 3, 4, 5],
       [1, 0, 2, 3, 4, 5],
       [1, 2, 0, 3, 4, 5],
       [2, 0, 1, 3, 4, 5],
       [2, 1, 0, 3, 4, 5],
       [0, 1, 2, 3, 5, 4],
       [0, 2, 1, 3, 5, 4],
       [1, 0, 2, 3, 5, 4],
       [1, 2, 0, 3, 5, 4],
       [2, 0, 1, 3, 5, 4],
       [2, 1, 0, 3, 5, 4]], dtype=uint32)
atomworks.ml.transforms.symmetry.identify_isomorphic_chains_based_on_chain_entity(atom_array: AtomArray) dict[int | str, list[int | str]][source]#

Identifies isomorphic chains based on the chain entity annotation.

This function creates a dictionary mapping chain entities to their corresponding chain IDs. Chains with the same entity are considered isomorphic.

Parameters:

atom_array (-) – The atom array containing chain entity and chain ID annotations.

Returns:

A dictionary where keys are chain entities and values

are lists of chain IDs belonging to that entity.

Return type:

  • dict[int | str, list[int | str]]

Example

>>> atom_array = AtomArray(...)  # AtomArray with chain_entity and chain_iid annotations
>>> isomorphic_chains = identify_isomorphic_chains_based_on_chain_entity(atom_array)
>>> print(isomorphic_chains)
{1: ['A', 'B', 'C'], 2: ['D', 'E', 'F'], 3: ['G']}
atomworks.ml.transforms.symmetry.identify_isomorphic_chains_based_on_molecule_entity(atom_array: AtomArray) dict[int | str, list[int | str]][source]#

Identifies isomorphic molecules based on the molecule entity annotation.

This function creates a dictionary mapping molecule entities to their corresponding molecule IDs. Molecules with the same entity are considered isomorphic.

Parameters:

atom_array (-) – The atom array containing molecule entity and molecule ID annotations.

Returns:

A dictionary where keys are molecule entities and values

are lists of molecule IDs belonging to that entity.

Return type:

  • dict[int | str, list[int | str]]

Example

>>> atom_array = AtomArray(...)  # AtomArray with molecule_entity and molecule_iid annotations
>>> isomorphic_molecules = identify_isomorphic_chains_based_on_molecule_entity(atom_array)
>>> print(isomorphic_molecules)
{"A,B": [1, 2, 3], "C": [4, 5]}
atomworks.ml.transforms.symmetry.instance_to_token_lvl_isomorphisms(instance_isomorphisms: ndarray, instance_token_idxs: list[ndarray]) ndarray[source]#

Convert instance-level isomorphisms to token-level isomorphisms.

This function takes a set of instance-level isomorphisms and their corresponding token indices, and maps the instance isomorphisms to token-level indices.

Parameters:
  • instance_isomorphisms (np.ndarray) – A 2D array of shape (n_permutations, n_instances) where each row represents a permutation of instance indices.

  • instance_token_idxs (list of np.ndarray) – A list where each element is an array of token indices corresponding to each instance.

Returns:

A 2D array of shape (n_permutations, total_tokens) containing the

token-level isomorphisms.

Return type:

token_lvl_isomorphisms (np.ndarray)

Example: >>> instance_isomorphisms = np.array([[0, 1], [1, 0]]) # Example instance-level isomorphisms >>> instance_token_idxs = [ … np.array([0, 1]), … np.array([2, 3]), … ] # Example token indices for each instance >>> token_lvl_isomorphisms = instance_to_token_lvl_isomorphisms(instance_isomorphisms, instance_token_idxs) >>> [[0 1 2 3] >>> [2 3 0 1]]

Transforms for adding and featurizing templates.

class atomworks.ml.transforms.template.AddInputFileTemplate[source]#

Bases: Transform

If atoms from the input file have been marked as templates, add them to the template dictionary. This is useful for when users want to use a part of their design as a template using the template_selection_syntax argument in the inference script.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.template.AddRFTemplates(max_n_template: int = 1, pick_top: bool = True, min_seq_similarity: float = 0.0, max_seq_similarity: float = 100.0, min_template_length: int = 0, filter_by_query_length: bool = False, template_lookup_path: PathLike | None = None, template_base_dir: PathLike | None = None)[source]#

Bases: Transform

Adds RF templates to the data.

The templates are added to the data under the key template.

Output features:
  • template (dict): A dictionary with chain IDs as keys and a list of templates for that chain as values.
    Each template is a dictionary with the following keys:
    • id (str): The template ID.

    • pdb_id (str): The PDB ID of the template.

    • chain_id (str): The chain ID of the template.

    • template_lookup_id (str): The lookup ID for the template - this is the chid_to_hash ID

      used for MSAs & Templates used in the original RF2AA which is used to retrieve the template from disk.

    • seq_similarity (float): The sequence similarity of the template to the query.

    • atom_array (AtomArray): The atom array of the template.

    • n_res (int): The number of residues in the template.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.template.FeaturizeTemplatesLikeAF3(sequence_encoding: AF3SequenceEncoding, gap_token: str = '<G>', allowed_chain_type: list[ChainType] = [ChainType.POLYPEPTIDE_L, ChainType.RNA], distogram_bins: Tensor = tensor([3.2500, 4.5338, 5.8176, 7.1014, 8.3851, 9.6689, 10.9527, 12.2365, 13.5203, 14.8041, 16.0878, 17.3716, 18.6554, 19.9392, 21.2230, 22.5068, 23.7905, 25.0743, 26.3581, 27.6419, 28.9257, 30.2095, 31.4932, 32.7770, 34.0608, 35.3446, 36.6284, 37.9122, 39.1959, 40.4797, 41.7635, 43.0473, 44.3311, 45.6149, 46.8986, 48.1824, 49.4662, 50.7500]))[source]#

Bases: Transform

A transform that featurizes templates for AlphaFold 3.

This transform generates the following template features (as torch.Tensors):
  • template_restype: [N_templ, N_token] Residue type for each template token.

  • template_pseudo_beta_mask: [N_templ, N_token] Mask indicating if pseudo-beta atom exists.

  • template_backbone_frame_mask: [N_templ, N_token] Mask indicating if coordinates exist for

    all atoms required to compute the backbone frame.

  • template_distogram: [N_templ, N_token, N_token] A pairwise feature indicating the distance

    between Cβ atoms (CA for glycine), discretized into bins.

  • template_unit_vector: [N_templ, N_token, N_token, 3] The unit vector of the displacement

    of the CA atom of all residues within the local frame of each residue.

References

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = ['AddRFTemplates|AddInputFileTemplate', 'AddWithinChainInstanceResIdx', 'AddGlobalTokenIdAnnotation']#
class atomworks.ml.transforms.template.FeaturizeTemplatesLikeRF2AA(n_template: int, init_coords: ~torch.Tensor | float, mask_token_idx: int = 21, encoding: ~atomworks.ml.encoding_definitions.TokenEncoding = Encoding(n_tokens=80, n_atoms_per_token=36)  Token      | 0    | 1    | 2    | 3    | 4    | 5    | 6    | 7    | 8    | 9    | 10   | 11   | 12   | 13   | 14   | 15   | 16   | 17   | 18   | 19   | 20   | 21   | 22   | 23   | 24   | 25   | 26   | 27   | 28   | 29   | 30   | 31   | 32   | 33   | 34   | 35   -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   0 : ALA   | N    | CA   | C    | O    | CB   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 3HB  |      |      |      |      |      |      |      |        1 : ARG   | N    | CA   | C    | O    | CB   | CG   | CD   | NE   | CZ   | NH1  | NH2  |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HG  | 2HG  | 1HD  | 2HD  | HE   | 1HH1 | 2HH1 | 1HH2 | 2HH2   2 : ASN   | N    | CA   | C    | O    | CB   | CG   | OD1  | ND2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HD2 | 2HD2 |      |      |      |      |      |      |        3 : ASP   | N    | CA   | C    | O    | CB   | CG   | OD1  | OD2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  |      |      |      |      |      |      |      |      |        4 : CYS   | N    | CA   | C    | O    | CB   | SG   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | HG   |      |      |      |      |      |      |      |        5 : GLN   | N    | CA   | C    | O    | CB   | CG   | CD   | OE1  | NE2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HG  | 2HG  | 1HE2 | 2HE2 |      |      |      |      |        6 : GLU   | N    | CA   | C    | O    | CB   | CG   | CD   | OE1  | OE2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HG  | 2HG  |      |      |      |      |      |      |        7 : GLY   | N    | CA   | C    | O    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | 1HA  | 2HA  |      |      |      |      |      |      |      |      |      |        8 : HIS   | N    | CA   | C    | O    | CB   | CG   | ND1  | CD2  | CE1  | NE2  |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 2HD  | 1HE  | 2HE  |      |      |      |      |      |        9 : ILE   | N    | CA   | C    | O    | CB   | CG1  | CG2  | CD1  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | HB   | 1HG2 | 2HG2 | 3HG2 | 1HG1 | 2HG1 | 1HD1 | 2HD1 | 3HD1 |      |       10 : LEU   | N    | CA   | C    | O    | CB   | CG   | CD1  | CD2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | HG   | 1HD1 | 2HD1 | 3HD1 | 1HD2 | 2HD2 | 3HD2 |      |       11 : LYS   | N    | CA   | C    | O    | CB   | CG   | CD   | CE   | NZ   |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HG  | 2HG  | 1HD  | 2HD  | 1HE  | 2HE  | 1HZ  | 2HZ  | 3HZ   12 : MET   | N    | CA   | C    | O    | CB   | CG   | SD   | CE   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HG  | 2HG  | 1HE  | 2HE  | 3HE  |      |      |      |       13 : PHE   | N    | CA   | C    | O    | CB   | CG   | CD1  | CD2  | CE1  | CE2  | CZ   |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HD  | 2HD  | 1HE  | 2HE  | HZ   |      |      |      |       14 : PRO   | N    | CA   | C    | O    | CB   | CG   | CD   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | HA   | 1HB  | 2HB  | 1HG  | 2HG  | 1HD  | 2HD  |      |      |      |      |      |       15 : SER   | N    | CA   | C    | O    | CB   | OG   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HG   | HA   | 1HB  | 2HB  |      |      |      |      |      |      |      |       16 : THR   | N    | CA   | C    | O    | CB   | OG1  | CG2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HG1  | HA   | HB   | 1HG2 | 2HG2 | 3HG2 |      |      |      |      |      |       17 : TRP   | N    | CA   | C    | O    | CB   | CG   | CD1  | CD2  | CE2  | CE3  | NE1  | CZ2  | CZ3  | CH2  |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HD  | 1HE  | HZ2  | HH2  | HZ3  | HE3  |      |      |       18 : TYR   | N    | CA   | C    | O    | CB   | CG   | CD1  | CD2  | CE1  | CE2  | CZ   | OH   |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 1HD  | 1HE  | 2HE  | 2HD  | HH   |      |      |      |       19 : VAL   | N    | CA   | C    | O    | CB   | CG1  | CG2  |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | HB   | 1HG1 | 2HG1 | 3HG1 | 1HG2 | 2HG2 | 3HG2 |      |      |      |       20 : UNK   | N    | CA   | C    | O    | CB   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 3HB  |      |      |      |      |      |      |      |       21 : <M>   | N    | CA   | C    | O    | CB   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 3HB  |      |      |      |      |      |      |      |       22 : DA    | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C2'  | C1'  | N9   | C4   | N3   | C2   | N1   | C6   | C5   | N7   | C8   | N6   |      |      | H5'' | H5'  | H4'  | H3'  | H2'' | H2'  | H1'  | H2   | H61  | H62  | H8   |      |       23 : DC    | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C2'  | C1'  | N1   | C2   | O2   | N3   | C4   | N4   | C5   | C6   |      |      |      |      | H5'' | H5'  | H4'  | H3'  | H2'' | H2'  | H1'  | H42  | H41  | H5   | H6   |      |       24 : DG    | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C2'  | C1'  | N9   | C4   | N3   | C2   | N1   | C6   | C5   | N7   | C8   | N2   | O6   |      | H5'' | H5'  | H4'  | H3'  | H2'' | H2'  | H1'  | H1   | H22  | H21  | H8   |      |       25 : DT    | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C2'  | C1'  | N1   | C2   | O2   | N3   | C4   | O4   | C5   | C7   | C6   |      |      |      | H5'' | H5'  | H4'  | H3'  | H2'' | H2'  | H1'  | H3   | H71  | H72  | H73  | H6   |       26 : DN    | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C2'  | C1'  |      |      |      |      |      |      |      |      |      |      |      |      | H5'' | H5'  | H4'  | H3'  | H2'' | H2'  | H1'  |      |      |      |      |      |       27 : A     | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C1'  | C2'  | O2'  | N1   | C2   | N3   | C4   | C5   | C6   | N6   | N7   | C8   | N9   |      | H5'  | H5'' | H4'  | H3'  | H2'  | HO2' | H1'  | H2   | H61  | H62  | H8   |      |       28 : C     | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C1'  | C2'  | O2'  | N1   | C2   | O2   | N3   | C4   | N4   | C5   | C6   |      |      |      | H5'  | H5'' | H4'  | H3'  | H2'  | HO2' | H1'  | H42  | H41  | H5   | H6   |      |       29 : G     | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C1'  | C2'  | O2'  | N1   | C2   | N2   | N3   | C4   | C5   | C6   | O6   | N7   | C8   | N9   | H5'  | H5'' | H4'  | H3'  | H2'  | HO2' | H1'  | H1   | H22  | H21  | H8   |      |       30 : U     | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C1'  | C2'  | O2'  | N1   | C2   | O2   | N3   | C4   | O4   | C5   | C6   |      |      |      | H5'  | H5'' | H4'  | H3'  | H2'  | HO2' | H1'  | H3   | H5   | H6   |      |      |       31 : N     | OP1  | P    | OP2  | O5'  | C5'  | C4'  | O4'  | C3'  | O3'  | C1'  | C2'  | O2'  |      |      |      |      |      |      |      |      |      |      |      | H5'  | H5'' | H4'  | H3'  | H2'  | HO2' | H1'  |      |      |      |      |      |       32 : HIS_D | N    | CA   | C    | O    | CB   | CG   | NE2  | CD2  | CE1  | ND1  |      |      |      |      |      |      |      |      |      |      |      |      |      | H    | HA   | 1HB  | 2HB  | 2HD  | 1HE  | 1HD  |      |      |      |      |      |       33 : 13    |      | 13   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       34 : 33    |      | 33   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       35 : 79    |      | 79   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       36 : 5     |      | 5    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       37 : 4     |      | 4    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       38 : 35    |      | 35   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       39 : 6     |      | 6    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       40 : 20    |      | 20   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       41 : 17    |      | 17   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       42 : 27    |      | 27   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       43 : 24    |      | 24   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       44 : 29    |      | 29   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       45 : 9     |      | 9    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       46 : 26    |      | 26   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       47 : 80    |      | 80   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       48 : 53    |      | 53   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       49 : 77    |      | 77   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       50 : 19    |      | 19   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       51 : 3     |      | 3    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       52 : 12    |      | 12   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       53 : 25    |      | 25   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       54 : 42    |      | 42   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       55 : 7     |      | 7    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       56 : 28    |      | 28   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       57 : 8     |      | 8    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       58 : 76    |      | 76   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       59 : 15    |      | 15   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       60 : 82    |      | 82   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       61 : 46    |      | 46   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       62 : 59    |      | 59   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       63 : 78    |      | 78   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       64 : 75    |      | 75   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       65 : 45    |      | 45   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       66 : 44    |      | 44   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       67 : 16    |      | 16   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       68 : 51    |      | 51   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       69 : 34    |      | 34   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       70 : 14    |      | 14   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       71 : 50    |      | 50   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       72 : 65    |      | 65   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       73 : 52    |      | 52   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       74 : 92    |      | 92   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       75 : 74    |      | 74   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       76 : 23    |      | 23   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       77 : 39    |      | 39   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       78 : 30    |      | 30   |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |       79 : 0     |      | 0    |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |      |     , allowed_chain_types: list[~atomworks.enums.ChainType] = [ChainType.POLYPEPTIDE_L, ChainType.RNA])[source]#

Bases: Transform

A transform that featurizes RFTemplates templates for RF2AA.

This class takes the templates added by the AddRFTemplates transform and featurizes them for use in the RF2AA model. The templates are added to the data under the key template.

- n_template

The number of templates to use.

Type:

int

- mask_token_idx

The index of the mask token. Defaults to 21.

Type:

int

- init_coords

The initial coordinates for the templates.

Type:

torch.Tensor | float

- encoding

The encoding to use for the templates. Defaults to RF2AA_ATOM36_ENCODING.

Type:

TokenEncoding

check_input(data

dict[str, Any]) -> None: Checks the input data for the required keys and types.

forward(data

dict[str, Any]) -> dict[str, Any]: Featurizes the templates and adds them to the data.

Raises:
  • AssertionError – If n_template is not a positive integer.

  • AssertionError – If encoding is not an instance of TokenEncoding.

  • AssertionError – If init_coords is a tensor and its dimensions do not match the expected shape.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = [<class 'atomworks.ml.transforms.template.AddRFTemplates'>, <class 'atomworks.ml.transforms.atom_array.AddWithinPolyResIdxAnnotation'>]#
class atomworks.ml.transforms.template.OneHotTemplateRestype(encoding: AF3SequenceEncoding)[source]#

Bases: Transform

One-hot encode residue types within templates. NOTE: We keep as a separate transform since the AF-3 supplement did not explicitly mention the one-hot encoding of the residue types for templates.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

class atomworks.ml.transforms.template.RF2AATemplate(xyz: Tensor, mask: Tensor, qmap: Tensor, f0d: Tensor, f1d: Tensor, seq: Tensor, ids: list[tuple[str]], label: list[str])[source]#

Bases: object

Data class for holding template information in the RF, RF2 & RF2AA format.

Note

  • RF templates only exist for proteins

  • This is a helper class to cast the templates into a more readable format and also to provide an interface layer that allows us to deal with templates as atom_arrays, if we ever re-create templates or add templates for non-proteins

  • RF-style templates already come encoded in atom14 representation (RFAtom14, not AF2Atom14)

Keys: - xyz: Tensor([1, n_templates x n_atoms_per_template, 14, 3]), raw coordinates of all templates - mask: Tensor([1, n_templates x n_atom_per_template, 14]), mask of all templates - qmap: Tensor([1, n_templates x n_atom_per_template, 2]), alignment mapping of all templates

  • index 0: which index in the query protein this template index matches to

  • index 1: which template index this matches to

  • f0d: Tensor([1, n_templates, 8?]), [0,:,4] holds sequence identity info

  • f1d: Tensor([1, n_templates x n_atoms_per_template, 3]), something in there may be related to template confidence, gaps?

  • seq: Tensor([1, 100677]) (tensor, encoded with Chemdata.aa2num encoding)

  • ids: list[tuple[str]] # Holds the f”{pdb_id}_{chain_id}” of the template

  • label: list[str] # holds the lookup_id for this template

RF2AA_INIT_TEMPLATE_COORDINATES = tensor([[-0.5272,  1.3593,  0.0000],         [ 0.0000,  0.0000,  0.0000],         [ 1.5233,  0.0000,  0.0000],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan],         [    nan,     nan,     nan]])#
property alignment_confidence: ndarray#
property chain_ids: ndarray#
f0d: Tensor#
f1d: Tensor#
ids: list[tuple[str]]#
label: list[str]#
property lookup_id: str#
mask: Tensor#
property max_aligned_query_res_idx: ndarray#
property n_res_per_template: ndarray#
property n_templates: int#
property pdb_ids: ndarray#
qmap: Tensor#
seq: Tensor#
property seq_similarity_to_query: ndarray#
subset(template_idxs: list[int]) RF2AATemplate[source]#

Subset the template to only include the template indices specified in template_idxs.

property template_ids: list[str]#
to_atom_array(template_idx: int) AtomArray[source]#
xyz: Tensor#
class atomworks.ml.transforms.template.RandomSubsampleTemplates(n_template: int = 4)[source]#

Bases: Transform

Subsample the templates for each chain in the template dictionary.

Parameters:

n_template (int) – The maximum possible number of templates to use. Default is 4.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = [<class 'atomworks.ml.transforms.template.FeaturizeTemplatesLikeAF3'>, <class 'atomworks.ml.transforms.template.FeaturizeTemplatesLikeRF2AA'>, 'OneHotTemplateRestype']#
atomworks.ml.transforms.template.add_input_file_template(atom_array: AtomArray) dict[str, list[dict[str, Any]]][source]#
atomworks.ml.transforms.template.blank_af3_template_features(n_templates: int, n_tokens: int, gap_token_index: int) dict[str, Tensor][source]#

Generates blank template features for AF3.

Parameters:
  • n_templates (-) – Number of templates.

  • n_tokens (-) – Number of tokens.

  • gap_token_index (-) – Index of the gap token in the sequence encoding.

Returns:

A dictionary containing initialized template features.

Return type:

dict

atomworks.ml.transforms.template.blank_rf2aa_template_features(n_template: int, n_token: int, encoding: TokenEncoding, mask_token_idx: int, init_coords: Tensor | float) Tensor[source]#

Generates blank template features for RF2AA.

Parameters:
  • n_template (int) – Number of templates.

  • n_token (int) – Number of tokens in the structure.

  • encoding (TokenEncoding) – Encoding object containing token and atom information.

  • mask_token_idx (int, optional) – Index of the mask token. Defaults to 20.

  • init_coords (torch.Tensor | float, optional) – Initial coordinates for the atoms.

Returns:

A tuple containing the following elements:
  • xyz (torch.Tensor): Tensor of shape (n_template, n_token, encoding.n_atoms_per_token, 3) containing the coordinates of the atoms.

  • t1d (torch.Tensor): Tensor of shape (n_template, n_token, encoding.n_tokens) containing the 1D template features.

  • mask (torch.Tensor): Tensor of shape (n_template, n_token, encoding.n_atoms_per_token) containing the mask information.

  • template_origin (np.ndarray): Array of shape (n_template,) containing the origin of the templates.

Return type:

tuple

atomworks.ml.transforms.template.featurize_templates_like_af3(atom_array: AtomArray, templates_by_chain: dict[str, list[dict[str, Any]]], sequence_encoding: AF3SequenceEncoding, gap_token: str = '<G>', allowed_chain_type: list[ChainType] = [ChainType.POLYPEPTIDE_L, ChainType.RNA], distogram_bins: Tensor = tensor([3.2500, 4.5338, 5.8176, 7.1014, 8.3851, 9.6689, 10.9527, 12.2365, 13.5203, 14.8041, 16.0878, 17.3716, 18.6554, 19.9392, 21.2230, 22.5068, 23.7905, 25.0743, 26.3581, 27.6419, 28.9257, 30.2095, 31.4932, 32.7770, 34.0608, 35.3446, 36.6284, 37.9122, 39.1959, 40.4797, 41.7635, 43.0473, 44.3311, 45.6149, 46.8986, 48.1824, 49.4662, 50.7500])) dict[str, Tensor][source]#

Generate AF3 template features for a given (cropped) atom array and the corresponding templates.

NOTE: Number of templates (n_template) is determined by the number of templates in the templates_by_chain dict.

This function adds the following features to the returned dictionary:
  • template_restype: [N_templ, N_token] One-hot encoding of the template sequence.

  • template_pseudo_beta_mask: [N_templ, N_token] Mask indicating if the CB (CA for glycine)

    has coordinates for the template at this residue.

  • template_backbone_frame_mask: [N_templ, N_token] Mask indicating if coordinates exist for

    all atoms required to compute the backbone frame (used in the template_unit_vector feature).

  • template_distogram: [N_templ, N_token, N_token, n_bins] A pairwise feature indicating the distance

    between Cβ atoms (CA for glycine). AF3 uses 38 bins between 3.25 Å and 50.75 Å with one extra bin for distances beyond 50.75 Å.

  • template_unit_vector: [N_templ, N_token, N_token, 3] The unit vector of the displacement

    of the CA atom of all residues within the local frame of each residue.

Parameters:
  • atom_array (-) – The input atom array.

  • templates_by_chain (-) – Dictionary of templates for each chain.

  • sequence_encoding (-) – Encoding for the sequence.

  • gap_token (-) – Token used for gaps in the sequence and as default to pad empty template tokens. NOTE: For templates a token is always a residue

  • allowed_chain_type (-) – List of allowed chain types.

  • distogram_bins (-) – Bins for discretizing distances in the distogram.

Returns:

A dictionary containing the template features.

Return type:

dict

References

NOTE: For templates a token is always a residue since we never align ligands, non-canonicals, PTMs, etc.

atomworks.ml.transforms.template.random_subsample_templates(template_dictionary: dict[str, list[dict[str, Any]]], n_template: int = 4) dict[str, list[dict[str, Any]]][source]#

Subsample the templates for each chain in the template dictionary. We support the “training” implementation with this function; for inference, do not use this function (and instead e.g. set max_n_template=4 to directly take the first 4 templates).

From the AF-3 supplement:
> “Templates are sorted by e-value. At most 20 templates can be returned by our search, and the model uses up to 4

(Ntempl ≤ 4). At inference time we take the first 4. At training time we choose k random templates out of the available n, where k ~ min(Uniform[0, n], 4). This reduces the efficacy of simply copying the template.

Utility Modules#

Convenience utils for common validation checks in transforms.

All checks take a data dictionary as input and raise an error if the check fails.

atomworks.ml.transforms._checks.check_atom_array_annotation(data: dict[str, Any], required: list[str], forbidden: list[str] = [], n_body: int = 1) None[source]#

Check if atom_array key has the annotations specified in required.

atomworks.ml.transforms._checks.check_atom_array_has_bonds(data: dict[str, Any]) None[source]#

Check if atom_array key has bonds.

atomworks.ml.transforms._checks.check_contains_keys(data: dict[str, Any], keys: list[str]) None[source]#

Check if a key is in a dictionary.

atomworks.ml.transforms._checks.check_does_not_contain_keys(data: dict[str, Any], keys: list[str]) None[source]#

Check if a key is not in a dictionary.

atomworks.ml.transforms._checks.check_is_instance(data: dict[str, Any], key: str, expected_type: type) None[source]#

Check if the value of a key in a dictionary is of a certain type.

atomworks.ml.transforms._checks.check_is_shape(data: dict[str, Any], key: str, expected_shape: tuple[int, ...]) None[source]#

Check if the value of a key in a dictionary has a certain shape.

atomworks.ml.transforms._checks.check_nonzero_length(data: dict[str, Any], key: str) None[source]#

Check if the length of the value of a key in a dictionary is nonzero.

class atomworks.ml.transforms.center_random_augmentation.CenterRandomAugmentation(batch_size: int, scale: int = 1, **kwargs)[source]#

Bases: Transform

Centers coordinates and then randomly rotates and translates the input coordinates.

Parameters:
  • batch_size (int) – Number of samples in the batch.

  • scale (int) – Scaling factor for the random rotation and translation. Default is 1.

check_input(data: dict) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = [<class 'atomworks.ml.transforms.diffusion.batch_structures.BatchStructuresForDiffusionNoising'>]#
class atomworks.ml.transforms.rdkit_utils.AddRDKitMoleculesForAtomizedMolecules(hydrogen_policy: Literal['infer', 'remove', 'keep'] = 'keep')[source]#

Bases: Transform

Add RDKit molecules for atomized molecules in the atom array.

This transform converts atomized molecules in the atom array to RDKit Mol objects and stores them in the data dictionary under the “rdkit” key. Each molecule is identified by its pn_unit_iid.

Note

This transform requires the AtomizeByCCDName transform to be applied previously.

Parameters:

data (dict[str, Any]) – A dictionary containing the input data, including the atom array.

Returns:

The updated data dictionary with the added RDKit molecules under the

”rdkit” key.

Return type:

dict[str, Any]

Example

>>> data = {
>>>     "atom_array": AtomArray(...),  # Your atom array here
>>> }
>>> transform = AddRDKitMoleculesForAtomizedMolecules()
>>> data = transform(data)
>>> print(data["rdkit"])
{
    'A_1': <rdkit.Chem.rdchem.Mol object at 0x...>,
    'B_1': <rdkit.Chem.rdchem.Mol object at 0x...>,
    ...
}
check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['CropContiguousLikeAF3', 'CropSpatialLikeAF3']#
requires_previous_transforms: ClassVar[list[str | Transform]] = ['AtomizeByCCDName']#
class atomworks.ml.transforms.rdkit_utils.GenerateRDKitConformers(n_conformers: int = 1, optimize_conformers: bool = True, optimize_kwargs: dict[str, Any] | None = None)[source]#

Bases: Transform

Generate conformers for RDKit molecules stored in the data[“rdkit”] dictionary.

This transform generates conformers for each RDKit molecule in the data dictionary and updates the molecules with the new conformers. The random seed for conformer generation is derived from the global numpy RNG state.

Parameters:
  • data (dict[str, Any]) – A dictionary containing the input data, including RDKit molecules under the “rdkit” key.

  • n_conformers (int) – Number of conformations to generate for each molecule. Default is 1.

Returns:

The updated data dictionary with RDKit molecules containing generated conformers.

Return type:

dict[str, Any]

Example

>>> data = {
>>>     "rdkit": {
>>>         'A_1': <rdkit.Chem.rdchem.Mol object at 0x...>,
>>>         'B_1': <rdkit.Chem.rdchem.Mol object at 0x...>,
>>>     }
>>> }
>>> transform = GenerateRDKitConformers(n_conformers=3)
>>> data = transform(data)
>>> print(data["rdkit"]["A_1"].GetNumConformers())
3
check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = ['AddRDKitMoleculesForAtomizedMolecules']#
class atomworks.ml.transforms.rdkit_utils.GetRDKitChiralCenters[source]#

Bases: Transform

Identify chiral centers in the RDKit molecules stored in the data[“rdkit”] dictionary. Returns a dictionary mapping each residue name to a list of chiral centers, e.g:

data[“chiral_centers”] = {

… “ILE”: [

{‘chiral_center_idx’: 1, ‘bonded_explicit_atom_idxs’: [0, 2, 4], ‘chirality’: ‘S’}, {‘chiral_center_idx’: 4, ‘bonded_explicit_atom_idxs’: [1, 5, 6], ‘chirality’: ‘S’}

}

Each chiral center is a dict with a center atom index, 3 or 4 bonded atom indices, and the RDKit-determined chirality.

Uses RDKit molecules first computed in GetAF3ReferenceMoleculeFeatures.

Parameters:

data (dict[str, Any]) – A dictionary containing the input data, including RDKit molecules under the “rdkit” key.

Returns:

The updated data dictionary with chiral_centers containing chiral

centers for each molecule.

Return type:

dict[str, Any]

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

requires_previous_transforms: ClassVar[list[str | Transform]] = ['GetAF3ReferenceMoleculeFeatures']#
atomworks.ml.transforms.rdkit_utils.ccd_code_to_rdkit_with_conformers(ccd_code: str, n_conformers: int, *, seed: int | None = None, timeout: float | None | tuple[float, float] = (3.0, 0.15), timeout_strategy: Literal['signal', 'subprocess'] = 'subprocess', skip_rdkit_conformer_generation: bool = False, **generate_conformers_kwargs) Mol[source]#

Generate an RDKit molecule with conformers for a given residue name.

This function attempts to generate the specified number of conformers for the given CCD code using RDKit’s conformer generation (based on ETKDGv3 per default). If conformer generation fails or times out, it falls back to using the idealized conformer from the CCD entry if one is available. :param ccd_code: The CCD code to generate conformers for. E.g. ‘ALA’ or ‘GLY’, ‘9RH’ etc. :param n_conformers: The number of conformers to generate for the given CCD code. :param seed: The seed for conformer generation. If None, a random seed

is generated using the global numpy RNG.

Parameters:
  • timeout – The timeout for the automorphism search. If None, no timeout is applied and the timeout strategy is ignored (no subprocesses will be spawned). If a tuple, the first element is the offset and the second element is the slope.

  • timeout_strategy – The strategy to use for the timeout. Defaults to “subprocess”.

  • **generate_conformers_kwargs – Additional keyword arguments to pass to the generate_conformers function.

Returns:

An RDKit molecule with the specified number of conformers.

Return type:

Chem.Mol

atomworks.ml.transforms.rdkit_utils.find_automorphisms_with_rdkit(mol: Mol, max_automorphs: int = 1000, timeout: float | None = None, timeout_strategy: Literal['signal', 'subprocess'] = 'subprocess') ndarray[source]#

Find automorphisms of a given RDKit molecule.

This function identifies the automorphisms (symmetry-related atom swaps) of the input molecule and returns them as a numpy array. If the search for automorphisms times out, it returns a single automorphism representing the identity (no swaps).

Parameters:
  • mol (Chem.Mol) – The RDKit molecule for which to find automorphisms.

  • max_automorphs (int) – The maximum number of automorphisms to return. These are deterministically set to be the first max_automorphs automorphisms found by RDKit. For model training it is recommended to deterministically select the automorphisms to be used (as done in this transform) as a model might otherwise be nudged towards a specific automorph in one training step, but that automorph then does not show up in the next training step, leading to a moving target problem.

  • timeout (float | None) – The timeout for the automorphism search. If None, no timeout is applied and the timeout strategy is ignored (no subprocesses will be spawned).

  • timeout_strategy (Literal["signal", "subprocess"]) – The strategy to use for the timeout. Defaults to “subprocess”.

Returns:

A numpy array of shape [n_automorphs, n_atoms, 2], where each element

represents an automorphism as list of paired atom indices (from_idx, to_idx). If the search fails (e.g. due to running out of memory), returns an array with a single automorphism representing the identity (no swaps).

Return type:

automorphs (np.ndarray)

References

Example

>>> from openbabel import pybel
>>> mol = pybel.readstring("smi", "c1c(O)cccc1(O)").OBMol
>>> automorphisms = find_automorphisms(mol)
>>> print(automorphisms)
    [[[0 0]
      [1 1]
      [2 2]
      [3 3]
      [4 4]
      [5 5]
      [6 6]
      [7 7]]
[[0 0]

[1 6] [2 7] [3 5] [4 4] [5 3] [6 1] [7 2]]]

atomworks.ml.transforms.rdkit_utils.generate_conformers(mol: Mol, *, seed: int | None = None, n_conformers: int = 1, method: str = 'ETKDGv3', num_threads: int = 1, hydrogen_policy: Literal['infer', 'remove', 'keep', 'auto'] = 'remove', optimize: bool = True, attempts_with_distance_geometry: int = 10, attempts_with_random_coordinates: int = 10000, **uff_optimize_kwargs: dict) Mol[source]#

Generate conformations for the given molecule.

Parameters:
  • mol (-) – The RDKit molecule to generate conformations for.

  • seed (-) – Random seed for reproducibility. If None, a random seed is used.

  • n_conformers (-) – Number of conformations to generate.

  • method (-) – The method to use for conformer generation. Default is “ETKDGv3”. Allowed methods are: “ETDG”, “ETKDG”, “ETKDGv2”, “ETKDGv3”, “srETKDGv3” See https://rdkit.org/docs/RDKit_Book.html#conformer-generation for details.

  • num_threads (-) – Number of threads to use for parallel computation. Default is 1.

  • hydrogen_policy (-) – Whether to add explicit hydrogens to the molecule. If “remove”, hydrogens are temporarily added for conformer generation, but removed again before returning the molecule. If “keep” the molecule is used as-is (without adding or removing hydrogens). If “auto”, the policy is set to “keep” if the molecule already has explicit hydrogens, otherwise it is set to “remove”. If “infer”, we follow the same behavior as “remove,” but do not remove added hydrogens prior to returning the molecule.

  • optimize (-) – Whether to optimize the generated conformers using UFF. Default is True.

  • **uff_optimize_kwargs (-) –

    Additional keyword arguments for UFF optimization: - maxIters (int): Maximum number of iterations (default 200). - vdwThresh (float): Used to exclude long-range van der Waals interactions

    (default 10.0).

    • ignoreInterfragInteractions (bool): If True, nonbonded terms between fragments will not be added to the forcefield (default True).

Returns:

The molecule with generated conformations.

Return type:

rdkit.Chem.Mol

Note

  • Optimizing conformers (optimize_conformers=True) is recommended for obtaining more realistic and lower-energy conformations. However, it may increase computation time.

  • The ETKDGv3 method is used for conformer generation, which incorporates torsion angle preferences and basic knowledge (e.g. aromatic rings are planar) for improved accuracy.

  • For macrocycles or complex ring systems, you may need to increase the number of conformers generated to ensure good sampling of the conformational space (if a representative ensemble of conformers is what you are after).

Best Practices:
  1. Always add hydrogens before generating conformers unless you have a specific reason not to (e.g., you’re working with a protein structure where hydrogens are already correctly placed).

  2. Use a non-zero seed for reproducibility in research or production environments.

  3. Generate multiple conformers (e.g., 50-100) for flexible molecules to sample the conformational space more thoroughly.

  4. Optimize conformers using UFF or MMFF94 for more realistic geometries, especially if the conformers will be used for further calculations or analysis.

  5. For very large or complex molecules, you may need to adjust parameters such as maxIterations or use more advanced sampling techniques.

References

1. Conformer tutorial: https://rdkit.org/docs/RDKit_Book.html#conformer-generation 1. RDKit Cookbook: https://www.rdkit.org/docs/Cookbook.html 2. Riniker and Landrum, “Better Informed Distance Geometry: Using What We Know To

Improve Conformation Generation”, JCIM, 2015.

atomworks.ml.transforms.rdkit_utils.get_chiral_centers(mol: Mol) list[int][source]#

Identify and return the tetrahedral chiral centers in an RDKit molecule.

Finds all tetrahedral chiral centers in the given molecule and returns their information, including the chiral center atom index and the indices of the atoms bonded to it.

Parameters:

mol (rdkit.Chem.Mol) – The RDKit molecule to analyze.

Returns:

A list of dictionaries, where each dictionary contains:
  • ”chiral_center_idx” (int): The index of the chiral center atom.

  • ”bonded_explicit_atom_idxs” (list[int]): A list of indices of the atoms bonded to the chiral center.

  • ”chirality” (str): The chirality of the center (‘R’ or ‘S’).

Return type:

  • list[dict]

Note

This function will generate a 3D conformation if one is not present, as chirality assignment requires 3D coordinates in RDKit to break the conditional tie between multiple possible chirality centers.

atomworks.ml.transforms.rdkit_utils.get_rdkit_chiral_centers(rdkit_mols: dict[str, Mol]) dict[source]#

Computes the chiral centers for a dictionary of RDKit molecules.

See the GetRDKitChiralCenters transform for more details.

atomworks.ml.transforms.rdkit_utils.optimize_conformers(mol: Mol, numThreads: int = 1, maxIters: int = 200, vdwThresh: float = 10.0, ignoreInterfragInteractions: bool = True) Mol[source]#

Optimize the conformers of an RDKit molecule.

Parameters:
  • mol (-) – The RDKit molecule to optimize.

  • numThreads (-) – Number of threads to use for parallel computation. Default is 1.

  • maxIters (-) – Maximum number of iterations for UFF optimization. Defaults to 200.

  • vdwThresh (-) – Used to exclude long-range van der Waals interactions. Defaults to 10.0.

  • ignoreInterfragInteractions (-) – If True, nonbonded terms between fragments will not be added to the forcefield. Defaults to True.

Returns:

The optimized RDKit molecule.

Return type:

Mol

atomworks.ml.transforms.rdkit_utils.sample_rdkit_conformer_for_atom_array(atom_array: AtomArray, n_conformers: int = 1, seed: int | None = None, timeout: float | None | tuple[float, float] = (3.0, 0.15), timeout_strategy: Literal['signal', 'subprocess'] = 'subprocess', return_mol: bool = False, **generate_conformers_kwargs) AtomArray[source]#

Sample a conformer for a Biotite AtomArray using RDKit.

Parameters:
  • atom_array (-) – The Biotite AtomArray to sample a conformer for.

  • n_conformers (-) – The number of conformers to sample.

  • timeout (-) – The timeout for conformer generation. If None, no timeout is applied. If a tuple, the first element is the offset and the second element is the slope.

  • seed (-) – The seed for conformer generation. If None, a random seed is generated using the global numpy RNG.

  • timeout_strategy (-) – The strategy to use for the timeout. Defaults to “subprocess”.

  • **generate_conformers_kwargs (-) –

    Additional keyword arguments to pass to the generate_conformers function.

Returns:

The AtomArray with updated coordinates from the sampled conformer. - Chem.Mol: The RDKit molecule with the generated conformer.

Return type:

  • AtomArray

Note

This function preserves the original atom order and properties of the input AtomArray.

class atomworks.ml.transforms.rf2aa_assumptions.AssertRF2AAAssumptions[source]#

Bases: Transform

Assert that the given sample satisfies the assumptions required for a successful forward and backward pass through RF2AA.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

atomworks.ml.transforms.rf2aa_assumptions.assert_satisfies_rf2aa_assumptions(sample: dict[str, Any]) None[source]#

Asserts that the given sample satisfies the assumptions required for a successful forward and backward pass through RF2AA.

Submodules#

DNA Transforms#

Diffusion Transforms#

ESM Transforms#

Feature Aggregation Transforms#

MSA Transforms#