Samplers#

This module contains various sampling strategies and utilities for data processing.

class atomworks.ml.samplers.DistributedMixedSampler(datasets_info: list[dict[str, any]], num_replicas: int, rank: int, n_examples_per_epoch: int | None, shuffle: bool = True, drop_last: bool = True)[source]#

Bases: Sampler

Custom DistributedSampler implementation that samples from an arbitrary list of samplers with specified probabilities.

Child samplers can be any type of non-distributed sampler, including a MixedSampler. After gathering all indices, shards the samples across nodes, ensuring each node receives a unique slice of the dataset.

Example

Imagine we have the following sampling tree:

DistributedMixedSampler

0.8 0.2

Sampler1 MixedSampler

/ 0.9 0.1

Sampler2 Sampler3

If we initialized DistributedMixedSampler with n_examples_per_epoch=100 and num_replicas=2, it would collect 80 samples from Sampler1 and 20 samples from the MixedSampler. The MixedSampler would in turn collect 18 samples from Sampler2 and 2 samples from Sampler3. After collecting those 100 samples, the DistributedMixedSampler would shard the samples across the two nodes, ensuring each node receives a unique slice of 50 examples.

If any of the child samplers were distributed samples, then the DistributedMixedSampler would not receive n_examples_per_epoch indices, and we would raise an error.

NOTE: The order of the datasets in datasets_info MUST match the order of the datasets in the ConcatDataset associated with this MixedSampler.

Parameters:
  • datasets_info – List of dictionaries, where each dictionary must contain at a minimum: - “sampler”: Sampler object for the dataset - “dataset”: Dataset object associated with the sampler - “probability”: Probability of sampling from this dataset

  • num_replicas – Number of replicas (nodes) in the distributed setting

  • rank – Rank of the current node

  • n_examples_per_epoch – Number of examples in an epoch. Effectively, the “length” of the sampler (since we often sample with replacement). May be None, in which case the number of examples per epoch must be set dynamically by a parent sampler.

  • shuffle – Whether to shuffle the indices. If False, the iterator will return all sampled indices from the first dataset, then the second, etc.

  • drop_last – Whether to drop the last incomplete batch if the dataset size is not divisible by the batch size

Returns:

An iterator over indices of the dataset for the current process (of length n_samples, not n_examples_per_epoch)

Return type:

iter

References

set_epoch(epoch: int) None[source]#
class atomworks.ml.samplers.FallbackSamplerWrapper(sampler: Sampler, fallback_sampler: Sampler, n_fallback_retries: int = 2)[source]#

Bases: Sampler

A wrapper around a sampler that allows for a fallback sampler to be used when an error occurs.

Meant to be used with a FallbackDatasetWrapper.

set_epoch(epoch: int) None[source]#
class atomworks.ml.samplers.LazyWeightedRandomSampler(weights: Sequence[float], num_samples: int, replacement: bool = True, generator: Generator | None = None, prefetch_buffer_size: int = 1)[source]#

Bases: WeightedRandomSampler

class atomworks.ml.samplers.LoadBalancedDistributedSampler(dataset: Dataset, key_to_balance: str, num_replicas: int | None = None, rank: int | None = None, drop_last: bool = False)[source]#

Bases: DistributedSampler

DistributedSampler that balances large examples across replicas.

Helpful for validation, where we don’t want GPUs to be idle while waiting for the slowest replica to finish.

For example, we may want to avoid the scenario where one GPU receives many large examples that are slow to process, while another GPU receives many small examples that are quick to process.

NOTE: Only useful for validation, as the order of the examples is deterministic.

Parameters:
  • dataset – Dataset used for sampling.

  • key_to_balance – Key in the dataset data dataframe that contains the length (size) of each example. The dataset must have a data attribute that can be accessed like a dataframe. For example, if the dataset has a data attribute that is a pandas DataFrame, the key_to_balance should be a column in that DataFrame (i.e., “n_tokens”).

  • num_replicas (int, optional) – Number of processes participating in distributed training. By default, world_size is retrieved from the current distributed group.

  • rank (int, optional) – Rank of the current process within num_replicas. By default, rank is retrieved from the current distributed group.

  • drop_last (bool, optional) – if True, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If False, the sampler will add extra indices to make the data evenly divisible across the replicas. Default: False.

class atomworks.ml.samplers.MixedSampler(datasets_info: list[dict[str, any]], n_examples_per_epoch: int | None = None, shuffle: bool = True)[source]#

Bases: DistributedMixedSampler

A non-distributed sampler that samples from an arbitrary list of samplers with specified probabilities.

This class acts like a DistributedMixedSampler with rank=0 and num_replicas=1.

Parameters:
  • datasets_info – List of dictionaries, where each dictionary must contain at a minimum: - “sampler”: Sampler object for the dataset - “dataset”: Dataset object associated with the sampler - “probability”: Probability of sampling from this dataset

  • n_examples_per_epoch – Number of examples in an epoch. Effectively, the “length” of the sampler.

  • shuffle – Whether to shuffle the indices. If False, the iterator will return all sampled indices from the first dataset, then the second, etc.

atomworks.ml.samplers.calculate_af3_example_weights(df: DataFrame, alphas: dict[str, float], beta: float) Series[source]#

Determines the weight of each example in the DataFrame using a methodology inspired by AF-3.

In AF-3, the weight of a given example is a function of:
  1. The size of the cluster to which the example belongs (specific for interfaces vs. chains)

  2. The number of proteins / nucleic acids / ligands in the example

  3. Whether the example is an interface or a chain

Specifically, AF3 gives the following formula (Section 2.5.1 from the AF-3 Supplementary Information):

w ∝ (β_r / N_clust) * (a_prot * n_prot + a_nuc * n_nuc + a_ligand * n_ligand)

Where:
  • w is the weight of the example

  • β_r is a weighting hyperparameter that is distinct for interfaces and chains

  • N_clust is the number of examples in the cluster

  • a_prot, a_nuc, and a_ligand are the interface weight hyperparameters for proteins, nucleic acids, and ligands, respectively

  • n_prot, n_nuc, and n_ligand are the number of proteins, nucleic acids, and ligands in the example

We make the following modifications to the original AF-3 formula:
  • We introduce n_peptide and a_peptide to better control the sampling over peptides (which were being over-sampled). We define peptides

as proteins with fewer than PEPTIDE_MAX_RESIDUES residues (see atomworks.ml.preprocessing.constants). - We introduce an incremental a_loi weight to control the sampling of ligands of interests (LOI), also described as Subject of Investigation.

Thus, our full formula is:

w ∝ (β_r / N_clust) * (a_prot * n_prot + a_peptide * n_peptide + a_nuc * n_nuc + a_ligand * n_ligand + a_loi * is_loi)

Parameters:
  • df (pd.DataFrame) – DataFrame containing the PN unit or interface data

  • alphas (dict) – Dictionary containing the weight hyperparameters for proteins, nucleic acids, ligands, and possibly peptides (common across interfaces and chains)

  • beta (float) – Weighting hyperparameter (distinct for interfaces and chains)

Returns:

A Series containing the calculated weights for each row in the DataFrame

Return type:

pd.Series

atomworks.ml.samplers.calculate_weights_by_inverse_cluster_size(dataset_df: DataFrame, cluster_column: str = 'cluster') Tensor[source]#

Calculate weights for each row in the DataFrame as the inverse of its cluster size.

Parameters:
  • dataset_df (pd.DataFrame) – DataFrame containing the PN unit or interface data

  • cluster_column (str) – Column name in dataset_df corresponding to the cluster info. Default is “cluster”.

Returns:

A tensor containing the calculated weights for each row in the DataFrame

Return type:

torch.Tensor

atomworks.ml.samplers.calculate_weights_for_pdb_dataset_df(dataset_df: DataFrame, alphas: dict[str, float], beta: float, cluster_column: str = 'cluster') Tensor[source]#

Calculate weights for each row in the DataFrame based on the cluster size and the AF-3 weighting methodology.

Parameters:
  • dataset_df (pd.DataFrame) – DataFrame containing the PN unit or interface data

  • alphas (dict[str, float]) – Dictionary containing alpha values for the weighting calculation (common across interfaces and chains/pn_units)

  • beta (float) – Beta value for the weighting calculation (distinct for interfaces and chains/pn_units)

Returns:

A tensor containing the calculated weights for each row in the DataFrame

Return type:

torch.Tensor

atomworks.ml.samplers.get_cluster_sizes(df: DataFrame, cluster_column: str = 'cluster') dict[str, int][source]#

Generate a mapping between cluster alphanumeric IDs and the number of PN units/interfaces in each cluster.

Parameters:
  • df (pd.DataFrame) – DataFrame containing the PN unit or interface data

  • cluster_column (str) – Name of the column containing the cluster alphanumeric IDs

Returns:

A dictionary where the keys are unique cluster IDs and the values are the counts of occurrences.

Return type:

dict

atomworks.ml.samplers.set_sampler_epoch(sampler: Sampler, epoch: int, add_random_offset: bool = False) None[source]#

Control the random seed for a sampler.