"""Various geometry utility functions to deal with rigid body transformations in 3D."""
import numpy as np
import torch
from biotite.structure import AtomArray, rmsd, superimpose
from einops import einsum, rearrange
from torch.nn.functional import normalize
from atomworks.ml.common import default
[docs]
def get_torch_eps(dtype: torch.dtype) -> float:
"""Get the smallest positive representable value for a given torch dtype."""
return torch.finfo(dtype).eps
[docs]
def rigid_from_3_points(
x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor, eps: float | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute the rigid body transformation (R, t) that leads from the origin into the local frame
via the Gram-Schmidt process.
The local frame is centered at x2 with the x-axis pointing towards x3, the y-axis in the plane
defined by x1, x2, and x3, and the z-axis perpendicular to this plane.
E.g. if x1=N, x2=CA, x3=C, then the x-axis is the vector pointing CA -> C, the y-axis
is in the N-CA-C plane and the z-axis is perpendicular to this plane.
Args:
x1: torch.Tensor of shape [..., 3], coordinates of the first point
x2: torch.Tensor of shape [..., 3], coordinates of the second point (origin of local frame)
x3: torch.Tensor of shape [..., 3], coordinates of the third point
eps: float, small value to avoid division by zero
Returns:
R: torch.Tensor of shape [..., 3, 3], rotation matrix
t: torch.Tensor of shape [..., 3], translation vector
Reference:
- AF2 supplementary, Algorithm 21
https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf
Example:
>>> x1 = torch.tensor([0.0, 0.0, 1.0])
>>> x2 = torch.tensor([0.0, 0.0, 0.0])
>>> x3 = torch.tensor([1.0, 0.0, 0.0])
>>> R, t = rigid_from_3_points(x1, x2, x3)
>>> print(R)
tensor([[ 1., 0., 0.],
[ 0., 0.,-1.],
[ 0., 1., 0.]])
>>> print(t)
tensor([0., 0., 0.])
"""
eps = default(eps, get_torch_eps(x1.dtype))
# Compute the x-axis of the local frame (pointing from x2 to x3)
x_axis = x3 - x2
x_axis = normalize(x_axis, dim=-1, eps=eps)
# Compute the y-axis of the local frame (in the plane defined by x1, x2, x3)
xy_vec = x1 - x2
y_axis = xy_vec - x_axis * torch.sum(x_axis * xy_vec, dim=-1, keepdim=True)
y_axis = normalize(y_axis, dim=-1, eps=eps)
# Compute the z-axis as the cross product of x_axis and y_axis
# (normalized & right-handed as a result)
z_axis = torch.cross(x_axis, y_axis, dim=-1)
# Construct the rotation matrix
rots = torch.stack([x_axis, y_axis, z_axis], dim=-1)
# The translation vector is simply x2
trans = x2
return rots, trans
[docs]
def apply_rigid(
rigid: tuple[torch.Tensor, torch.Tensor],
points: torch.Tensor,
) -> torch.Tensor:
"""
Apply a rigid body transformation to a set of points via (p -> R @ p + t).
(i.e. first rotate then translate)
Args:
- rigid (tuple[torch.Tensor, torch.Tensor]): A tuple containing the rotation matrix (R) and
translation vector (t) representing the rigid body transformation.
- points (torch.Tensor): A tensor of shape [..., 3] representing the points to transform.
Returns:
- torch.Tensor: A tensor of shape [..., 3] representing the transformed points.
NOTE: This transforms `p` from the local frame of the `rigid` to the global frame.
"""
rots, trans = rigid
return einsum(rots, points, "... i j, ... j -> ... i") + trans
[docs]
def apply_batched_rigid(
rigid: tuple[torch.Tensor, torch.Tensor],
points: torch.Tensor,
) -> torch.Tensor:
"""
Apply a batch of rigid body transformations to a set of batched points via (p -> R @ p + t).
(i.e. first rotate then translate)
Args:
- rigid (tuple[torch.Tensor, torch.Tensor]): A tuple containing the rotation matrix (R) and
translation vector (t) representing the rigid body transformation.
- points (torch.Tensor): A tensor of shape [batch_size, ..., 3] representing the points to transform.
Returns:
- torch.Tensor: A tensor of shape [batch_size, ..., 3] representing the transformed points.
NOTE: This transforms `p` from the local frame of the `rigid` to the global frame.
"""
rots, trans = rigid
batch, length, _ = points.shape
assert rots.shape == (batch, 3, 3), "rotation dimension must match the points dimension"
assert trans.shape == (batch, 3), "translation dimension must match the points dimension"
trans = trans.unsqueeze(1).expand(-1, length, -1)
return einsum(rots, points, "b i j, b l j -> b l i") + trans
[docs]
def invert_rigid(rigid: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""
Invert a rigid body transformation (R, t) to (R^T, -R^T @ t).
Args:
- rigid (tuple[torch.Tensor, torch.Tensor]): A tuple containing the rotation matrix (R) and
translation vector (t) representing the rigid body transformation.
Returns:
- tuple[torch.Tensor, torch.Tensor]: A tuple containing the inverted rotation matrix (R^T) and
inverted translation vector (-R^T @ t).
Example:
>>> R = torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]])
>>> t = torch.tensor([1, 2, 3])
>>> R_inv, t_inv = invert_rigid((R, t))
>>> print(R_inv)
tensor([[ 0, 1, 0],
[-1, 0, 0],
[ 0, 0, 1]])
>>> print(t_inv)
tensor([-2, 1, -3])
"""
rots, trans = rigid
inv_rots = rearrange(rots, "... i j->... j i")
inv_trans = -einsum(inv_rots, trans, "... i j, ... j->... i")
return inv_rots, inv_trans
[docs]
def compose_rigids(
rigid1: tuple[torch.Tensor, torch.Tensor],
rigid2: tuple[torch.Tensor, torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compose two rigid body transformations (R1, t1) and (R2, t2) to (R2 @ R1, R2 @ t1 + t2).
Args:
- rigid1 (tuple[torch.Tensor, torch.Tensor]): First rigid body transformation (R1, t1).
- rigid2 (tuple[torch.Tensor, torch.Tensor]): Second rigid body transformation (R2, t2).
Returns:
- tuple[torch.Tensor, torch.Tensor]: Composed rigid body transformation (R_composed, t_composed).
Example:
>>> R1, t1 = torch.eye(3), torch.tensor([1.0, 0.0, 0.0])
>>> R2, t2 = torch.eye(3), torch.tensor([0.0, 1.0, 0.0])
>>> R_composed, t_composed = compose_rigids((R1, t1), (R2, t2))
>>> print(R_composed)
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
>>> print(t_composed)
tensor([1., 1., 0.])
"""
rots1, trans1 = rigid1
rots2, trans2 = rigid2
rots_composed = einsum(rots2, rots1, "... i j, ... j k->... i k")
trans_composed = einsum(rots2, trans1, "... i j, ... j->... i") + trans2
return rots_composed, trans_composed
[docs]
def apply_inverse_rigid(
rigid: tuple[torch.Tensor, torch.Tensor],
points: torch.Tensor,
) -> torch.Tensor:
"""
Apply the inverse of a rigid body transformation to a set of points via (p -> R^T @ (p - t)).
Args:
- rigid (tuple[torch.Tensor, torch.Tensor]): A tuple containing the rotation matrix (R) and
translation vector (t) of the rigid body transformation.
- points (torch.Tensor): The points to transform, with shape (..., 3).
Returns:
- torch.Tensor: The transformed points, with the same shape as the input points.
"""
inv_rigid = invert_rigid(rigid)
return apply_rigid(inv_rigid, points)
[docs]
def get_random_rots(batch_size: int, **tensor_kwargs) -> torch.Tensor:
"""
Generate random 3D rotation matrices.
Args:
- batch_size (int): Number of rotation matrices to generate.
- device (torch.device | None): Device to place the tensors on. Defaults to None.
Returns:
- torch.Tensor: Batch of random rotation matrices with shape (batch_size, 3, 3).
Example:
>>> R = get_random_rots(5)
>>> print(R.shape)
torch.Size([5, 3, 3])
>>> print(torch.allclose(torch.det(R), torch.ones(5)))
True
"""
# Generate random matrices
rand_mat = torch.randn(batch_size, 3, 3, **tensor_kwargs)
# Compute QR decomposition
q_decomp, _ = torch.linalg.qr(rand_mat)
# Ensure proper rotation (determinant = 1)
det = torch.det(q_decomp)
q_decomp *= det.unsqueeze(-1).unsqueeze(-1).sign()
return q_decomp
[docs]
def get_random_rigid(batch_size: int, scale: float = 1.0, **tensor_kwargs) -> tuple[torch.Tensor, torch.Tensor]:
"""
Generate random rigid body transformations (R, t).
Args:
- batch_size (int): Number of rigid transformations to generate.
- scale (float, optional): Scale factor for the translation vectors. Defaults to 1.0.
- **tensor_kwargs: Additional keyword arguments to pass to tensor creation functions.
Returns:
- tuple[torch.Tensor, torch.Tensor]: A `rigid`tuple containing:
- rots (torch.Tensor): Batch of random rotation matrices with shape (batch_size, 3, 3).
- trans (torch.Tensor): Batch of random translation vectors with shape (batch_size, 3).
Note:
If batch_size is 1, the output tensors are squeezed to remove the batch dimension.
"""
rots = get_random_rots(batch_size, **tensor_kwargs)
trans = scale * torch.randn(batch_size, 3, **tensor_kwargs)
if batch_size == 1:
rots, trans = rots.squeeze(0), trans.squeeze(0)
return rots, trans
[docs]
def random_rigid_augmentation(coord_atom_lvl: torch.Tensor, batch_size: int, s: float = 1.0) -> torch.Tensor:
"""
Apply random rigid body transformations to atomic coordinates.
Generates random rigid body transformations (rotation and translation)
for a batch of atomic coordinates and applies these transformations to the input coordinates.
Args:
coord_atom_lvl (torch.Tensor): A tensor containing atomic coordinates to be transformed.
The shape is expected to be (batch_size, num_atoms, 3).
batch_size (int): The number of transformations to generate and apply, corresponding to
the number of coordinate sets in `coord_atom_lvl`.
s (float, optional): The translational scale in Angstrom. Random translations will be drawn from N(0, s), i.e. with standard deviation `s`. The rotational degree of freedom is sampled uniformly random.
Defaults to 1.0.
Returns:
torch.Tensor: A tensor of the same shape as `coord_atom_lvl`, containing the transformed
atomic coordinates.
"""
rigid = get_random_rigid(batch_size, scale=s)
# (`get_random_rigid` squeezes dimension for batch_size=1)
if batch_size == 1:
rigid = rigid[0].unsqueeze(0), rigid[1].unsqueeze(0)
return apply_batched_rigid(rigid, coord_atom_lvl)
[docs]
def masked_center(
coord_atom_lvl: np.ndarray | torch.Tensor, mask_atom_lvl: np.ndarray | torch.Tensor = None
) -> np.ndarray | torch.Tensor:
"""Center the coordinates of the atoms in coord_atom_lvl around the origin using the mask mask_atom_lvl.
Supports both NumPy and PyTorch tensors.
"""
if mask_atom_lvl is None:
mask_atom_lvl = (
np.ones(coord_atom_lvl.shape[0], dtype=bool)
if isinstance(coord_atom_lvl, np.ndarray)
else torch.ones(coord_atom_lvl.shape[0], dtype=torch.bool)
)
atoms = coord_atom_lvl[mask_atom_lvl]
center = atoms.mean(axis=0) if isinstance(coord_atom_lvl, np.ndarray) else atoms.mean(dim=0)
coord_atom_lvl = coord_atom_lvl - center
return coord_atom_lvl
[docs]
def align_atom_arrays(mbl_sele: AtomArray, tgt_sele: AtomArray, mbl_full: AtomArray) -> tuple[AtomArray, float]:
"""
Computes the transformation that aligns mbl_sele to tgt_sele,
then applies that transformation to mbl_full and returns it along with aligment rmsd
Args:
mbl_sele (AtomArray): An atom array containing atomic coordinates of the array to
be transformed, pre-masked to contain only the portion to be aligned.
tgt_sele (AtomArray): An atom array containing coordinates for mbl_sele to be aligned to.
Must be the same size as mbl_sele; should be the same residues / molecules.
mbl_full (AtomArray): The full atom array to be transformed based on the alignment between
mbl_sele and tgt_sle.
Returns:
AtomArray: an atom array of the same shape as mbl_full, containing the transformed coordinates.
float: the RMSD between mbl_sele and tgt_sele following alignment.
"""
mbl_fitted, xform = superimpose(tgt_sele, mbl_sele)
mbl_full_xformed = xform.apply(mbl_full)
return mbl_full_xformed, rmsd(mbl_fitted, tgt_sele)