Source code for atomworks.ml.transforms.diffusion.edm
from typing import Any, ClassVar
import torch
from atomworks.ml.transforms._checks import check_contains_keys
from atomworks.ml.transforms.base import Transform
from atomworks.ml.transforms.diffusion.batch_structures import BatchStructuresForDiffusionNoising
[docs]
def sample_t_edm(sigma_data: float, diffusion_batch_size: int) -> torch.Tensor:
"""
Sample timesteps following the EDM paper.
Args:
sigma_data (float): The sigma data parameter for scaling.
diffusion_batch_size (int): The size of the batch for diffusion. We will sample this many timesteps.
Returns:
torch.Tensor: A tensor of shape (diffusion_batch_size,) containing sampled time values.
"""
# Reference for h-params: NVIDIA EDM Paper (https://arxiv.org/pdf/2206.00364)
t = sigma_data * torch.exp(-1.2 + 1.5 * torch.normal(mean=0, std=1, size=(diffusion_batch_size,)))
return t
[docs]
def sample_noise_edm(t: torch.Tensor, num_atoms: int) -> torch.Tensor:
"""
Based on the timestep t, sample noise for the diffusion process.
Args:
t (torch.Tensor): A tensor of shape (diffusion_batch_size,) containing time values.
num_atoms (int): The number of atoms.
Returns:
torch.Tensor: A tensor of shape (diffusion_batch_size, num_atoms, 3) containing sampled noise.
"""
t_tiled = t[:, None, None].tile(1, num_atoms, 3)
return torch.normal(mean=0, std=1, size=t_tiled.shape) * t_tiled
[docs]
class SampleEDMNoise(Transform):
requires_previous_transforms: ClassVar[list[str | Transform]] = [BatchStructuresForDiffusionNoising]
def __init__(self, sigma_data: float, diffusion_batch_size: int, **kwargs):
super().__init__(**kwargs)
self.sigma_data = sigma_data
self.diffusion_batch_size = diffusion_batch_size
[docs]
def forward(self, data: dict) -> dict:
"""
Apply EDM noise sampling to the coordinates that are to be noised.
Args:
data (Dict[str, Any]): The input data dictionary containing the coordinates to be noised.
Returns:
Dict[str, Any]: The input data dictionary with the added keys "t" and "noise" containing the sampled timesteps and noise.
- t (torch.Tensor): A tensor of shape (diffusion_batch_size,) containing sampled time values.
- noise (torch.Tensor): A tensor of shape (diffusion_batch_size, num_atoms, 3) containing sampled noise for each atom.
"""
t = sample_t_edm(self.sigma_data, self.diffusion_batch_size)
noise = sample_noise_edm(t, data["coord_atom_lvl_to_be_noised"].shape[1])
data["t"] = t
data["noise"] = noise
return data