Utilities#
This module contains various utility functions and classes used throughout the atomworks.ml package.
Debug Utilities#
- atomworks.ml.utils.debug.save_failed_example_to_disk(example_id: str, fail_dir: str, *, data: dict = {}, rng_state_dict: dict = {}, error_msg: str = '') None [source]#
Attempts to save a failed example to disk as a pickle file.
- Parameters:
example_id (-) – The ID of the example.
fail_dir (-) – The directory where the failed example should be saved. Defaults to a specific path.
rng_state_dict (-) – The random number generator state dictionary.
error_msg (-) – The error message associated with the failure.
- Returns:
None
Error Handling#
- atomworks.ml.utils.error.context(msg: str, cleanup: ~collections.abc.Callable[[], None] = <function <lambda>>, raise_error: bool = True, log_level: int = 40, exc_types: tuple = (<class 'Exception'>, ), logger: ~logging.Logger = <Logger atomworks.ml.utils.error (WARNING)>) Any [source]#
Context manager for handling exceptions with configurable error handling and logging.
This context manager allows you to pass a custom ‘msg’ to error messages, make them try-exceptable and add a cleanup function that will be called when an unrecoverable error occurs.
- Parameters:
msg (-) – Message to prepend to the error description
cleanup (-) – Optional cleanup function to call when an exception occurs. Defaults to no-op
raise_error (-) – If True, logs and re-raises the exception. If False, only logs the exception
log_level (-) – Logging level to use (from logging module constants). Defaults to logging.ERROR
exc_types (-) – Tuple of exception types to catch. Defaults to (Exception,)
logger (-) – Logger to use for logging. Defaults to the global root logger if not provided.
- Yields:
Any – The yielded value from the context block
- Raises:
Exception – Re-raises the caught exception if raise_error is True
Geometry Utilities#
Various geometry utility functions to deal with rigid body transformations in 3D.
- atomworks.ml.utils.geometry.align_atom_arrays(mbl_sele: AtomArray, tgt_sele: AtomArray, mbl_full: AtomArray) tuple[AtomArray, float] [source]#
Computes the transformation that aligns mbl_sele to tgt_sele, then applies that transformation to mbl_full and returns it along with aligment rmsd
- Parameters:
mbl_sele (AtomArray) – An atom array containing atomic coordinates of the array to be transformed, pre-masked to contain only the portion to be aligned.
tgt_sele (AtomArray) – An atom array containing coordinates for mbl_sele to be aligned to. Must be the same size as mbl_sele; should be the same residues / molecules.
mbl_full (AtomArray) – The full atom array to be transformed based on the alignment between mbl_sele and tgt_sle.
- Returns:
an atom array of the same shape as mbl_full, containing the transformed coordinates. float: the RMSD between mbl_sele and tgt_sele following alignment.
- Return type:
AtomArray
- atomworks.ml.utils.geometry.apply_batched_rigid(rigid: tuple[Tensor, Tensor], points: Tensor) Tensor [source]#
Apply a batch of rigid body transformations to a set of batched points via (p -> R @ p + t). (i.e. first rotate then translate)
- Parameters:
rigid (-) – A tuple containing the rotation matrix (R) and translation vector (t) representing the rigid body transformation.
points (-) – A tensor of shape [batch_size, …, 3] representing the points to transform.
- Returns:
A tensor of shape [batch_size, …, 3] representing the transformed points.
- Return type:
torch.Tensor
NOTE: This transforms p from the local frame of the rigid to the global frame.
- atomworks.ml.utils.geometry.apply_inverse_rigid(rigid: tuple[Tensor, Tensor], points: Tensor) Tensor [source]#
Apply the inverse of a rigid body transformation to a set of points via (p -> R^T @ (p - t)).
- Parameters:
rigid (-) – A tuple containing the rotation matrix (R) and translation vector (t) of the rigid body transformation.
points (-) – The points to transform, with shape (…, 3).
- Returns:
The transformed points, with the same shape as the input points.
- Return type:
torch.Tensor
- atomworks.ml.utils.geometry.apply_rigid(rigid: tuple[Tensor, Tensor], points: Tensor) Tensor [source]#
Apply a rigid body transformation to a set of points via (p -> R @ p + t). (i.e. first rotate then translate)
- Parameters:
rigid (-) – A tuple containing the rotation matrix (R) and translation vector (t) representing the rigid body transformation.
points (-) – A tensor of shape […, 3] representing the points to transform.
- Returns:
A tensor of shape […, 3] representing the transformed points.
- Return type:
torch.Tensor
NOTE: This transforms p from the local frame of the rigid to the global frame.
- atomworks.ml.utils.geometry.compose_rigids(rigid1: tuple[Tensor, Tensor], rigid2: tuple[Tensor, Tensor]) tuple[Tensor, Tensor] [source]#
Compose two rigid body transformations (R1, t1) and (R2, t2) to (R2 @ R1, R2 @ t1 + t2).
- Parameters:
rigid1 (-) – First rigid body transformation (R1, t1).
rigid2 (-) – Second rigid body transformation (R2, t2).
- Returns:
Composed rigid body transformation (R_composed, t_composed).
- Return type:
tuple[torch.Tensor, torch.Tensor]
Example
>>> R1, t1 = torch.eye(3), torch.tensor([1.0, 0.0, 0.0]) >>> R2, t2 = torch.eye(3), torch.tensor([0.0, 1.0, 0.0]) >>> R_composed, t_composed = compose_rigids((R1, t1), (R2, t2)) >>> print(R_composed) tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) >>> print(t_composed) tensor([1., 1., 0.])
- atomworks.ml.utils.geometry.get_random_rigid(batch_size: int, scale: float = 1.0, **tensor_kwargs) tuple[Tensor, Tensor] [source]#
Generate random rigid body transformations (R, t).
- Parameters:
batch_size (-) – Number of rigid transformations to generate.
scale (-) – Scale factor for the translation vectors. Defaults to 1.0.
**tensor_kwargs (-) –
Additional keyword arguments to pass to tensor creation functions.
- Returns:
- A `rigid`tuple containing:
rots (torch.Tensor): Batch of random rotation matrices with shape (batch_size, 3, 3).
trans (torch.Tensor): Batch of random translation vectors with shape (batch_size, 3).
- Return type:
tuple[torch.Tensor, torch.Tensor]
Note
If batch_size is 1, the output tensors are squeezed to remove the batch dimension.
- atomworks.ml.utils.geometry.get_random_rots(batch_size: int, **tensor_kwargs) Tensor [source]#
Generate random 3D rotation matrices.
- Parameters:
batch_size (-) – Number of rotation matrices to generate.
device (-) – Device to place the tensors on. Defaults to None.
- Returns:
Batch of random rotation matrices with shape (batch_size, 3, 3).
- Return type:
torch.Tensor
Example
>>> R = get_random_rots(5) >>> print(R.shape) torch.Size([5, 3, 3]) >>> print(torch.allclose(torch.det(R), torch.ones(5))) True
- atomworks.ml.utils.geometry.get_torch_eps(dtype: dtype) float [source]#
Get the smallest positive representable value for a given torch dtype.
- atomworks.ml.utils.geometry.invert_rigid(rigid: tuple[Tensor, Tensor]) tuple[Tensor, Tensor] [source]#
Invert a rigid body transformation (R, t) to (R^T, -R^T @ t).
- Parameters:
rigid (-) – A tuple containing the rotation matrix (R) and translation vector (t) representing the rigid body transformation.
- Returns:
- A tuple containing the inverted rotation matrix (R^T) and
inverted translation vector (-R^T @ t).
- Return type:
tuple[torch.Tensor, torch.Tensor]
Example
>>> R = torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]]) >>> t = torch.tensor([1, 2, 3]) >>> R_inv, t_inv = invert_rigid((R, t)) >>> print(R_inv) tensor([[ 0, 1, 0], [-1, 0, 0], [ 0, 0, 1]]) >>> print(t_inv) tensor([-2, 1, -3])
- atomworks.ml.utils.geometry.masked_center(coord_atom_lvl: ndarray | Tensor, mask_atom_lvl: ndarray | Tensor = None) ndarray | Tensor [source]#
Center the coordinates of the atoms in coord_atom_lvl around the origin using the mask mask_atom_lvl.
Supports both NumPy and PyTorch tensors.
- atomworks.ml.utils.geometry.random_rigid_augmentation(coord_atom_lvl: Tensor, batch_size: int, s: float = 1.0) Tensor [source]#
Apply random rigid body transformations to atomic coordinates.
Generates random rigid body transformations (rotation and translation) for a batch of atomic coordinates and applies these transformations to the input coordinates.
- Parameters:
coord_atom_lvl (torch.Tensor) – A tensor containing atomic coordinates to be transformed. The shape is expected to be (batch_size, num_atoms, 3).
batch_size (int) – The number of transformations to generate and apply, corresponding to the number of coordinate sets in coord_atom_lvl.
s (float, optional) – The translational scale in Angstrom. Random translations will be drawn from N(0, s), i.e. with standard deviation s. The rotational degree of freedom is sampled uniformly random. Defaults to 1.0.
- Returns:
- A tensor of the same shape as coord_atom_lvl, containing the transformed
atomic coordinates.
- Return type:
torch.Tensor
- atomworks.ml.utils.geometry.rigid_from_3_points(x1: Tensor, x2: Tensor, x3: Tensor, eps: float | None = None) tuple[Tensor, Tensor] [source]#
Compute the rigid body transformation (R, t) that leads from the origin into the local frame via the Gram-Schmidt process.
The local frame is centered at x2 with the x-axis pointing towards x3, the y-axis in the plane defined by x1, x2, and x3, and the z-axis perpendicular to this plane.
E.g. if x1=N, x2=CA, x3=C, then the x-axis is the vector pointing CA -> C, the y-axis is in the N-CA-C plane and the z-axis is perpendicular to this plane.
- Parameters:
x1 – torch.Tensor of shape […, 3], coordinates of the first point
x2 – torch.Tensor of shape […, 3], coordinates of the second point (origin of local frame)
x3 – torch.Tensor of shape […, 3], coordinates of the third point
eps – float, small value to avoid division by zero
- Returns:
torch.Tensor of shape […, 3, 3], rotation matrix t: torch.Tensor of shape […, 3], translation vector
- Return type:
R
- Reference:
AF2 supplementary, Algorithm 21 https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf
Example
>>> x1 = torch.tensor([0.0, 0.0, 1.0]) >>> x2 = torch.tensor([0.0, 0.0, 0.0]) >>> x3 = torch.tensor([1.0, 0.0, 0.0]) >>> R, t = rigid_from_3_points(x1, x2, x3) >>> print(R) tensor([[ 1., 0., 0.], [ 0., 0.,-1.], [ 0., 1., 0.]]) >>> print(t) tensor([0., 0., 0.])
I/O Utilities#
- atomworks.ml.utils.io.cache_based_on_subset_of_args(cache_keys: list[str], maxsize: int | None = None) Callable [source]#
Decorator to cache function results based on a subset of its keyword arguments. Most helpful when some arguments may be unhashable types (e.g., dictionaries, AtomArray).
If the value of any of the cache keys is None, the function is executed and the result is not cached.
Note
The wrapped function must use keyword arguments for those specified in cache_keys. Positional arguments are not supported for cache key extraction.
- Parameters:
cache_keys (List[str]) – The names of the keyword arguments to use as the cache key.
maxsize (Optional[int]) – The maximum number of entries to store in the cache. If None, the cache size is unlimited.
- Returns:
A decorator that caches the function results based on the specified keyword arguments.
- Return type:
Callable
Example
@cache_based_on_subset_of_args([‘arg1’], maxsize=2) def function(*, arg1, arg2):
return arg1 + arg2
result1 = function(arg1=1, arg2=2) # Caches with key 1 result2 = function(arg1=1, arg2=3) # Retrieves from cache
- atomworks.ml.utils.io.cache_to_disk_as_pickle(cache_dir: PathLike | None = None, use_gzip: bool = True, directory_depth: int = 2) Callable [source]#
A decorator to cache the results of a function to disk as a pickle file.
Creates a unique cached pickle file for each set of function arguments using an MD5 hash. If the cache file exists, the result is loaded from the file. Otherwise, the function is called, and the result is saved to the cache file.
If cache_dir is None, caching is disabled and the function is always executed.
- Parameters:
cache_dir (PathLike or None) – The directory where cache files will be stored, or None to disable caching.
use_gzip (bool) – Whether to use gzip compression for the cache files.
directory_depth (int) – The depth of the directory structure for sharding cache files.
- Returns:
The wrapped function with optional disk caching enabled.
- Return type:
function
- atomworks.ml.utils.io.convert_af3_model_output_to_atom_array_stack(atom_to_token_map: ndarray[int], pn_unit_iids: ndarray[str], decoded_restypes: ndarray[str], xyz: ndarray, elements: ndarray[int | str], token_is_atomized: ndarray[bool] = None) AtomArrayStack [source]#
Create an AtomArrayStack from AlphaFold-3-type model outputs. Specific to AF-3; may not work with other formats.
- Parameters:
atom_to_token_map (-) – Mapping from atoms to tokens [n_atom]
pn_unit_iids (-) – PN unit IID’s for each token [n_token]
decoded_restype (-) – Decoded residue types for each token [n_token]
xyz (-) – Coordinates of atoms [n_atom, 3] or [batch, n_atom, 3], where batch is the number of structures
elements (-) – Element types for each atom [n_atom]
token_is_atomized (-) – Flags indicating if tokens are atomized [n_token]. If not provided or None, residues with a single atom are considered atomized.
- Returns:
Constructed AtomArrayStack.
- Return type:
AtomArrayStack
- atomworks.ml.utils.io.filter_residue_atoms(residue_atom_array: AtomArray, chem_type: str, elements: ndarray[str]) AtomArray [source]#
Filter out unwanted atoms from a residue (e.g.., hydrogens, leaving groups)
- Parameters:
residue_atom_array (-) – The AtomArray to filter.
chem_type (-) – Type of the chemical chain.
elements (-) – Element types (as strings, e.g., “C”) for each atom in the residue.
- Returns:
Filtered AtomArray.
- Return type:
struc.AtomArray
- atomworks.ml.utils.io.get_sharded_file_path(base_dir: Path, file_hash: str, extension: str, depth: int, chars_per_dir: int = 2, include_subdirectory: bool = False) Path [source]#
Construct a nested file path based on the directory depth.
- Parameters:
base_dir (Path) – The base directory where the files are stored.
file_hash (str) – The hash of the file content or identifier.
extension (str) – The file extension.
depth (int) – The directory nesting depth.
chars_per_dir (int) – The number of characters to use for each directory level.
include_subdirectory (bool) – If True, creates an additional directory with the full hash name.
- Returns:
The constructed path to the file.
- Return type:
Path
Example
>>> get_sharded_file_path("/path/to/cache", "abcdef123456", ".pkl", 2) Path("/path/to/cache/ab/cd/abcdef123456.pkl") >>> get_sharded_file_path("/path/to/cache", "abcdef123456", ".pkl", 3, chars_per_dir=1) Path("/path/to/cache/a/b/c/abcdef123456.pkl") >>> get_sharded_file_path("/path/to/cache", "abcdef123456", ".pkl", 2, include_subdirectory=True) Path("/path/to/cache/ab/cd/abcdef123456/abcdef123456.pkl")
- atomworks.ml.utils.io.open_file(filename: PathLike) TextIO [source]#
Open a file, handling gzipped files if necessary.
- atomworks.ml.utils.io.read_parquet_with_metadata(filepath: PathLike, **kwargs: Any) DataFrame [source]#
Convenience wrapper around pd.read_parquet that preserves metadata.
- Parameters:
filepath – Path to the parquet file.
**kwargs – Additional arguments to pass to pd.read_parquet.
- Returns:
pandas DataFrame with metadata in .attrs attribute
- atomworks.ml.utils.io.to_parquet_with_metadata(df: DataFrame, filepath: PathLike, **kwargs: Any) None [source]#
Convenience wrapper around df.to_parquet that saves table-wide metadata (df.attrs) to the parquet file.
- Parameters:
df – pandas DataFrame to save.
filepath – Path where to save the parquet file.
**kwargs – Additional arguments to pass to df.to_parquet.
Miscellaneous Utilities#
- atomworks.ml.utils.misc.argunsort(s: ndarray) ndarray [source]#
Returns the permutation necessary to undo a sort given the argsort array.
An argsort array is an array of indices that sorts another array. This function allows you to get the argsort array, sort your array with it, and then undo the sort without the overhead of sorting again.
- Parameters:
s (numpy.ndarray) – The argsort array.
- Returns:
The permutation array that can be used to undo the sort.
- Return type:
numpy.ndarray
Example
>>> arr = np.array([3, 1, 2]) >>> s = np.argsort(arr) >>> sorted_arr = arr[s] >>> undo_sort = argunsort(s) >>> original_arr = sorted_arr[undo_sort] >>> np.array_equal(original_arr, arr) True
- atomworks.ml.utils.misc.convert_pn_unit_iids_to_pn_unit_ids(pn_unit_iids: list[str]) list[str] [source]#
Convert a list of pn_unit_iid strings to pn_unit_id strings.
Example
>>> pn_unit_iids = ["B_1,C_1", "A_11,B_11"] >>> convert_pn_unit_iids_to_pn_unit_ids(pn_unit_iids) ['B,C', 'A,B']
- atomworks.ml.utils.misc.cumcount(a: ndarray) ndarray [source]#
Helper function to compute the cumulative count of each unique element in an array.
- atomworks.ml.utils.misc.dfill(a: ndarray) ndarray [source]#
Takes an array and returns the indices at which the value changes, repeating each index until the next change occurs.
- Parameters:
a (numpy.ndarray) – The input array.
- Returns:
An array of indices where each index is repeated until a change in value occurs in the input array.
- Return type:
numpy.ndarray
Example
>>> short_list = np.array(list("aaabaaacaaadaaac")) >>> dfill(short_list) array([ 0, 0, 0, 3, 4, 4, 4, 7, 8, 8, 8, 11, 12, 12, 12, 15])
- atomworks.ml.utils.misc.extract_transformation_id_from_pn_unit_iid(pn_unit_iid: str) str [source]#
Extracts the transformation ID from a pn_unit_iid string.
Example
>>> extract_transformation_id_from_pn_unit_iid("A_1,B_1") '1'
- atomworks.ml.utils.misc.get_msa_tax_id(pdb_id: str, chain_id: str) int [source]#
Retrieves the taxonomy ID for a given PDB and chain ID combination.
Parameters: - pdb_id (str): The PDB ID of the protein structure. E.g., “1A2K”. - chain_id (str): The chain ID within the PDB structure. E.g., “A”. Notably, no transformation ID.
Returns: - str: The taxonomy ID corresponding to the combined PDB and chain ID (e.g., “79015”).
- atomworks.ml.utils.misc.grouped_count(data: Tensor, *, mask: Tensor | None = None, groups: list[Tensor] | None = None, n_tokens: int | None = None, dtype: dtype = torch.int64) Tensor [source]#
Counts the occurrence of each token in a data tensor, optionally within specified groups and masked positions. (Time & memory-efficient implementation of grouped_sum accross one-hot-tokens)
NOTE: The special case where groups=None and mask=None corresponds to one-hot token counting.
- Parameters:
data (-) – The input tensor containing token data for which we want to count the occurence of each token.
mask (-) – A boolean mask tensor with True values for all positions to include when conunting. If None, all positions are considered (i.e. mask = True for all positions).
groups (-) – A list of tensors specifying the group assignments for each dimension of the data tensor. If None, each position is its own group for each dimension.
n_tokens (-) – The number of unique tokens. If None, it is inferred from the data tensor.
- Returns:
A tensor containing the count of each token in each group. The shape of the tensor is determined by the group sizes and the number of tokens.
- Return type:
torch.Tensor
Example
>>> msa = torch.tensor( ... [ ... [0, 1, 3, 1, 2], ... [1, 0, 0, 3, 2], ... [2, 2, 1, 0, 1], ... [3, 1, 2, 2, 3], ... [1, 0, 0, 0, 1], ... [2, 1, 3, 3, 1], ... ] ... ) >>> groups = [ ... [0, 1, 2, 2, 1, 0], # groups for dim=0 (=rows) ... [0, 1, 2, 3, 4], # groups for dim=1 (=cols) ... ] >>> group_counts = grouped_count(msa, mask=None, groups=groups) >>> group_counts[0] tensor([ [1, 0, 1, 0], # (corresponds to 0x1 & 2x1 at position 0 in rows 0 & 5) [0, 2, 0, 0], # (corresponds to 1x2 at position 1 in rows 0 & 5) [0, 0, 0, 2], # (corresponds to 3x2 at position 2 in rows 0 & 5) [0, 1, 0, 1], # (corresponds to 1x1 & 3x1 at position 3 in rows 0 & 5) [0, 1, 1, 0] # (corresponds to 2x1 & 1x1 at position 4 in rows 0 & 5) ])
- atomworks.ml.utils.misc.grouped_sum(data: Tensor, assignment: Tensor, num_groups: int, as_float: bool = True) Tensor [source]#
Computes the sum along a tensor, given group indices.
- Parameters:
data (torch.Tensor) – A tensor whose groups are to be summed. Shape: (N, …, D), where N is the number of elements.
assignment (torch.Tensor) – A 1-D tensor containing group indices. Must be int64 (to be compatible with the scatter operation). Shape: (N,).
num_groups (int) – The number of groups.
as_float (bool) – If True, the input data will be converted to float before summing. If not True, then booleans will be added as booleans, not integers.
- Returns:
- A tensor of the same data type as the input data, containing
the sum of elements for each group (cluster). Shape: (num_groups, …, D).
- Return type:
torch.Tensor
Example
>>> data = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) >>> assignment = torch.tensor([1, 1, 1, 1]) >>> num_groups = 2 >>> grouped_sum(data, assignment, num_groups) tensor([[ 6, 8], [10, 12]]) # Explanation: # Group 0: [1, 2] + [5, 6] = [6, 8] # Group 1: [3, 4] + [7, 8] = [10, 12]
- atomworks.ml.utils.misc.hash_sequence(sequence: str) str [source]#
Generate a SHA-256 hash for the given sequence and return a compressed string format of the hash.
- Parameters:
sequence (str) – The sequence to be hashed.
- Returns:
The compressed hash string format.
- Return type:
str
- atomworks.ml.utils.misc.masked_mean(*, mask: Tensor, value: Tensor, axis: int | list[int] | None = None, drop_mask_channel: bool = False, eps: float = 1e-10) Tensor [source]#
Compute the masked mean of a tensor along specified axes.
Parameters: - mask (torch.Tensor): A mask tensor with the same shape as value or with dimensions that can be broadcast to value. - value (torch.Tensor): The input tensor for which the masked mean is to be computed. If memory is a concern, can be float16 or even a bool - the sensitive parts of the computation are in float32. - axis (Optional[Union[int, List[int]]]): The axis or axes along which to compute the mean. If None, the mean is computed over all dimensions. - drop_mask_channel (bool): If True, drops the last channel of the mask (assumes the last dimension is a singleton). - eps (float): A small constant to avoid division by zero.
Returns: - torch.Tensor: The masked mean of value along the specified axes. Given in full precision (float32).
Example: >>> import torch >>> mask = torch.tensor([[1, 0], [1, 1]], dtype=bool) >>> value = torch.tensor([[2.0, 3.0], [4.0, 5.0]], dtype=torch.float16) >>> mask_mean(mask, value, axis=0) tensor([3., 5.]) # float32
Reference: - AF2 Multimer Code (google-deepmind/alphafold)
Nested Dictionary Utilities#
Tools to work with nested dictionaries.
- atomworks.ml.utils.nested_dict.flatten(d: dict[str, Any], *, fuse_keys: str | None = None) dict[tuple[str, ...], Any] | dict[str, Any] [source]#
Flatten a nested dictionary into a single level dictionary with tuple keys, preserving non-dict values.
- Parameters:
d (-) – A nested dictionary to flatten.
fuse_keys (-) – If provided, joins the key tuple elements with this string to create string keys. If None, returns tuple keys.
- Returns:
- A flattened dictionary where nested dict keys become either tuple keys or fused string keys,
but other values remain intact.
- Return type:
dict
Example
>>> d = {"a": {"b": [1, 2]}, "c": {"d": {"e": 3}}, "f": [4, 5]} >>> flatten(d) {('a', 'b'): [1, 2], ('c', 'd', 'e'): 3, ('f',): [4, 5]} >>> flatten(d, fuse_keys=".") {'a.b': [1, 2], 'c.d.e': 3, 'f': [4, 5]}
- atomworks.ml.utils.nested_dict.get(d: dict[tuple[str, ...], Any], key: tuple[str, ...], default: Any = None) Any [source]#
Get a value from a nested dictionary using a tuple key.
Equivalent behavior to .get for nested dictionaries.
- Parameters:
d (-) – A nested dictionary.
key (-) – A tuple of keys to navigate through the dictionary.
default (-) – The value to return if the key is not found.
- Returns:
The value at the specified key. If the key is not found, the default value is returned.
- Return type:
Any
- atomworks.ml.utils.nested_dict.getitem(d: dict[tuple[str, ...], Any], key: tuple[str, ...]) Any [source]#
Get a value from a nested dictionary using a tuple key.
Equivalent behavior to __getitem__ for nested dictionaries.
- Parameters:
d (-) – A nested dictionary.
key (-) – A tuple of keys to navigate through the dictionary.
- Returns:
The value at the specified key.
- Return type:
Any
- atomworks.ml.utils.nested_dict.set(d: dict[tuple[str, ...], Any], key: tuple[str, ...], value: Any) None [source]#
Set a value in a nested dictionary using a tuple key.
Equivalent behavior to __setitem__ for nested dictionaries. Creates intermediate dictionaries if they don’t exist yet.
- Parameters:
d (-) – A nested dictionary.
key (-) – A tuple of keys to navigate through the dictionary.
value (-) – The value to set at the specified key.
- atomworks.ml.utils.nested_dict.unflatten(d: dict[tuple[str, ...] | str, Any], *, split_keys: str | None = None) dict[str, Any] [source]#
Unflatten a flattened dictionary into a nested dictionary.
- Parameters:
d (-) – A flattened dictionary with either tuple keys or string keys.
split_keys (-) – If provided, splits string keys with this string to create tuple keys. If None, expects tuple keys.
- Returns:
A nested dictionary reconstructed from the flattened keys.
- Return type:
dict
Example
>>> d = {("a", "b"): [1, 2], ("c", "d", "e"): 3, ("f",): [4, 5]} >>> unflatten(d) {'a': {'b': [1, 2]}, 'c': {'d': {'e': 3}}, 'f': [4, 5]} >>> d = {"a.b": [1, 2], "c.d.e": 3, "f": [4, 5]} >>> unflatten(d, split_keys=".") {'a': {'b': [1, 2]}, 'c': {'d': {'e': 3}}, 'f': [4, 5]}
NumPy Utilities#
General utility functions for working with numpy arrays.
- atomworks.ml.utils.numpy.get_connected_components_from_adjacency(adjacency: ndarray) list[ndarray] [source]#
Return a list of indices for each connected component according to the given adjacency matrix.
- atomworks.ml.utils.numpy.get_indices_of_non_constant_columns(arr: ndarray) ndarray [source]#
Identify columns where values change between consecutive rows.
- Parameters:
arr (np.ndarray) – A 2D NumPy array where you want to find columns with changing values.
- Returns:
An array of column indices where values change between consecutive rows.
- Return type:
np.ndarray
Example
>>> arr = np.array( ... [ ... [151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161], ... [151, 152, 153, 154, 155, 156, 157, 158, 159, 161, 160], ... ] ... ) >>> find_changing_columns(arr) array([ 9, 10])
- atomworks.ml.utils.numpy.get_nearest_true_index_for_each_false(arr: ndarray) ndarray [source]#
Get the index of the nearest True for each False in the array, breaking ties by choosing the nearest True to the left.
- Parameters:
arr (-) – A boolean numpy array.
- Returns:
An array of length np.sum(~arr) where each entry is the index of the nearest True.
- Return type:
np.ndarray
Example
>>> arr = np.array([False, True, True, False, False, True, False]) >>> get_nearest_true_index_for_each_false(arr) array([1, 2, 5, 5])
- atomworks.ml.utils.numpy.insert_data_by_id_(to_fill: ndarray, to_fill_ids: ndarray, from_data: ndarray, from_data_ids: ndarray, axis: int = 0) ndarray [source]#
Insert data into an array based on matching IDs.
- Parameters:
to_fill (np.ndarray) – Array to be filled.
to_fill_ids (np.ndarray) – Array of IDs corresponding to to_fill.
from_data (np.ndarray) – Data array from which to insert.
from_data_ids (np.ndarray) – Array of IDs corresponding to from_data.
axis (int, optional) – Axis along which to insert data. Defaults to 0.
- Returns:
Array with inserted data.
- Return type:
np.ndarray
Example
>>> to_array = np.zeros((7, 6)) >>> to_ids = np.array([1, 5, 2, 17, 20, 20, 2]) >>> from_array = np.arange(10).repeat(6).reshape(10, 6) >>> from_ids = np.array([1, 2, 5, 6, 7, 21, 22, 23, 20, 25]) >>> insert_data_by_id_(to_array, to_ids, from_array, from_ids) >>> print(to_array) array([[0., 0., 0., 0., 0., 0.], [2., 2., 2., 2., 2., 2.], [1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0.], [8., 8., 8., 8., 8., 8.], [8., 8., 8., 8., 8., 8.], [1., 1., 1., 1., 1., 1.]])
- atomworks.ml.utils.numpy.is_mask_contiguous(mask: ndarray) bool [source]#
Check if a mask is contiguous.
- atomworks.ml.utils.numpy.not_isin(element: ndarray, test_element: ndarray) ndarray [source]#
Return a boolean mask indicating where elements of element are not in test_element.
- Parameters:
element (np.ndarray) – Array to check.
test_element (np.ndarray) – Array to check against.
- Returns:
Boolean mask.
- Return type:
np.ndarray
Example
>>> not_isin(np.array([1, 2, 3, 4, 5]), np.array([2, 4, 6])) array([ True, False, True, False, True])
- atomworks.ml.utils.numpy.select_data_by_id(select_ids: ndarray, data_ids: ndarray, data: ndarray, axis: int = 0) ndarray [source]#
Select data from an array based on matching IDs.
- Parameters:
select_ids (np.ndarray) – Array of IDs to select.
data_ids (np.ndarray) – Array of IDs corresponding to the data.
data (np.ndarray) – Data array from which to select.
axis (int, optional) – Axis along which to select data. Defaults to 0.
- Returns:
Array of selected data.
- Return type:
np.ndarray
- Raises:
AssertionError – If the shape of data along axis does not match the length of data_ids.
AssertionError – If data_ids contains duplicate values.
Example
>>> to_ids = np.array([1, 5, 2, 20, 20, 2]) >>> from_array = np.arange(10).repeat(6).reshape(10, 6) >>> from_ids = np.array([1, 2, 5, 6, 7, 21, 22, 23, 20, 25]) >>> select_data_by_id(to_ids, from_ids, from_array) array([[0., 0., 0., 0., 0., 0.], [2., 2., 2., 2., 2., 2.], [1., 1., 1., 1., 1., 1.], [8., 8., 8., 8., 8., 8.], [8., 8., 8., 8., 8., 8.], [1., 1., 1., 1., 1., 1.]])
- atomworks.ml.utils.numpy.unique_by_first_occurrence(arr: ndarray) ndarray [source]#
Return unique elements of an array while preserving the order of their first occurrence.
- Parameters:
arr (np.ndarray) – Input array.
- Returns:
Array of unique elements in the order of their first occurrence.
- Return type:
np.ndarray
Example
>>> unique_by_first_occurrence(np.array([4, 2, 2, 3, 1, 4])) array([4, 2, 3, 1])
Random Number Generation#
Utilities for managing the random number generators in the current process. Inspired by: Lightning-AI/pytorch-lightning
- atomworks.ml.utils.rng.capture_rng_states(include_cuda: bool = False) dict[str, Any] [source]#
Collect the global random state of torch, torch.cuda, numpy and Python in current process.
- Parameters:
include_cuda (bool) – Whether to include the state of the CUDA RNG. If cuda is not available, the state of the CUDA RNG is not included. Defaults to True.
- atomworks.ml.utils.rng.create_rng_state_from_seeds(np_seed: int | None = None, torch_seed: int | None = None, py_seed: int | None = None) dict[str, Any] [source]#
Create a dictionary of RNG states from the provided seeds. If no seed is provided, the current state of the RNGs is used.
- Parameters:
np_seed (int | None) – The seed for the Numpy RNG.
torch_seed (int | None) – The seed for the PyTorch RNG.
py_seed (int | None) – The seed for the Python built-in RNG.
- atomworks.ml.utils.rng.get_numpy_rng_state_hash(rng: RandomState | None = None) int [source]#
Get the hash of the current state of the Numpy RNG.
- atomworks.ml.utils.rng.get_rng_state_hash(rng_state_dict: dict[str, Any]) int [source]#
Get the hash of the RNG state dictionary.
- atomworks.ml.utils.rng.rng_state(rng_state_dict: dict[str, Any] | None = None, include_cuda: bool = True) Generator[dict[str, Any], None, None] [source]#
- A context manager that resets the global random state on exit to what it was before entering.
Within the context manager, the RNG states are set to the provided rng state in the dictionary.
It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators.
- Args:
- rng_state_dict (dict[str, Any] | None): A dictionary of RNG states to set. It can have the following keys:
“torch”: The state of the PyTorch RNG.
“torch.cuda”: The state of the PyTorch CUDA RNG.
“numpy”: The state of the Numpy RNG.
“python”: The state of the Python built-in RNG.
If no rng_state_dict is provided, the RNG states are set to the current state of the RNGs. If the rng_state_dict only contains a subset of the RNG states, the other RNG states are set to the current state of the RNGs.
- include_cuda (bool): Whether to allow this function to also control the torch.cuda random number generator.
Set this to
False
when using the function in a forked process where CUDA re-initialization is prohibited. Defaults to True.
Example:
- ```
# Outside the context manager print(“NumPy:”, np.random.random(3)) # [0.04810046 0.99270597 0.70612995] print(“PyTorch:”, torch.rand(3)) # tensor([0.1405, 0.4602, 0.4284]) print(“Python random:”, [random.random() for _ in range(3)]) # [0.7406435863188185, 0.5632059276194807, 0.8537007637060476]
# Inside the context manager with fixed seeds with rng_state(create_rng_state_from_seeds(np_seed=42, torch_seed=42, py_seed=42)) as rng_state_dict:
my_state = serialize_rng_state_dict(rng_state_dict) print(”
- Within context manager:”)
print(“NumPy:”, np.random.random(3)) # [0.37454012 0.95071431 0.73199394] print(“PyTorch:”, torch.rand(3)) # tensor([0.8823, 0.9150, 0.3829]) print(“Python random:”, [random.random() for _ in range(3)]) # [0.6394267984578837, 0.025010755222666936, 0.27502931836911926]
# Back to the original state outside the context manager print(”
- Back outside the context manager:”)
print(“NumPy:”, np.random.random(3)) # [0.75479377 0.99594641 0.70411424] print(“PyTorch:”, torch.rand(3)) # tensor([0.2757, 0.5345, 0.1754]) print(“Python random:”, [random.random() for _ in range(3)]) # [0.2194923914916147, 0.8731837332486028, 0.47700011905124995]
# Inside the context manager with fixed seeds with rng_state(eval(my_state)):
print(”
- Within context manager:”)
print(“NumPy:”, np.random.random(3)) # [0.37454012 0.95071431 0.73199394] print(“PyTorch:”, torch.rand(3)) # tensor([0.8823, 0.9150, 0.3829]) print(“Python random:”, [random.random() for _ in range(3)]) # [0.6394267984578837, 0.025010755222666936, 0.27502931836911926]
Testing Utilities#
- atomworks.ml.utils.testing.cached_parse(pdb_id: str, **kwargs) dict [source]#
Wrapper around _cached_parse with caching to return an immutable copy of the output dict
- atomworks.ml.utils.testing.get_pdb_mirror_path(pdbid: str, base_dir: str = None) str [source]#
Convenience util to get the path to a CIF file on the DIGS
- atomworks.ml.utils.testing.is_clash(atom_array_1: AtomArray, atom_array_2: AtomArray, clash_distance: float = 1.0) bool [source]#
Checks for clashes between two arrays. Based on atomworks.ml.preprocessing.process.DataPreprocessor. Recommended to pass in minimal masks of arrays to check to reduce runtime.
Timer Utilities#
Timeout utilities for applying time limits to functions and blocks of code.
Adapted from pnpnpn/timeout-decorator (MIT License) and from chaidiscovery/chai-lab
- exception atomworks.ml.utils.timer.ChildProcessError[source]#
Bases:
Exception
Exception raised when a child process dies unexpectedly.
- atomworks.ml.utils.timer.do_nothing(*args, **kwargs) Callable [source]#
A decorator that does nothing and simply returns the original function.
This decorator can be used as a placeholder or for testing purposes when you want to conditionally apply decorators without changing the code structure.
- Returns:
A decorator function that returns the original function unchanged.
- Return type:
Callable
Example
```python @do_nothing_decorator() def my_function():
return “Hello, World!”
- atomworks.ml.utils.timer.timeout(timeout: float | int | None = None, strategy: Literal['signal', 'subprocess'] = 'subprocess') Callable [source]#
Decorator to apply a timeout to a function.
The signal strategy is more efficient and slightly faster, but does not work in all contexts (e.g. with some C dependencies like RDKit, on certain operating systems). The subprocess strategy is always available, but slightly slower and with a higher overhead.
- atomworks.ml.utils.timer.timeout_using_signal(timeout: float | int | None) Callable [source]#
Build a decorator that applies a timeout to a function using the signal module.
This decorator sets up a signal handler to raise a TimeoutError if the decorated function exceeds the specified timeout duration. It uses the SIGALRM signal to implement the timeout.
Use for example as:
`python result = timeout_using_signal(timeout=10.0)(my_function)(*args, **kwargs) `
- Parameters:
timeout (float | int | None) – The timeout duration in seconds.
- Returns:
A decorator function that can be applied to other functions to add timeout functionality.
- Return type:
Callable
- atomworks.ml.utils.timer.timeout_using_subprocess(timeout: float | int | None) Callable [source]#
Force function to timeout after specified time.
The returned decorator uses a subprocess to execute the function, allowing for timeout functionality even for CPU-bound operations that cannot be interrupted by signals.
- Parameters:
timeout (float | int | None) – The maximum time in seconds allowed for the function to execute.
- Returns:
A decorator that can be applied to a function.
- Return type:
Callable
- Raises:
TimeoutError – If the function does not return before the timeout.
ChildProcessException – If the child process dies unexpectedly.
Token Utilities#
- atomworks.ml.utils.token.apply_and_spread_token_wise(atom_array: AtomArray, data: ndarray, function: Callable, axis: int | None = None, token_starts: ndarray | None = None) ndarray [source]#
Apply a function token wise and then spread the result to the atoms.
- atomworks.ml.utils.token.apply_segment_wise_2d(array: ndarray, segment_start_end_idxs: ndarray, reduce_func: Callable) ndarray [source]#
Reduces a 2D array by applying a reduction function to specified segments along both rows and columns.
NOTE: Segments must be contiguous, rectangular blocks (sub-matrices) of the 2D array.
- Parameters:
array (np.ndarray) – A 2D numpy array to be reduced.
group_start_end_idxs (np.ndarray) – A 1D numpy array of indices indicating the start and end of each block. The first element must be 0 and the last element must be the number of rows in the array.
reduce_func (Callable) – A function to apply to each segment. This function should take an array and return a reduced value.
- Returns:
A 2D numpy array that has been reduced along both rows and columns.
- Return type:
np.ndarray
Example
>>> array = np.array([ [1, 2, 3], [4, 5, 6], [7, 8, 9] ]) >>> segment_start_end_idxs = np.array([0, 2, 3]) >>> apply_segment_wise_2d(array, segment_start_end_idxs, reduce_func=np.sum) array([ [12, 9], [15, 9] ])
- atomworks.ml.utils.token.apply_token_wise(array: AtomArray, data: ndarray, function: Callable, axis: int | None = None, token_starts: ndarray | None = None) ndarray [source]#
Analogous to biotite’s apply_residue_wise.
- atomworks.ml.utils.token.get_af3_token_center_coords(atom_array: AtomArray) ndarray [source]#
Returns the center coordinates of the tokens in the atom array as per the AF3 definition.
- For each token we also designate a token center atom, used in various places below:
CA for standard amino acids
C1’ for standard nucleotides
For other cases take the first and only atom as they are tokenized per-atom.
If a token center cannot be assigned (e.g. because the token center atom is unoccupied), the center coordinate is set to np.nan.
- Parameters:
atom_array (AtomArray) – The atom array to get the center coordinates of.
- Returns:
The center coordinates of the tokens in the atom array.
- Return type:
np.ndarray
- Reference:
Example
>>> # Contrived example showing only a few tokens and annotations per residue for illustration >>> array = AtomArray( res_name="ALA", atom_name="CA", coord=np.array([0, 0, 0]), res_name="ALA", atom_name="C", coord=np.array([1, 0, 0]), res_name="ALA", atom_name="O", coord=np.array([2, 0, 0]), res_name="NAP", atom_name="P1", coord=np.array([3, 0, 0]), res_name="U", atom_name="C1'", coord=np.array([4, 0, 0]), ) >>> get_af3_token_center_coords(array) array([[0, 0, 0], [3, 0, 0], [4, 0, 0]])
- atomworks.ml.utils.token.get_af3_token_center_idxs(atom_array: AtomArray) ndarray [source]#
Returns the indices of the center atoms of the tokens in the atom array as per the AF3 definition.
- atomworks.ml.utils.token.get_af3_token_center_masks(atom_array: AtomArray) ndarray [source]#
Returns a boolean mask indicating the center atoms of the tokens in the atom array as per the AF3 definition.
NOTE: “Center” atoms are distinct from “representative” atoms, which are used during distogram prediction (and more closely represent the center of mass).
- For each token we also designate a token center atom, used in various places below:
CA for standard amino acids
C1’ for standard nucleotides
For other cases take the first and only atom as they are tokenized per-atom.
- Parameters:
atom_array (AtomArray) – The atom array to get the center atoms of.
- Returns:
A boolean mask indicating the center atoms of the tokens in the atom array.
- Return type:
np.ndarray
- atomworks.ml.utils.token.get_af3_token_representative_coords(atom_array: AtomArray) ndarray [source]#
Returns the representative coordinates of the tokens in the atom array.
See “get_af3_token_representative_masks” for more details on what constitutes a representative atom.
- Parameters:
atom_array (AtomArray) – The atom array to get the representative coordinates of.
- Returns:
The representative coordinates of the tokens in the atom array.
- Return type:
np.ndarray
- atomworks.ml.utils.token.get_af3_token_representative_idxs(atom_array: AtomArray) ndarray [source]#
Returns the indices of the representative atoms of the tokens in the atom array.
See “get_af3_token_representative_masks” for more details on what constitutes a representative atom.
- Parameters:
atom_array (AtomArray) – The atom array to get the representative atom indices from.
- Returns:
An array of indices corresponding to the representative atoms of the tokens.
- Return type:
np.ndarray
- atomworks.ml.utils.token.get_af3_token_representative_masks(atom_array: AtomArray) ndarray [source]#
Returns a boolean mask indicating the representative atoms of the tokens in the atom array.
- From the AF-3 supplement, section 4.4. (Distogram prediction):
> …where the pairwise token distances use the representative atom for each token: CB for protein residues (CA for glycine), C4 for purines and C2 for pyrimidines. All ligands already have a single atom per token.
NOTE: “Representative” atoms are distinct from “center” atoms, which are used during cropping.
- Parameters:
atom_array (AtomArray) – The atom array to get the representative atoms of.
- Returns:
A boolean mask indicating the representative atoms of the tokens in the atom array.
- Return type:
np.ndarray
- atomworks.ml.utils.token.get_token_count(array: AtomArray) int [source]#
Returns the number of distinct tokens in the atom array.
This function counts the number of tokens based on the changes in the atom array’s annotations. It will match the behavior of biotite.structure.get_residue_count when the atom array does not have the atomize annotation or if atomize is False for all atoms.
- Returns:
The number of distinct tokens in the atom array.
- Return type:
int
- atomworks.ml.utils.token.get_token_masks(array: AtomArray, indices: ndarray) ndarray [source]#
Get boolean masks indicating the tokens to which the given atom indices belong.
- Parameters:
array (-) – The atom array (stack) to determine the residues from.
indices (-) – An array of indices indicating the atoms to get the corresponding residues for. Negative indices are not allowed.
- Returns:
- A 2D boolean array where each row corresponds to a given index
in indices. Each row masks the atoms that belong to the same residue as the atom at the given index.
- Return type:
residues_masks (ndarray, dtype=bool)
See also
get_residue_masks_for
get_token_starts
get_token_starts_for
get_token_positions
- atomworks.ml.utils.token.get_token_starts(array: AtomArray, add_exclusive_stop: bool = False) ndarray [source]#
Get indices for an atom array, each indicating the beginning of a token.
Inspired by biotite.structure.get_residue_starts.
- A new token starts:
If atomize is True
If either the chain ID, residue ID, insertion code or residue name changes from one to the next atom.
- Parameters:
array (AtomArray) – The atom array to get the token starts from.
add_exclusive_stop (bool, optional) – If True, add an exclusive stop to the token starts for the last residue. Defaults to False.
- Returns:
An array of indices indicating the beginning of each token.
- Return type:
np.ndarray
- atomworks.ml.utils.token.get_token_starts_for(array: AtomArray, indices: ndarray) ndarray [source]#
Retrieves the indices that point to the start of the token for each specified atom index.
This function is useful for identifying the beginning of the token associated with each atom in the provided indices. It is particularly relevant in contexts where atoms are grouped into tokens based on their annotations.
- Parameters:
array (-) – The atom array (or stack) from which to determine the residue starts.
indices (-) – An array of atom indices for which the corresponding residue starts are to be retrieved. Negative indices are not permitted.
- Returns:
- An array of indices pointing to the start of
the tokens corresponding to the input indices.
- Return type:
start_indices (ndarray, dtype=int, shape=(k,))
See also
get_residue_starts_for
get_token_starts
get_token_masks
get_token_positions