Utilities#

This module contains various utility functions and classes used throughout the atomworks.ml package.

Debug Utilities#

atomworks.ml.utils.debug.save_failed_example_to_disk(example_id: str, fail_dir: str, *, data: dict = {}, rng_state_dict: dict = {}, error_msg: str = '') None[source]#

Attempts to save a failed example to disk as a pickle file.

Parameters:
  • example_id (-) – The ID of the example.

  • fail_dir (-) – The directory where the failed example should be saved. Defaults to a specific path.

  • rng_state_dict (-) – The random number generator state dictionary.

  • error_msg (-) – The error message associated with the failure.

Returns:

None

Error Handling#

atomworks.ml.utils.error.context(msg: str, cleanup: ~collections.abc.Callable[[], None] = <function <lambda>>, raise_error: bool = True, log_level: int = 40, exc_types: tuple = (<class 'Exception'>, ), logger: ~logging.Logger = <Logger atomworks.ml.utils.error (WARNING)>) Any[source]#

Context manager for handling exceptions with configurable error handling and logging.

This context manager allows you to pass a custom ‘msg’ to error messages, make them try-exceptable and add a cleanup function that will be called when an unrecoverable error occurs.

Parameters:
  • msg (-) – Message to prepend to the error description

  • cleanup (-) – Optional cleanup function to call when an exception occurs. Defaults to no-op

  • raise_error (-) – If True, logs and re-raises the exception. If False, only logs the exception

  • log_level (-) – Logging level to use (from logging module constants). Defaults to logging.ERROR

  • exc_types (-) – Tuple of exception types to catch. Defaults to (Exception,)

  • logger (-) – Logger to use for logging. Defaults to the global root logger if not provided.

Yields:

Any – The yielded value from the context block

Raises:

Exception – Re-raises the caught exception if raise_error is True

atomworks.ml.utils.error.format_traceback(tb: str) str[source]#

Format a traceback string with syntax highlighting.

Parameters:

tb (-) – The traceback string to format

Returns:

The formatted traceback with ANSI color codes

Return type:

str

Geometry Utilities#

Various geometry utility functions to deal with rigid body transformations in 3D.

atomworks.ml.utils.geometry.align_atom_arrays(mbl_sele: AtomArray, tgt_sele: AtomArray, mbl_full: AtomArray) tuple[AtomArray, float][source]#

Computes the transformation that aligns mbl_sele to tgt_sele, then applies that transformation to mbl_full and returns it along with aligment rmsd

Parameters:
  • mbl_sele (AtomArray) – An atom array containing atomic coordinates of the array to be transformed, pre-masked to contain only the portion to be aligned.

  • tgt_sele (AtomArray) – An atom array containing coordinates for mbl_sele to be aligned to. Must be the same size as mbl_sele; should be the same residues / molecules.

  • mbl_full (AtomArray) – The full atom array to be transformed based on the alignment between mbl_sele and tgt_sle.

Returns:

an atom array of the same shape as mbl_full, containing the transformed coordinates. float: the RMSD between mbl_sele and tgt_sele following alignment.

Return type:

AtomArray

atomworks.ml.utils.geometry.apply_batched_rigid(rigid: tuple[Tensor, Tensor], points: Tensor) Tensor[source]#

Apply a batch of rigid body transformations to a set of batched points via (p -> R @ p + t). (i.e. first rotate then translate)

Parameters:
  • rigid (-) – A tuple containing the rotation matrix (R) and translation vector (t) representing the rigid body transformation.

  • points (-) – A tensor of shape [batch_size, …, 3] representing the points to transform.

Returns:

A tensor of shape [batch_size, …, 3] representing the transformed points.

Return type:

  • torch.Tensor

NOTE: This transforms p from the local frame of the rigid to the global frame.

atomworks.ml.utils.geometry.apply_inverse_rigid(rigid: tuple[Tensor, Tensor], points: Tensor) Tensor[source]#

Apply the inverse of a rigid body transformation to a set of points via (p -> R^T @ (p - t)).

Parameters:
  • rigid (-) – A tuple containing the rotation matrix (R) and translation vector (t) of the rigid body transformation.

  • points (-) – The points to transform, with shape (…, 3).

Returns:

The transformed points, with the same shape as the input points.

Return type:

  • torch.Tensor

atomworks.ml.utils.geometry.apply_rigid(rigid: tuple[Tensor, Tensor], points: Tensor) Tensor[source]#

Apply a rigid body transformation to a set of points via (p -> R @ p + t). (i.e. first rotate then translate)

Parameters:
  • rigid (-) – A tuple containing the rotation matrix (R) and translation vector (t) representing the rigid body transformation.

  • points (-) – A tensor of shape […, 3] representing the points to transform.

Returns:

A tensor of shape […, 3] representing the transformed points.

Return type:

  • torch.Tensor

NOTE: This transforms p from the local frame of the rigid to the global frame.

atomworks.ml.utils.geometry.compose_rigids(rigid1: tuple[Tensor, Tensor], rigid2: tuple[Tensor, Tensor]) tuple[Tensor, Tensor][source]#

Compose two rigid body transformations (R1, t1) and (R2, t2) to (R2 @ R1, R2 @ t1 + t2).

Parameters:
  • rigid1 (-) – First rigid body transformation (R1, t1).

  • rigid2 (-) – Second rigid body transformation (R2, t2).

Returns:

Composed rigid body transformation (R_composed, t_composed).

Return type:

  • tuple[torch.Tensor, torch.Tensor]

Example

>>> R1, t1 = torch.eye(3), torch.tensor([1.0, 0.0, 0.0])
>>> R2, t2 = torch.eye(3), torch.tensor([0.0, 1.0, 0.0])
>>> R_composed, t_composed = compose_rigids((R1, t1), (R2, t2))
>>> print(R_composed)
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
>>> print(t_composed)
tensor([1., 1., 0.])
atomworks.ml.utils.geometry.get_random_rigid(batch_size: int, scale: float = 1.0, **tensor_kwargs) tuple[Tensor, Tensor][source]#

Generate random rigid body transformations (R, t).

Parameters:
  • batch_size (-) – Number of rigid transformations to generate.

  • scale (-) – Scale factor for the translation vectors. Defaults to 1.0.

  • **tensor_kwargs (-) –

    Additional keyword arguments to pass to tensor creation functions.

Returns:

A `rigid`tuple containing:
  • rots (torch.Tensor): Batch of random rotation matrices with shape (batch_size, 3, 3).

  • trans (torch.Tensor): Batch of random translation vectors with shape (batch_size, 3).

Return type:

  • tuple[torch.Tensor, torch.Tensor]

Note

If batch_size is 1, the output tensors are squeezed to remove the batch dimension.

atomworks.ml.utils.geometry.get_random_rots(batch_size: int, **tensor_kwargs) Tensor[source]#

Generate random 3D rotation matrices.

Parameters:
  • batch_size (-) – Number of rotation matrices to generate.

  • device (-) – Device to place the tensors on. Defaults to None.

Returns:

Batch of random rotation matrices with shape (batch_size, 3, 3).

Return type:

  • torch.Tensor

Example

>>> R = get_random_rots(5)
>>> print(R.shape)
torch.Size([5, 3, 3])
>>> print(torch.allclose(torch.det(R), torch.ones(5)))
True
atomworks.ml.utils.geometry.get_torch_eps(dtype: dtype) float[source]#

Get the smallest positive representable value for a given torch dtype.

atomworks.ml.utils.geometry.invert_rigid(rigid: tuple[Tensor, Tensor]) tuple[Tensor, Tensor][source]#

Invert a rigid body transformation (R, t) to (R^T, -R^T @ t).

Parameters:

rigid (-) – A tuple containing the rotation matrix (R) and translation vector (t) representing the rigid body transformation.

Returns:

A tuple containing the inverted rotation matrix (R^T) and

inverted translation vector (-R^T @ t).

Return type:

  • tuple[torch.Tensor, torch.Tensor]

Example

>>> R = torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]])
>>> t = torch.tensor([1, 2, 3])
>>> R_inv, t_inv = invert_rigid((R, t))
>>> print(R_inv)
tensor([[ 0,  1,  0],
        [-1,  0,  0],
        [ 0,  0,  1]])
>>> print(t_inv)
tensor([-2,  1, -3])
atomworks.ml.utils.geometry.masked_center(coord_atom_lvl: ndarray | Tensor, mask_atom_lvl: ndarray | Tensor = None) ndarray | Tensor[source]#

Center the coordinates of the atoms in coord_atom_lvl around the origin using the mask mask_atom_lvl.

Supports both NumPy and PyTorch tensors.

atomworks.ml.utils.geometry.random_rigid_augmentation(coord_atom_lvl: Tensor, batch_size: int, s: float = 1.0) Tensor[source]#

Apply random rigid body transformations to atomic coordinates.

Generates random rigid body transformations (rotation and translation) for a batch of atomic coordinates and applies these transformations to the input coordinates.

Parameters:
  • coord_atom_lvl (torch.Tensor) – A tensor containing atomic coordinates to be transformed. The shape is expected to be (batch_size, num_atoms, 3).

  • batch_size (int) – The number of transformations to generate and apply, corresponding to the number of coordinate sets in coord_atom_lvl.

  • s (float, optional) – The translational scale in Angstrom. Random translations will be drawn from N(0, s), i.e. with standard deviation s. The rotational degree of freedom is sampled uniformly random. Defaults to 1.0.

Returns:

A tensor of the same shape as coord_atom_lvl, containing the transformed

atomic coordinates.

Return type:

torch.Tensor

atomworks.ml.utils.geometry.rigid_from_3_points(x1: Tensor, x2: Tensor, x3: Tensor, eps: float | None = None) tuple[Tensor, Tensor][source]#

Compute the rigid body transformation (R, t) that leads from the origin into the local frame via the Gram-Schmidt process.

The local frame is centered at x2 with the x-axis pointing towards x3, the y-axis in the plane defined by x1, x2, and x3, and the z-axis perpendicular to this plane.

E.g. if x1=N, x2=CA, x3=C, then the x-axis is the vector pointing CA -> C, the y-axis is in the N-CA-C plane and the z-axis is perpendicular to this plane.

Parameters:
  • x1 – torch.Tensor of shape […, 3], coordinates of the first point

  • x2 – torch.Tensor of shape […, 3], coordinates of the second point (origin of local frame)

  • x3 – torch.Tensor of shape […, 3], coordinates of the third point

  • eps – float, small value to avoid division by zero

Returns:

torch.Tensor of shape […, 3, 3], rotation matrix t: torch.Tensor of shape […, 3], translation vector

Return type:

R

Reference:

Example

>>> x1 = torch.tensor([0.0, 0.0, 1.0])
>>> x2 = torch.tensor([0.0, 0.0, 0.0])
>>> x3 = torch.tensor([1.0, 0.0, 0.0])
>>> R, t = rigid_from_3_points(x1, x2, x3)
>>> print(R)
tensor([[ 1., 0., 0.],
        [ 0., 0.,-1.],
        [ 0., 1., 0.]])
>>> print(t)
tensor([0., 0., 0.])

I/O Utilities#

atomworks.ml.utils.io.cache_based_on_subset_of_args(cache_keys: list[str], maxsize: int | None = None) Callable[source]#

Decorator to cache function results based on a subset of its keyword arguments. Most helpful when some arguments may be unhashable types (e.g., dictionaries, AtomArray).

If the value of any of the cache keys is None, the function is executed and the result is not cached.

Note

The wrapped function must use keyword arguments for those specified in cache_keys. Positional arguments are not supported for cache key extraction.

Parameters:
  • cache_keys (List[str]) – The names of the keyword arguments to use as the cache key.

  • maxsize (Optional[int]) – The maximum number of entries to store in the cache. If None, the cache size is unlimited.

Returns:

A decorator that caches the function results based on the specified keyword arguments.

Return type:

Callable

Example

@cache_based_on_subset_of_args([‘arg1’], maxsize=2) def function(*, arg1, arg2):

return arg1 + arg2

result1 = function(arg1=1, arg2=2) # Caches with key 1 result2 = function(arg1=1, arg2=3) # Retrieves from cache

atomworks.ml.utils.io.cache_to_disk_as_pickle(cache_dir: PathLike | None = None, use_gzip: bool = True, directory_depth: int = 2) Callable[source]#

A decorator to cache the results of a function to disk as a pickle file.

Creates a unique cached pickle file for each set of function arguments using an MD5 hash. If the cache file exists, the result is loaded from the file. Otherwise, the function is called, and the result is saved to the cache file.

If cache_dir is None, caching is disabled and the function is always executed.

Parameters:
  • cache_dir (PathLike or None) – The directory where cache files will be stored, or None to disable caching.

  • use_gzip (bool) – Whether to use gzip compression for the cache files.

  • directory_depth (int) – The depth of the directory structure for sharding cache files.

Returns:

The wrapped function with optional disk caching enabled.

Return type:

function

atomworks.ml.utils.io.convert_af3_model_output_to_atom_array_stack(atom_to_token_map: ndarray[int], pn_unit_iids: ndarray[str], decoded_restypes: ndarray[str], xyz: ndarray, elements: ndarray[int | str], token_is_atomized: ndarray[bool] = None) AtomArrayStack[source]#

Create an AtomArrayStack from AlphaFold-3-type model outputs. Specific to AF-3; may not work with other formats.

Parameters:
  • atom_to_token_map (-) – Mapping from atoms to tokens [n_atom]

  • pn_unit_iids (-) – PN unit IID’s for each token [n_token]

  • decoded_restype (-) – Decoded residue types for each token [n_token]

  • xyz (-) – Coordinates of atoms [n_atom, 3] or [batch, n_atom, 3], where batch is the number of structures

  • elements (-) – Element types for each atom [n_atom]

  • token_is_atomized (-) – Flags indicating if tokens are atomized [n_token]. If not provided or None, residues with a single atom are considered atomized.

Returns:

Constructed AtomArrayStack.

Return type:

  • AtomArrayStack

atomworks.ml.utils.io.filter_residue_atoms(residue_atom_array: AtomArray, chem_type: str, elements: ndarray[str]) AtomArray[source]#

Filter out unwanted atoms from a residue (e.g.., hydrogens, leaving groups)

Parameters:
  • residue_atom_array (-) – The AtomArray to filter.

  • chem_type (-) – Type of the chemical chain.

  • elements (-) – Element types (as strings, e.g., “C”) for each atom in the residue.

Returns:

Filtered AtomArray.

Return type:

  • struc.AtomArray

atomworks.ml.utils.io.get_sharded_file_path(base_dir: Path, file_hash: str, extension: str, depth: int, chars_per_dir: int = 2, include_subdirectory: bool = False) Path[source]#

Construct a nested file path based on the directory depth.

Parameters:
  • base_dir (Path) – The base directory where the files are stored.

  • file_hash (str) – The hash of the file content or identifier.

  • extension (str) – The file extension.

  • depth (int) – The directory nesting depth.

  • chars_per_dir (int) – The number of characters to use for each directory level.

  • include_subdirectory (bool) – If True, creates an additional directory with the full hash name.

Returns:

The constructed path to the file.

Return type:

Path

Example

>>> get_sharded_file_path("/path/to/cache", "abcdef123456", ".pkl", 2)
Path("/path/to/cache/ab/cd/abcdef123456.pkl")
>>> get_sharded_file_path("/path/to/cache", "abcdef123456", ".pkl", 3, chars_per_dir=1)
Path("/path/to/cache/a/b/c/abcdef123456.pkl")
>>> get_sharded_file_path("/path/to/cache", "abcdef123456", ".pkl", 2, include_subdirectory=True)
Path("/path/to/cache/ab/cd/abcdef123456/abcdef123456.pkl")
atomworks.ml.utils.io.open_file(filename: PathLike) TextIO[source]#

Open a file, handling gzipped files if necessary.

atomworks.ml.utils.io.read_parquet_with_metadata(filepath: PathLike, **kwargs: Any) DataFrame[source]#

Convenience wrapper around pd.read_parquet that preserves metadata.

Parameters:
  • filepath – Path to the parquet file.

  • **kwargs – Additional arguments to pass to pd.read_parquet.

Returns:

pandas DataFrame with metadata in .attrs attribute

atomworks.ml.utils.io.to_parquet_with_metadata(df: DataFrame, filepath: PathLike, **kwargs: Any) None[source]#

Convenience wrapper around df.to_parquet that saves table-wide metadata (df.attrs) to the parquet file.

Parameters:
  • df – pandas DataFrame to save.

  • filepath – Path where to save the parquet file.

  • **kwargs – Additional arguments to pass to df.to_parquet.

Miscellaneous Utilities#

atomworks.ml.utils.misc.argunsort(s: ndarray) ndarray[source]#

Returns the permutation necessary to undo a sort given the argsort array.

An argsort array is an array of indices that sorts another array. This function allows you to get the argsort array, sort your array with it, and then undo the sort without the overhead of sorting again.

Parameters:

s (numpy.ndarray) – The argsort array.

Returns:

The permutation array that can be used to undo the sort.

Return type:

numpy.ndarray

Example

>>> arr = np.array([3, 1, 2])
>>> s = np.argsort(arr)
>>> sorted_arr = arr[s]
>>> undo_sort = argunsort(s)
>>> original_arr = sorted_arr[undo_sort]
>>> np.array_equal(original_arr, arr)
True
atomworks.ml.utils.misc.convert_pn_unit_iids_to_pn_unit_ids(pn_unit_iids: list[str]) list[str][source]#

Convert a list of pn_unit_iid strings to pn_unit_id strings.

Example

>>> pn_unit_iids = ["B_1,C_1", "A_11,B_11"]
>>> convert_pn_unit_iids_to_pn_unit_ids(pn_unit_iids)
['B,C', 'A,B']
atomworks.ml.utils.misc.cumcount(a: ndarray) ndarray[source]#

Helper function to compute the cumulative count of each unique element in an array.

Source: https://stackoverflow.com/questions/40602269/how-to-use-numpy-to-get-the-cumulative-count-by-unique-values-in-linear-time

atomworks.ml.utils.misc.dfill(a: ndarray) ndarray[source]#

Takes an array and returns the indices at which the value changes, repeating each index until the next change occurs.

Parameters:

a (numpy.ndarray) – The input array.

Returns:

An array of indices where each index is repeated until a change in value occurs in the input array.

Return type:

numpy.ndarray

Example

>>> short_list = np.array(list("aaabaaacaaadaaac"))
>>> dfill(short_list)
array([ 0,  0,  0,  3,  4,  4,  4,  7,  8,  8,  8, 11, 12, 12, 12, 15])
atomworks.ml.utils.misc.extract_transformation_id_from_pn_unit_iid(pn_unit_iid: str) str[source]#

Extracts the transformation ID from a pn_unit_iid string.

Example

>>> extract_transformation_id_from_pn_unit_iid("A_1,B_1")
'1'
atomworks.ml.utils.misc.get_msa_tax_id(pdb_id: str, chain_id: str) int[source]#

Retrieves the taxonomy ID for a given PDB and chain ID combination.

Parameters: - pdb_id (str): The PDB ID of the protein structure. E.g., “1A2K”. - chain_id (str): The chain ID within the PDB structure. E.g., “A”. Notably, no transformation ID.

Returns: - str: The taxonomy ID corresponding to the combined PDB and chain ID (e.g., “79015”).

atomworks.ml.utils.misc.grouped_count(data: Tensor, *, mask: Tensor | None = None, groups: list[Tensor] | None = None, n_tokens: int | None = None, dtype: dtype = torch.int64) Tensor[source]#

Counts the occurrence of each token in a data tensor, optionally within specified groups and masked positions. (Time & memory-efficient implementation of grouped_sum accross one-hot-tokens)

NOTE: The special case where groups=None and mask=None corresponds to one-hot token counting.

Parameters:
  • data (-) – The input tensor containing token data for which we want to count the occurence of each token.

  • mask (-) – A boolean mask tensor with True values for all positions to include when conunting. If None, all positions are considered (i.e. mask = True for all positions).

  • groups (-) – A list of tensors specifying the group assignments for each dimension of the data tensor. If None, each position is its own group for each dimension.

  • n_tokens (-) – The number of unique tokens. If None, it is inferred from the data tensor.

Returns:

A tensor containing the count of each token in each group. The shape of the tensor is determined by the group sizes and the number of tokens.

Return type:

  • torch.Tensor

Example

>>> msa = torch.tensor(
...     [
...         [0, 1, 3, 1, 2],
...         [1, 0, 0, 3, 2],
...         [2, 2, 1, 0, 1],
...         [3, 1, 2, 2, 3],
...         [1, 0, 0, 0, 1],
...         [2, 1, 3, 3, 1],
...     ]
... )
>>> groups = [
...     [0, 1, 2, 2, 1, 0],  # groups for dim=0 (=rows)
...     [0, 1, 2, 3, 4],  # groups for dim=1 (=cols)
... ]
>>> group_counts = grouped_count(msa, mask=None, groups=groups)
>>> group_counts[0]
tensor([
    [1, 0, 1, 0],  # (corresponds to 0x1 & 2x1 at position 0 in rows 0 & 5)
    [0, 2, 0, 0],  # (corresponds to 1x2 at position 1 in rows 0 & 5)
    [0, 0, 0, 2],  # (corresponds to 3x2 at position 2 in rows 0 & 5)
    [0, 1, 0, 1],  # (corresponds to 1x1 & 3x1 at position 3 in rows 0 & 5)
    [0, 1, 1, 0]   # (corresponds to 2x1 & 1x1 at position 4 in rows 0 & 5)
])
atomworks.ml.utils.misc.grouped_sum(data: Tensor, assignment: Tensor, num_groups: int, as_float: bool = True) Tensor[source]#

Computes the sum along a tensor, given group indices.

Parameters:
  • data (torch.Tensor) – A tensor whose groups are to be summed. Shape: (N, …, D), where N is the number of elements.

  • assignment (torch.Tensor) – A 1-D tensor containing group indices. Must be int64 (to be compatible with the scatter operation). Shape: (N,).

  • num_groups (int) – The number of groups.

  • as_float (bool) – If True, the input data will be converted to float before summing. If not True, then booleans will be added as booleans, not integers.

Returns:

A tensor of the same data type as the input data, containing

the sum of elements for each group (cluster). Shape: (num_groups, …, D).

Return type:

torch.Tensor

Example

>>> data = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
>>> assignment = torch.tensor([1, 1, 1, 1])
>>> num_groups = 2
>>> grouped_sum(data, assignment, num_groups)
tensor([[ 6,  8],
        [10, 12]])
# Explanation:
# Group 0: [1, 2] + [5, 6] = [6, 8]
# Group 1: [3, 4] + [7, 8] = [10, 12]
atomworks.ml.utils.misc.hash_sequence(sequence: str) str[source]#

Generate a SHA-256 hash for the given sequence and return a compressed string format of the hash.

Parameters:

sequence (str) – The sequence to be hashed.

Returns:

The compressed hash string format.

Return type:

str

atomworks.ml.utils.misc.masked_mean(*, mask: Tensor, value: Tensor, axis: int | list[int] | None = None, drop_mask_channel: bool = False, eps: float = 1e-10) Tensor[source]#

Compute the masked mean of a tensor along specified axes.

Parameters: - mask (torch.Tensor): A mask tensor with the same shape as value or with dimensions that can be broadcast to value. - value (torch.Tensor): The input tensor for which the masked mean is to be computed. If memory is a concern, can be float16 or even a bool - the sensitive parts of the computation are in float32. - axis (Optional[Union[int, List[int]]]): The axis or axes along which to compute the mean. If None, the mean is computed over all dimensions. - drop_mask_channel (bool): If True, drops the last channel of the mask (assumes the last dimension is a singleton). - eps (float): A small constant to avoid division by zero.

Returns: - torch.Tensor: The masked mean of value along the specified axes. Given in full precision (float32).

Example: >>> import torch >>> mask = torch.tensor([[1, 0], [1, 1]], dtype=bool) >>> value = torch.tensor([[2.0, 3.0], [4.0, 5.0]], dtype=torch.float16) >>> mask_mean(mask, value, axis=0) tensor([3., 5.]) # float32

Reference: - AF2 Multimer Code (google-deepmind/alphafold)

Nested Dictionary Utilities#

Tools to work with nested dictionaries.

atomworks.ml.utils.nested_dict.flatten(d: dict[str, Any], *, fuse_keys: str | None = None) dict[tuple[str, ...], Any] | dict[str, Any][source]#

Flatten a nested dictionary into a single level dictionary with tuple keys, preserving non-dict values.

Parameters:
  • d (-) – A nested dictionary to flatten.

  • fuse_keys (-) – If provided, joins the key tuple elements with this string to create string keys. If None, returns tuple keys.

Returns:

A flattened dictionary where nested dict keys become either tuple keys or fused string keys,

but other values remain intact.

Return type:

  • dict

Example

>>> d = {"a": {"b": [1, 2]}, "c": {"d": {"e": 3}}, "f": [4, 5]}
>>> flatten(d)
{('a', 'b'): [1, 2], ('c', 'd', 'e'): 3, ('f',): [4, 5]}
>>> flatten(d, fuse_keys=".")
{'a.b': [1, 2], 'c.d.e': 3, 'f': [4, 5]}
atomworks.ml.utils.nested_dict.get(d: dict[tuple[str, ...], Any], key: tuple[str, ...], default: Any = None) Any[source]#

Get a value from a nested dictionary using a tuple key.

Equivalent behavior to .get for nested dictionaries.

Parameters:
  • d (-) – A nested dictionary.

  • key (-) – A tuple of keys to navigate through the dictionary.

  • default (-) – The value to return if the key is not found.

Returns:

The value at the specified key. If the key is not found, the default value is returned.

Return type:

  • Any

atomworks.ml.utils.nested_dict.getitem(d: dict[tuple[str, ...], Any], key: tuple[str, ...]) Any[source]#

Get a value from a nested dictionary using a tuple key.

Equivalent behavior to __getitem__ for nested dictionaries.

Parameters:
  • d (-) – A nested dictionary.

  • key (-) – A tuple of keys to navigate through the dictionary.

Returns:

The value at the specified key.

Return type:

  • Any

atomworks.ml.utils.nested_dict.set(d: dict[tuple[str, ...], Any], key: tuple[str, ...], value: Any) None[source]#

Set a value in a nested dictionary using a tuple key.

Equivalent behavior to __setitem__ for nested dictionaries. Creates intermediate dictionaries if they don’t exist yet.

Parameters:
  • d (-) – A nested dictionary.

  • key (-) – A tuple of keys to navigate through the dictionary.

  • value (-) – The value to set at the specified key.

atomworks.ml.utils.nested_dict.unflatten(d: dict[tuple[str, ...] | str, Any], *, split_keys: str | None = None) dict[str, Any][source]#

Unflatten a flattened dictionary into a nested dictionary.

Parameters:
  • d (-) – A flattened dictionary with either tuple keys or string keys.

  • split_keys (-) – If provided, splits string keys with this string to create tuple keys. If None, expects tuple keys.

Returns:

A nested dictionary reconstructed from the flattened keys.

Return type:

  • dict

Example

>>> d = {("a", "b"): [1, 2], ("c", "d", "e"): 3, ("f",): [4, 5]}
>>> unflatten(d)
{'a': {'b': [1, 2]}, 'c': {'d': {'e': 3}}, 'f': [4, 5]}
>>> d = {"a.b": [1, 2], "c.d.e": 3, "f": [4, 5]}
>>> unflatten(d, split_keys=".")
{'a': {'b': [1, 2]}, 'c': {'d': {'e': 3}}, 'f': [4, 5]}

NumPy Utilities#

General utility functions for working with numpy arrays.

atomworks.ml.utils.numpy.get_connected_components_from_adjacency(adjacency: ndarray) list[ndarray][source]#

Return a list of indices for each connected component according to the given adjacency matrix.

atomworks.ml.utils.numpy.get_indices_of_non_constant_columns(arr: ndarray) ndarray[source]#

Identify columns where values change between consecutive rows.

Parameters:

arr (np.ndarray) – A 2D NumPy array where you want to find columns with changing values.

Returns:

An array of column indices where values change between consecutive rows.

Return type:

np.ndarray

Example

>>> arr = np.array(
...     [
...         [151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161],
...         [151, 152, 153, 154, 155, 156, 157, 158, 159, 161, 160],
...     ]
... )
>>> find_changing_columns(arr)
array([ 9, 10])
atomworks.ml.utils.numpy.get_nearest_true_index_for_each_false(arr: ndarray) ndarray[source]#

Get the index of the nearest True for each False in the array, breaking ties by choosing the nearest True to the left.

Parameters:

arr (-) – A boolean numpy array.

Returns:

An array of length np.sum(~arr) where each entry is the index of the nearest True.

Return type:

  • np.ndarray

Example

>>> arr = np.array([False, True, True, False, False, True, False])
>>> get_nearest_true_index_for_each_false(arr)
array([1, 2, 5, 5])
atomworks.ml.utils.numpy.insert_data_by_id_(to_fill: ndarray, to_fill_ids: ndarray, from_data: ndarray, from_data_ids: ndarray, axis: int = 0) ndarray[source]#

Insert data into an array based on matching IDs.

Parameters:
  • to_fill (np.ndarray) – Array to be filled.

  • to_fill_ids (np.ndarray) – Array of IDs corresponding to to_fill.

  • from_data (np.ndarray) – Data array from which to insert.

  • from_data_ids (np.ndarray) – Array of IDs corresponding to from_data.

  • axis (int, optional) – Axis along which to insert data. Defaults to 0.

Returns:

Array with inserted data.

Return type:

np.ndarray

Example

>>> to_array = np.zeros((7, 6))
>>> to_ids = np.array([1, 5, 2, 17, 20, 20, 2])
>>> from_array = np.arange(10).repeat(6).reshape(10, 6)
>>> from_ids = np.array([1, 2, 5, 6, 7, 21, 22, 23, 20, 25])
>>> insert_data_by_id_(to_array, to_ids, from_array, from_ids)
>>> print(to_array)
array([[0., 0., 0., 0., 0., 0.],
       [2., 2., 2., 2., 2., 2.],
       [1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0.],
       [8., 8., 8., 8., 8., 8.],
       [8., 8., 8., 8., 8., 8.],
       [1., 1., 1., 1., 1., 1.]])
atomworks.ml.utils.numpy.is_mask_contiguous(mask: ndarray) bool[source]#

Check if a mask is contiguous.

atomworks.ml.utils.numpy.not_isin(element: ndarray, test_element: ndarray) ndarray[source]#

Return a boolean mask indicating where elements of element are not in test_element.

Parameters:
  • element (np.ndarray) – Array to check.

  • test_element (np.ndarray) – Array to check against.

Returns:

Boolean mask.

Return type:

np.ndarray

Example

>>> not_isin(np.array([1, 2, 3, 4, 5]), np.array([2, 4, 6]))
array([ True, False,  True, False,  True])
atomworks.ml.utils.numpy.select_data_by_id(select_ids: ndarray, data_ids: ndarray, data: ndarray, axis: int = 0) ndarray[source]#

Select data from an array based on matching IDs.

Parameters:
  • select_ids (np.ndarray) – Array of IDs to select.

  • data_ids (np.ndarray) – Array of IDs corresponding to the data.

  • data (np.ndarray) – Data array from which to select.

  • axis (int, optional) – Axis along which to select data. Defaults to 0.

Returns:

Array of selected data.

Return type:

np.ndarray

Raises:
  • AssertionError – If the shape of data along axis does not match the length of data_ids.

  • AssertionError – If data_ids contains duplicate values.

Example

>>> to_ids = np.array([1, 5, 2, 20, 20, 2])
>>> from_array = np.arange(10).repeat(6).reshape(10, 6)
>>> from_ids = np.array([1, 2, 5, 6, 7, 21, 22, 23, 20, 25])
>>> select_data_by_id(to_ids, from_ids, from_array)
array([[0., 0., 0., 0., 0., 0.],
       [2., 2., 2., 2., 2., 2.],
       [1., 1., 1., 1., 1., 1.],
       [8., 8., 8., 8., 8., 8.],
       [8., 8., 8., 8., 8., 8.],
       [1., 1., 1., 1., 1., 1.]])
atomworks.ml.utils.numpy.unique_by_first_occurrence(arr: ndarray) ndarray[source]#

Return unique elements of an array while preserving the order of their first occurrence.

Parameters:

arr (np.ndarray) – Input array.

Returns:

Array of unique elements in the order of their first occurrence.

Return type:

np.ndarray

Example

>>> unique_by_first_occurrence(np.array([4, 2, 2, 3, 1, 4]))
array([4, 2, 3, 1])

Random Number Generation#

Utilities for managing the random number generators in the current process. Inspired by: Lightning-AI/pytorch-lightning

atomworks.ml.utils.rng.capture_rng_states(include_cuda: bool = False) dict[str, Any][source]#

Collect the global random state of torch, torch.cuda, numpy and Python in current process.

Parameters:

include_cuda (bool) – Whether to include the state of the CUDA RNG. If cuda is not available, the state of the CUDA RNG is not included. Defaults to True.

atomworks.ml.utils.rng.create_rng_state_from_seeds(np_seed: int | None = None, torch_seed: int | None = None, py_seed: int | None = None) dict[str, Any][source]#

Create a dictionary of RNG states from the provided seeds. If no seed is provided, the current state of the RNGs is used.

Parameters:
  • np_seed (int | None) – The seed for the Numpy RNG.

  • torch_seed (int | None) – The seed for the PyTorch RNG.

  • py_seed (int | None) – The seed for the Python built-in RNG.

atomworks.ml.utils.rng.get_numpy_rng_state_hash(rng: RandomState | None = None) int[source]#

Get the hash of the current state of the Numpy RNG.

atomworks.ml.utils.rng.get_rng_state_hash(rng_state_dict: dict[str, Any]) int[source]#

Get the hash of the RNG state dictionary.

atomworks.ml.utils.rng.rng_state(rng_state_dict: dict[str, Any] | None = None, include_cuda: bool = True) Generator[dict[str, Any], None, None][source]#
A context manager that resets the global random state on exit to what it was before entering.

Within the context manager, the RNG states are set to the provided rng state in the dictionary.

It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators.

Args:
  • rng_state_dict (dict[str, Any] | None): A dictionary of RNG states to set. It can have the following keys:
    • “torch”: The state of the PyTorch RNG.

    • “torch.cuda”: The state of the PyTorch CUDA RNG.

    • “numpy”: The state of the Numpy RNG.

    • “python”: The state of the Python built-in RNG.

    If no rng_state_dict is provided, the RNG states are set to the current state of the RNGs. If the rng_state_dict only contains a subset of the RNG states, the other RNG states are set to the current state of the RNGs.

  • include_cuda (bool): Whether to allow this function to also control the torch.cuda random number generator.

    Set this to False when using the function in a forked process where CUDA re-initialization is prohibited. Defaults to True.

Example:

```

# Outside the context manager print(“NumPy:”, np.random.random(3)) # [0.04810046 0.99270597 0.70612995] print(“PyTorch:”, torch.rand(3)) # tensor([0.1405, 0.4602, 0.4284]) print(“Python random:”, [random.random() for _ in range(3)]) # [0.7406435863188185, 0.5632059276194807, 0.8537007637060476]

# Inside the context manager with fixed seeds with rng_state(create_rng_state_from_seeds(np_seed=42, torch_seed=42, py_seed=42)) as rng_state_dict:

my_state = serialize_rng_state_dict(rng_state_dict) print(”

Within context manager:”)

print(“NumPy:”, np.random.random(3)) # [0.37454012 0.95071431 0.73199394] print(“PyTorch:”, torch.rand(3)) # tensor([0.8823, 0.9150, 0.3829]) print(“Python random:”, [random.random() for _ in range(3)]) # [0.6394267984578837, 0.025010755222666936, 0.27502931836911926]

# Back to the original state outside the context manager print(”

Back outside the context manager:”)

print(“NumPy:”, np.random.random(3)) # [0.75479377 0.99594641 0.70411424] print(“PyTorch:”, torch.rand(3)) # tensor([0.2757, 0.5345, 0.1754]) print(“Python random:”, [random.random() for _ in range(3)]) # [0.2194923914916147, 0.8731837332486028, 0.47700011905124995]

# Inside the context manager with fixed seeds with rng_state(eval(my_state)):

print(”

Within context manager:”)

print(“NumPy:”, np.random.random(3)) # [0.37454012 0.95071431 0.73199394] print(“PyTorch:”, torch.rand(3)) # tensor([0.8823, 0.9150, 0.3829]) print(“Python random:”, [random.random() for _ in range(3)]) # [0.6394267984578837, 0.025010755222666936, 0.27502931836911926]

```

atomworks.ml.utils.rng.serialize_rng_state_dict(rng_state_dict: dict[str, Any]) str[source]#

Convert the RNG state dictionary to a string so it can be re-created via eval.

Testing Utilities#

atomworks.ml.utils.testing.cached_parse(pdb_id: str, **kwargs) dict[source]#

Wrapper around _cached_parse with caching to return an immutable copy of the output dict

atomworks.ml.utils.testing.get_pdb_mirror_path(pdbid: str, base_dir: str = None) str[source]#

Convenience util to get the path to a CIF file on the DIGS

atomworks.ml.utils.testing.is_clash(atom_array_1: AtomArray, atom_array_2: AtomArray, clash_distance: float = 1.0) bool[source]#

Checks for clashes between two arrays. Based on atomworks.ml.preprocessing.process.DataPreprocessor. Recommended to pass in minimal masks of arrays to check to reduce runtime.

Timer Utilities#

Timeout utilities for applying time limits to functions and blocks of code.

Adapted from pnpnpn/timeout-decorator (MIT License) and from chaidiscovery/chai-lab

exception atomworks.ml.utils.timer.ChildProcessError[source]#

Bases: Exception

Exception raised when a child process dies unexpectedly.

atomworks.ml.utils.timer.do_nothing(*args, **kwargs) Callable[source]#

A decorator that does nothing and simply returns the original function.

This decorator can be used as a placeholder or for testing purposes when you want to conditionally apply decorators without changing the code structure.

Returns:

A decorator function that returns the original function unchanged.

Return type:

Callable

Example

```python @do_nothing_decorator() def my_function():

return “Hello, World!”

# or: do_nothing(bla=123, blub=456)(my_function) ```

atomworks.ml.utils.timer.timeout(timeout: float | int | None = None, strategy: Literal['signal', 'subprocess'] = 'subprocess') Callable[source]#

Decorator to apply a timeout to a function.

The signal strategy is more efficient and slightly faster, but does not work in all contexts (e.g. with some C dependencies like RDKit, on certain operating systems). The subprocess strategy is always available, but slightly slower and with a higher overhead.

atomworks.ml.utils.timer.timeout_using_signal(timeout: float | int | None) Callable[source]#

Build a decorator that applies a timeout to a function using the signal module.

This decorator sets up a signal handler to raise a TimeoutError if the decorated function exceeds the specified timeout duration. It uses the SIGALRM signal to implement the timeout.

Use for example as: `python result = timeout_using_signal(timeout=10.0)(my_function)(*args, **kwargs) `

Parameters:

timeout (float | int | None) – The timeout duration in seconds.

Returns:

A decorator function that can be applied to other functions to add timeout functionality.

Return type:

Callable

atomworks.ml.utils.timer.timeout_using_subprocess(timeout: float | int | None) Callable[source]#

Force function to timeout after specified time.

The returned decorator uses a subprocess to execute the function, allowing for timeout functionality even for CPU-bound operations that cannot be interrupted by signals.

Parameters:

timeout (float | int | None) – The maximum time in seconds allowed for the function to execute.

Returns:

A decorator that can be applied to a function.

Return type:

Callable

Raises:
  • TimeoutError – If the function does not return before the timeout.

  • ChildProcessException – If the child process dies unexpectedly.

Token Utilities#

atomworks.ml.utils.token.apply_and_spread_token_wise(atom_array: AtomArray, data: ndarray, function: Callable, axis: int | None = None, token_starts: ndarray | None = None) ndarray[source]#

Apply a function token wise and then spread the result to the atoms.

atomworks.ml.utils.token.apply_segment_wise_2d(array: ndarray, segment_start_end_idxs: ndarray, reduce_func: Callable) ndarray[source]#

Reduces a 2D array by applying a reduction function to specified segments along both rows and columns.

NOTE: Segments must be contiguous, rectangular blocks (sub-matrices) of the 2D array.

Parameters:
  • array (np.ndarray) – A 2D numpy array to be reduced.

  • group_start_end_idxs (np.ndarray) – A 1D numpy array of indices indicating the start and end of each block. The first element must be 0 and the last element must be the number of rows in the array.

  • reduce_func (Callable) – A function to apply to each segment. This function should take an array and return a reduced value.

Returns:

A 2D numpy array that has been reduced along both rows and columns.

Return type:

np.ndarray

Example

>>> array = np.array([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])
>>> segment_start_end_idxs = np.array([0, 2, 3])
>>> apply_segment_wise_2d(array, segment_start_end_idxs, reduce_func=np.sum)
array([
    [12, 9],
    [15, 9]
])
atomworks.ml.utils.token.apply_token_wise(array: AtomArray, data: ndarray, function: Callable, axis: int | None = None, token_starts: ndarray | None = None) ndarray[source]#

Analogous to biotite’s apply_residue_wise.

atomworks.ml.utils.token.get_af3_token_center_coords(atom_array: AtomArray) ndarray[source]#

Returns the center coordinates of the tokens in the atom array as per the AF3 definition.

For each token we also designate a token center atom, used in various places below:
  • CA for standard amino acids

  • C1’ for standard nucleotides

  • For other cases take the first and only atom as they are tokenized per-atom.

If a token center cannot be assigned (e.g. because the token center atom is unoccupied), the center coordinate is set to np.nan.

Parameters:

atom_array (AtomArray) – The atom array to get the center coordinates of.

Returns:

The center coordinates of the tokens in the atom array.

Return type:

np.ndarray

Reference:

Example

>>> # Contrived example showing only a few tokens and annotations per residue for illustration
>>> array = AtomArray(
    res_name="ALA", atom_name="CA", coord=np.array([0, 0, 0]),
    res_name="ALA", atom_name="C", coord=np.array([1, 0, 0]),
    res_name="ALA", atom_name="O", coord=np.array([2, 0, 0]),
    res_name="NAP", atom_name="P1", coord=np.array([3, 0, 0]),
    res_name="U",   atom_name="C1'", coord=np.array([4, 0, 0]),
)
>>> get_af3_token_center_coords(array)
array([[0, 0, 0], [3, 0, 0], [4, 0, 0]])
atomworks.ml.utils.token.get_af3_token_center_idxs(atom_array: AtomArray) ndarray[source]#

Returns the indices of the center atoms of the tokens in the atom array as per the AF3 definition.

atomworks.ml.utils.token.get_af3_token_center_masks(atom_array: AtomArray) ndarray[source]#

Returns a boolean mask indicating the center atoms of the tokens in the atom array as per the AF3 definition.

NOTE: “Center” atoms are distinct from “representative” atoms, which are used during distogram prediction (and more closely represent the center of mass).

For each token we also designate a token center atom, used in various places below:
  • CA for standard amino acids

  • C1’ for standard nucleotides

  • For other cases take the first and only atom as they are tokenized per-atom.

Parameters:

atom_array (AtomArray) – The atom array to get the center atoms of.

Returns:

A boolean mask indicating the center atoms of the tokens in the atom array.

Return type:

np.ndarray

Reference:
atomworks.ml.utils.token.get_af3_token_representative_coords(atom_array: AtomArray) ndarray[source]#

Returns the representative coordinates of the tokens in the atom array.

See “get_af3_token_representative_masks” for more details on what constitutes a representative atom.

Parameters:

atom_array (AtomArray) – The atom array to get the representative coordinates of.

Returns:

The representative coordinates of the tokens in the atom array.

Return type:

np.ndarray

atomworks.ml.utils.token.get_af3_token_representative_idxs(atom_array: AtomArray) ndarray[source]#

Returns the indices of the representative atoms of the tokens in the atom array.

See “get_af3_token_representative_masks” for more details on what constitutes a representative atom.

Parameters:

atom_array (AtomArray) – The atom array to get the representative atom indices from.

Returns:

An array of indices corresponding to the representative atoms of the tokens.

Return type:

np.ndarray

atomworks.ml.utils.token.get_af3_token_representative_masks(atom_array: AtomArray) ndarray[source]#

Returns a boolean mask indicating the representative atoms of the tokens in the atom array.

From the AF-3 supplement, section 4.4. (Distogram prediction):

> …where the pairwise token distances use the representative atom for each token: CB for protein residues (CA for glycine), C4 for purines and C2 for pyrimidines. All ligands already have a single atom per token.

NOTE: “Representative” atoms are distinct from “center” atoms, which are used during cropping.

Parameters:

atom_array (AtomArray) – The atom array to get the representative atoms of.

Returns:

A boolean mask indicating the representative atoms of the tokens in the atom array.

Return type:

np.ndarray

atomworks.ml.utils.token.get_token_count(array: AtomArray) int[source]#

Returns the number of distinct tokens in the atom array.

This function counts the number of tokens based on the changes in the atom array’s annotations. It will match the behavior of biotite.structure.get_residue_count when the atom array does not have the atomize annotation or if atomize is False for all atoms.

Returns:

The number of distinct tokens in the atom array.

Return type:

int

atomworks.ml.utils.token.get_token_masks(array: AtomArray, indices: ndarray) ndarray[source]#

Get boolean masks indicating the tokens to which the given atom indices belong.

Parameters:
  • array (-) – The atom array (stack) to determine the residues from.

  • indices (-) – An array of indices indicating the atoms to get the corresponding residues for. Negative indices are not allowed.

Returns:

A 2D boolean array where each row corresponds to a given index

in indices. Each row masks the atoms that belong to the same residue as the atom at the given index.

Return type:

  • residues_masks (ndarray, dtype=bool)

See also

  • get_residue_masks_for

  • get_token_starts

  • get_token_starts_for

  • get_token_positions

atomworks.ml.utils.token.get_token_starts(array: AtomArray, add_exclusive_stop: bool = False) ndarray[source]#

Get indices for an atom array, each indicating the beginning of a token.

Inspired by biotite.structure.get_residue_starts.

A new token starts:
  • If atomize is True

  • If either the chain ID, residue ID, insertion code or residue name changes from one to the next atom.

Parameters:
  • array (AtomArray) – The atom array to get the token starts from.

  • add_exclusive_stop (bool, optional) – If True, add an exclusive stop to the token starts for the last residue. Defaults to False.

Returns:

An array of indices indicating the beginning of each token.

Return type:

np.ndarray

atomworks.ml.utils.token.get_token_starts_for(array: AtomArray, indices: ndarray) ndarray[source]#

Retrieves the indices that point to the start of the token for each specified atom index.

This function is useful for identifying the beginning of the token associated with each atom in the provided indices. It is particularly relevant in contexts where atoms are grouped into tokens based on their annotations.

Parameters:
  • array (-) – The atom array (or stack) from which to determine the residue starts.

  • indices (-) – An array of atom indices for which the corresponding residue starts are to be retrieved. Negative indices are not permitted.

Returns:

An array of indices pointing to the start of

the tokens corresponding to the input indices.

Return type:

  • start_indices (ndarray, dtype=int, shape=(k,))

See also

  • get_residue_starts_for

  • get_token_starts

  • get_token_masks

  • get_token_positions

atomworks.ml.utils.token.spread_token_wise(array: AtomArray, input_data: ndarray, token_starts: ndarray | None = None) ndarray[source]#

Analogous to biotite’s spread_residue_wise.

atomworks.ml.utils.token.token_iter(array: AtomArray) Iterator[AtomArray][source]#

Returns an iterator over the tokens in the atom array.

This will match biotite.structure.residue_iter in the case where the atom array does not have the atomize annotation, or if atomize is False everywhere.