Source code for atomworks.ml.transforms.masks
"""Transforms that add masks for an AtomArray to the data"""
import logging
import warnings
from typing import Any
import numpy as np
from biotite.structure import AtomArray
from scipy.spatial import KDTree
from atomworks.ml.transforms._checks import (
check_atom_array_annotation,
check_contains_keys,
check_is_instance,
)
from atomworks.ml.transforms.base import Transform
logger = logging.getLogger("atomworks.ml")
[docs]
def compute_spatial_knn_mask(coords: np.ndarray, k: int) -> np.ndarray:
"""Compute the spatial KNN mask for an atom array"""
assert coords.ndim == 2, "Coordinates must be a 2D array"
num_atoms = coords.shape[0]
# ... filter out 'nan' or 'inf' coordinates as calculating on
# non-nan coordinates will fail and no neighbors can be assigned to
# the atoms without coordinates
is_finite = np.isfinite(coords).all(axis=1)
idx_finite = np.where(is_finite)[0]
if len(idx_finite) < num_atoms:
warnings.warn(
"Some atoms have no coordinates, they will not receive any neighbors in the spatial KNN mask", stacklevel=2
)
assert len(idx_finite) > k + 1, (
f"Not enough atoms to calculate KNN mask with {k} neighbors, "
f"but only {len(idx_finite)} atoms with coordinates found."
)
# ... get the k+1 nearest neighbors for each atom (self included)
kdtree = KDTree(coords[idx_finite])
knn_distances_, knn_indices = kdtree.query(coords[idx_finite], k=k + 1)
# ... convert indices into boolean masks (atoms which had all 'nan' coords
# will have no neighbors and will be excluded from the mask)
mask = np.zeros((num_atoms, num_atoms), dtype=bool)
# ... map indices to full array space
rows = idx_finite[:, None] # Shape: (n_finite, 1)
cols = idx_finite[knn_indices] # Shape: (n_finite, k+1)
mask[rows, cols] = True
# ... set diagonal to 0 to exclude 'self' in mask
np.fill_diagonal(mask, False)
# ... check each atom that had coordinates was assigned k neighbors
assert np.all(mask[idx_finite].sum(axis=1) == k), "Not all rows have k neighbors."
# ... check diagonal is zero
assert np.all(mask.diagonal() == 0), "Diagonal is not zero."
return mask
[docs]
class AddSpatialKNNMask(Transform):
"""
Generate a spatial k-nearest neighbors mask for each atom in the input atom array
and add it to the data with the key 'spatial_knn_masks' (shape: (n_atoms, n_atoms)).
This mask is e.g. used as a local attention mask in diffusion for the input sequence
based on given coordinates.
Args:
num_neighbors (int): The number of neighbors to keep.
max_atoms_in_crop (int): The maximum allowed number of atoms
in the crop. This transform builds an `(n_atoms, n_atoms)` mask, so this
limit on the number of atoms avoids unexpected memory baloons.
The default of 40'000 atoms should allow any crop of <1'538 tokens to pass
(worst case: RNA guanine, which has 26 heavy atoms per token,
resulting in 39,988 atoms for a structure made up of only guanine)
"""
def __init__(self, num_neighbors: int, max_atoms_in_crop: int = 40_000):
self.num_neighbors = num_neighbors
self.max_atoms_in_crop = max_atoms_in_crop
[docs]
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
"""
Generate a local attention mask for the input sequence based on given coordinates. only keep k nearest neighbors
Args:
data (dict[str, Any]): The input data dictionary.
Returns:
dict[str, Any]: The output data dictionary with the added 'spatial_knn_masks' key
- 'spatial_knn_masks' (np.ndarray): Boolean mask of shape (n_atoms, n_atoms)
where True indicates that the atom is a k-nearest neighbor of the other atom.
NOTE: atoms with no coordinates will not receive any neighbors in the mask
(i.e. a row of all False values)
"""
atom_array = data["atom_array"]
# ... compute the masks
k_nn_masks = compute_spatial_knn_mask(atom_array.coord, self.num_neighbors)
data["spatial_knn_masks"] = k_nn_masks
return data