Source code for atomworks.ml.pipelines.af3

from os import PathLike
from pathlib import Path

import numpy as np
import torch

from atomworks.enums import ChainType
from atomworks.io.constants import AF3_EXCLUDED_LIGANDS, GAP, STANDARD_AA, STANDARD_DNA, STANDARD_RNA
from atomworks.ml.common import exists
from atomworks.ml.encoding_definitions import RF2AA_ATOM36_ENCODING, AF3SequenceEncoding
from atomworks.ml.transforms.af3_reference_molecule import GetAF3ReferenceMoleculeFeatures
from atomworks.ml.transforms.atom_array import (
    AddGlobalAtomIdAnnotation,
    AddGlobalTokenIdAnnotation,
    AddWithinChainInstanceResIdx,
    AddWithinPolyResIdxAnnotation,
    ComputeAtomToTokenMap,
    CopyAnnotation,
)
from atomworks.ml.transforms.atom_frames import (
    AddAtomFrames,
    AddIsRealAtom,
    AddPolymerFrameIndices,
)
from atomworks.ml.transforms.atomize import AtomizeByCCDName, FlagNonPolymersForAtomization
from atomworks.ml.transforms.base import (
    AddData,
    Compose,
    ConditionalRoute,
    ConvertToTorch,
    Identity,
    RandomRoute,
    SubsetToKeys,
    Transform,
)
from atomworks.ml.transforms.bfactor_conditioned_transforms import SetOccToZeroOnBfactor
from atomworks.ml.transforms.bonds import AddAF3TokenBondFeatures
from atomworks.ml.transforms.center_random_augmentation import CenterRandomAugmentation
from atomworks.ml.transforms.chirals import AddAF3ChiralFeatures
from atomworks.ml.transforms.covalent_modifications import (
    FlagAndReassignCovalentModifications,
)
from atomworks.ml.transforms.crop import CropContiguousLikeAF3, CropSpatialLikeAF3
from atomworks.ml.transforms.diffusion.batch_structures import (
    BatchStructuresForDiffusionNoising,
)
from atomworks.ml.transforms.diffusion.edm import SampleEDMNoise
from atomworks.ml.transforms.dna.pad_dna import PadDNA
from atomworks.ml.transforms.encoding import EncodeAF3TokenLevelFeatures, EncodeAtomArray
from atomworks.ml.transforms.feature_aggregation.af3 import AggregateFeaturesLikeAF3
from atomworks.ml.transforms.feature_aggregation.confidence import PackageConfidenceFeats
from atomworks.ml.transforms.featurize_unresolved_residues import (
    MaskPolymerResiduesWithUnresolvedFrameAtoms,
    PlaceUnresolvedTokenAtomsOnRepresentativeAtom,
    PlaceUnresolvedTokenOnClosestResolvedTokenInSequence,
)
from atomworks.ml.transforms.filters import (
    FilterToSpecifiedPNUnits,
    HandleUndesiredResTokens,
    RemoveHydrogens,
    RemoveNucleicAcidTerminalOxygen,
    RemovePolymersWithTooFewResolvedResidues,
    RemoveTerminalOxygen,
    RemoveUnresolvedPNUnits,
)
from atomworks.ml.transforms.msa.msa import (
    EncodeMSA,
    FeaturizeMSALikeAF3,
    FillFullMSAFromEncoded,
    LoadPolymerMSAs,
    PairAndMergePolymerMSAs,
)
from atomworks.ml.transforms.rdkit_utils import GetRDKitChiralCenters
from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
from atomworks.ml.transforms.template import (
    AddInputFileTemplate,
    AddRFTemplates,
    FeaturizeTemplatesLikeAF3,
    OneHotTemplateRestype,
    RandomSubsampleTemplates,
)


[docs] def build_af3_transform_pipeline( *, # Training or inference (required) is_inference: bool, # If True, we skip cropping, etc. # MSA dirs protein_msa_dirs: list[dict], rna_msa_dirs: list[dict], # Recycles n_recycles: int = 5, # Crop params crop_size: int = 384, crop_center_cutoff_distance: float = 15.0, crop_contiguous_probability: float = 0.5, crop_spatial_probability: float = 0.5, max_atoms_in_crop: int | None = None, # Undesired res names undesired_res_names: list[str] = AF3_EXCLUDED_LIGANDS, # Conformer generation params conformer_generation_timeout: float = 5.0, # seconds use_element_for_atom_names_of_atomized_tokens: bool = False, # Template params max_n_template: int = 20, # Maximum number of templates to return from our template search (distinct from n_template) n_template: int = 4, template_max_seq_similarity: float = 60.0, template_min_seq_similarity: float = 10.0, template_min_length: int = 10, template_allowed_chain_types: list[ChainType] = [ ChainType.POLYPEPTIDE_L, ChainType.RNA, ], template_distogram_bins: torch.Tensor = torch.linspace(3.25, 50.75, 38), # noqa: B008 template_default_token: str = GAP, template_lookup_path: PathLike | None = None, template_base_dir: PathLike | None = None, # MSA parameters max_msa_sequences: int = 10_000, # Paper: 16,000, but we only have 10K stored on disk n_msa: int = 10_000, # Paper: ?? I think ~12K? dense_msa: bool = True, # True for AF3 # Cache paths msa_cache_dir: PathLike | str | None = None, sigma_data: float = 16.0, diffusion_batch_size: int = 48, # Whether to include features for confidence head run_confidence_head: bool = False, return_atom_array: bool = True, # DNA pad_dna_p_skip: float = 0.0, b_factor_min: float | None = None, b_factor_max: float | None = None, ) -> Transform: """Build the AF3 pipeline with specified parameters. This function constructs a pipeline of transforms for processing protein structures in a manner similar to AlphaFold 3. The pipeline includes steps for removing hydrogens, adding annotations, atomizing residues, cropping, adding templates, encoding features, and generating reference molecule features. Args: crop_size (int, optional): The size of the crop. Defaults to 384. crop_center_cutoff_distance (float, optional): The cutoff distance for spatial cropping. Defaults to 15.0. crop_contiguous_probability (float, optional): The probability of using contiguous cropping. Defaults to 0.5. crop_spatial_probability (float, optional): The probability of using spatial cropping. Defaults to 0.5. conformer_generation_timeout (float, optional): The timeout for conformer generation in seconds. Defaults to 10.0. Returns: Transform: A composed pipeline of transforms. Raises: AssertionError: If the crop probabilities do not sum to 1.0, if the crop size is not positive, or if the crop center cutoff distance is not positive. Note: The cropping method is chosen randomly based on the provided probabilities. The pipeline includes steps for processing the structure, adding annotations, and generating features required for AF3-like predictions. References: - AlphaFold 3 Supplementary Information. https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf """ if (crop_contiguous_probability > 0 or crop_spatial_probability > 0) and not is_inference: assert np.isclose( crop_contiguous_probability + crop_spatial_probability, 1.0, atol=1e-6 ), "Crop probabilities must sum to 1.0" assert crop_size > 0, "Crop size must be greater than 0" assert crop_center_cutoff_distance > 0, "Crop center cutoff distance must be greater than 0" af3_sequence_encoding = AF3SequenceEncoding() rf2aa_sequence_encoding = RF2AA_ATOM36_ENCODING transforms = [ AddData({"is_inference": is_inference, "run_confidence_head": run_confidence_head}), RemoveHydrogens(), FilterToSpecifiedPNUnits( extra_info_key_with_pn_unit_iids_to_keep="all_pn_unit_iids_after_processing" ), # Filter to non-clashing PN units RemoveTerminalOxygen(), SetOccToZeroOnBfactor(b_factor_min, b_factor_max), RemoveUnresolvedPNUnits(), RemovePolymersWithTooFewResolvedResidues(min_residues=4), MaskPolymerResiduesWithUnresolvedFrameAtoms(), # NOTE: For inference, we must keep UNL to support ligands that are not in the CCD HandleUndesiredResTokens(undesired_res_tokens=undesired_res_names), # e.g., non-standard residues ConditionalRoute( condition_func=lambda data: data.get("is_inference", False), transform_map={ True: Identity(), False: PadDNA(p_skip=pad_dna_p_skip) if pad_dna_p_skip > 0 else Identity(), }, ), FlagAndReassignCovalentModifications(), FlagNonPolymersForAtomization(), AddGlobalAtomIdAnnotation(allow_overwrite=True), AtomizeByCCDName( atomize_by_default=True, res_names_to_ignore=STANDARD_AA + STANDARD_RNA + STANDARD_DNA, move_atomized_part_to_end=False, validate_atomize=False, ), RemoveNucleicAcidTerminalOxygen(), AddWithinChainInstanceResIdx(), AddWithinPolyResIdxAnnotation(), ] # Crop # ... crop around our query pn_unit(s) early, since we don't need the full structure moving forward cropping_transform = Identity() if crop_size is not None: cropping_transform = RandomRoute( transforms=[ CropContiguousLikeAF3( crop_size=crop_size, keep_uncropped_atom_array=True, max_atoms_in_crop=max_atoms_in_crop, ), CropSpatialLikeAF3( crop_size=crop_size, crop_center_cutoff_distance=crop_center_cutoff_distance, keep_uncropped_atom_array=True, max_atoms_in_crop=max_atoms_in_crop, ), ], probs=[crop_contiguous_probability, crop_spatial_probability], ) transforms.append( ConditionalRoute( condition_func=lambda data: data.get("is_inference", False), transform_map={ True: Identity(), False: cropping_transform, # Default to Identity during inference (`is_inference == True`) }, ) ) training_template_loading_transforms = Compose( [ AddRFTemplates( max_n_template=max_n_template, # return at most max_n_template (e.g., 20 in AF-3) from our template search (we will then subsample) pick_top=False, max_seq_similarity=template_max_seq_similarity, min_seq_similarity=template_min_seq_similarity, min_template_length=template_min_length, template_lookup_path=template_lookup_path, template_base_dir=template_base_dir, ), # Subsample templates to n_template (from 20) RandomSubsampleTemplates(n_template=n_template), ] ) inference_template_loading_from_disk = AddRFTemplates( max_n_template=n_template, # return at most n_template (e.g., 4 in AF-3) from our template search (no subsampling) pick_top=True, max_seq_similarity=template_max_seq_similarity, min_seq_similarity=template_min_seq_similarity, min_template_length=template_min_length, template_lookup_path=template_lookup_path, template_base_dir=template_base_dir, ) inference_template_load_from_structure = AddInputFileTemplate() inference_template_loading_transforms = ConditionalRoute( condition_func=lambda data: "is_input_file_templated" in data["atom_array"].get_annotation_categories(), transform_map={ True: inference_template_load_from_structure, False: inference_template_loading_from_disk, }, ) transforms += [ AddGlobalTokenIdAnnotation(), # required for reference molecule features and TokenToAtomMap EncodeAF3TokenLevelFeatures(sequence_encoding=af3_sequence_encoding), GetAF3ReferenceMoleculeFeatures( conformer_generation_timeout=conformer_generation_timeout, use_element_for_atom_names_of_atomized_tokens=use_element_for_atom_names_of_atomized_tokens, ), FindAutomorphismsWithNetworkX(), # Adds the "automorphisms" key to the data dictionary ComputeAtomToTokenMap(), GetRDKitChiralCenters(), AddAF3ChiralFeatures(), ConditionalRoute( condition_func=lambda data: data["is_inference"], transform_map={ False: training_template_loading_transforms, True: inference_template_loading_transforms, }, ), FeaturizeTemplatesLikeAF3( sequence_encoding=af3_sequence_encoding, gap_token=template_default_token, allowed_chain_type=template_allowed_chain_types, distogram_bins=template_distogram_bins, ), ] transforms += [ # ... load and pair MSAs LoadPolymerMSAs( protein_msa_dirs=protein_msa_dirs, rna_msa_dirs=rna_msa_dirs, max_msa_sequences=max_msa_sequences, # maximum number of sequences to load (we later subsample further) msa_cache_dir=Path(msa_cache_dir) if exists(msa_cache_dir) else None, use_paths_in_chain_info=True, # if there are paths specified in the `chain_info` for a given chain, use them ), PairAndMergePolymerMSAs(dense=dense_msa), # ... encode MSA to AF-3 format EncodeMSA( encoding=af3_sequence_encoding, token_to_use_for_gap=af3_sequence_encoding.token_to_idx["<G>"], ), # ... fill MSA, indexing into only the portions of the polymers that are present in the cropped structure FillFullMSAFromEncoded(pad_token=af3_sequence_encoding.token_to_idx["<G>"]), AddAF3TokenBondFeatures(), # ... featurize MSA ConvertToTorch( keys=[ "encoded", "feats", "full_msa_details", ] ), FeaturizeMSALikeAF3( encoding=af3_sequence_encoding, n_recycles=n_recycles, n_msa=n_msa, ), # Prepare coordinates for noising (without modifying the ground truth) # ... add placeholder coordinates for noising CopyAnnotation(annotation_to_copy="coord", new_annotation="coord_to_be_noised"), # ... handling of unresolved residues (note that these Transforms create the "atom_array_to_noise" dictionary, if not already present) PlaceUnresolvedTokenAtomsOnRepresentativeAtom(annotation_to_update="coord_to_be_noised"), PlaceUnresolvedTokenOnClosestResolvedTokenInSequence( annotation_to_update="coord_to_be_noised", annotation_to_copy="coord_to_be_noised", ), # Feature aggregation AggregateFeaturesLikeAF3(), OneHotTemplateRestype(encoding=af3_sequence_encoding), # ... batching and noise sampling for diffusion BatchStructuresForDiffusionNoising(batch_size=diffusion_batch_size), CenterRandomAugmentation(batch_size=diffusion_batch_size), SampleEDMNoise(sigma_data=sigma_data, diffusion_batch_size=diffusion_batch_size), ] confidence_transforms = Compose( [ # Additions required for confidence calculation EncodeAtomArray(rf2aa_sequence_encoding), AddAtomFrames(), AddIsRealAtom(rf2aa_sequence_encoding), AddPolymerFrameIndices(), # wrap it all together PackageConfidenceFeats(), ] ) transforms.append( ConditionalRoute( condition_func=lambda data: data.get("run_confidence_head", False), transform_map={ True: confidence_transforms, False: Identity(), }, ) ) keys_to_keep = [ "example_id", "feats", "t", "noise", "ground_truth", "coord_atom_lvl_to_be_noised", "automorphisms", "symmetry_resolution", "extra_info", ] if run_confidence_head: keys_to_keep.append("confidence_feats") if return_atom_array and is_inference: keys_to_keep.append("atom_array") transforms += [ # Subset to only keys necessary SubsetToKeys(keys_to_keep) ] # ... compose final pipeline pipeline = Compose(transforms) return pipeline