Diffusion Transforms#
This module contains transformations for diffusion-based structure processing.
- class atomworks.ml.transforms.diffusion.batch_structures.BatchStructuresForDiffusionNoising(batch_size: int, **kwargs)[source]#
Bases:
Transform
Tiles the ground truth structures to match the diffusion batch size.
In AF-3, we first batch input structures (broadcast the ground truth down the batch dimension), and then perform data augmentations such as differentially noising and rotating each structure.
- Precise behavior depends on whether the data dictionary already contains the key coord_atom_lvl_to_be_noised:
If the data dictionary already contains the key coord_atom_lvl_to_be_noised, we will batch the coordinates found in that key.
Otherwise, we will batch the coordinates found in ground_truth.coord_atom_lvl
Performs the following transformation: (n_atoms, 3) -> (diffusion_batch_size, n_atoms, 3)
- Parameters:
batch_size (int) – The size of the diffusion batch.
- check_input(data: dict[str, Any]) None [source]#
Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.
- class atomworks.ml.transforms.diffusion.edm.SampleEDMNoise(sigma_data: float, diffusion_batch_size: int, **kwargs)[source]#
Bases:
Transform
- check_input(data: dict[str, Any]) None [source]#
Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.
- forward(data: dict) dict [source]#
Apply EDM noise sampling to the coordinates that are to be noised.
- Parameters:
data (Dict[str, Any]) – The input data dictionary containing the coordinates to be noised.
- Returns:
- The input data dictionary with the added keys “t” and “noise” containing the sampled timesteps and noise.
t (torch.Tensor): A tensor of shape (diffusion_batch_size,) containing sampled time values.
noise (torch.Tensor): A tensor of shape (diffusion_batch_size, num_atoms, 3) containing sampled noise for each atom.
- Return type:
Dict[str, Any]
- atomworks.ml.transforms.diffusion.edm.sample_noise_edm(t: Tensor, num_atoms: int) Tensor [source]#
Based on the timestep t, sample noise for the diffusion process.
- Parameters:
t (torch.Tensor) – A tensor of shape (diffusion_batch_size,) containing time values.
num_atoms (int) – The number of atoms.
- Returns:
A tensor of shape (diffusion_batch_size, num_atoms, 3) containing sampled noise.
- Return type:
torch.Tensor
- atomworks.ml.transforms.diffusion.edm.sample_t_edm(sigma_data: float, diffusion_batch_size: int) Tensor [source]#
Sample timesteps following the EDM paper.
- Parameters:
sigma_data (float) – The sigma data parameter for scaling.
diffusion_batch_size (int) – The size of the batch for diffusion. We will sample this many timesteps.
- Returns:
A tensor of shape (diffusion_batch_size,) containing sampled time values.
- Return type:
torch.Tensor