Diffusion Transforms#

This module contains transformations for diffusion-based structure processing.

class atomworks.ml.transforms.diffusion.batch_structures.BatchStructuresForDiffusionNoising(batch_size: int, **kwargs)[source]#

Bases: Transform

Tiles the ground truth structures to match the diffusion batch size.

In AF-3, we first batch input structures (broadcast the ground truth down the batch dimension), and then perform data augmentations such as differentially noising and rotating each structure.

Precise behavior depends on whether the data dictionary already contains the key coord_atom_lvl_to_be_noised:
  • If the data dictionary already contains the key coord_atom_lvl_to_be_noised, we will batch the coordinates found in that key.

  • Otherwise, we will batch the coordinates found in ground_truth.coord_atom_lvl

Performs the following transformation: (n_atoms, 3) -> (diffusion_batch_size, n_atoms, 3)

Parameters:

batch_size (int) – The size of the diffusion batch.

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict[str, Any]) dict[str, Any][source]#

Apply a transformation to the input dictionary and return the transformed dictionary.

Parameters:

data (dict) – The input dictionary to transform.

Returns:

The transformed dictionary.

Return type:

dict

incompatible_previous_transforms: ClassVar[list[str | Transform]] = ['BatchStructuresForDiffusionNoising']#
class atomworks.ml.transforms.diffusion.edm.SampleEDMNoise(sigma_data: float, diffusion_batch_size: int, **kwargs)[source]#

Bases: Transform

check_input(data: dict[str, Any]) None[source]#

Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.

forward(data: dict) dict[source]#

Apply EDM noise sampling to the coordinates that are to be noised.

Parameters:

data (Dict[str, Any]) – The input data dictionary containing the coordinates to be noised.

Returns:

The input data dictionary with the added keys “t” and “noise” containing the sampled timesteps and noise.
  • t (torch.Tensor): A tensor of shape (diffusion_batch_size,) containing sampled time values.

  • noise (torch.Tensor): A tensor of shape (diffusion_batch_size, num_atoms, 3) containing sampled noise for each atom.

Return type:

Dict[str, Any]

requires_previous_transforms: ClassVar[list[str | Transform]] = [<class 'atomworks.ml.transforms.diffusion.batch_structures.BatchStructuresForDiffusionNoising'>]#
atomworks.ml.transforms.diffusion.edm.sample_noise_edm(t: Tensor, num_atoms: int) Tensor[source]#

Based on the timestep t, sample noise for the diffusion process.

Parameters:
  • t (torch.Tensor) – A tensor of shape (diffusion_batch_size,) containing time values.

  • num_atoms (int) – The number of atoms.

Returns:

A tensor of shape (diffusion_batch_size, num_atoms, 3) containing sampled noise.

Return type:

torch.Tensor

atomworks.ml.transforms.diffusion.edm.sample_t_edm(sigma_data: float, diffusion_batch_size: int) Tensor[source]#

Sample timesteps following the EDM paper.

Parameters:
  • sigma_data (float) – The sigma data parameter for scaling.

  • diffusion_batch_size (int) – The size of the batch for diffusion. We will sample this many timesteps.

Returns:

A tensor of shape (diffusion_batch_size,) containing sampled time values.

Return type:

torch.Tensor