Samplers#
This module contains various sampling strategies and utilities for data processing.
- class atomworks.ml.samplers.DistributedMixedSampler(datasets_info: list[dict[str, any]], num_replicas: int, rank: int, n_examples_per_epoch: int | None, shuffle: bool = True, drop_last: bool = True)[source]#
Bases:
Sampler
Custom DistributedSampler implementation that samples from an arbitrary list of samplers with specified probabilities.
Child samplers can be any type of non-distributed sampler, including a MixedSampler. After gathering all indices, shards the samples across nodes, ensuring each node receives a unique slice of the dataset.
Example
Imagine we have the following sampling tree:
- DistributedMixedSampler
0.8 0.2
- Sampler1 MixedSampler
/ 0.9 0.1
Sampler2 Sampler3
If we initialized DistributedMixedSampler with n_examples_per_epoch=100 and num_replicas=2, it would collect 80 samples from Sampler1 and 20 samples from the MixedSampler. The MixedSampler would in turn collect 18 samples from Sampler2 and 2 samples from Sampler3. After collecting those 100 samples, the DistributedMixedSampler would shard the samples across the two nodes, ensuring each node receives a unique slice of 50 examples.
If any of the child samplers were distributed samples, then the DistributedMixedSampler would not receive n_examples_per_epoch indices, and we would raise an error.
NOTE: The order of the datasets in datasets_info MUST match the order of the datasets in the ConcatDataset associated with this MixedSampler.
- Parameters:
datasets_info – List of dictionaries, where each dictionary must contain at a minimum: - “sampler”: Sampler object for the dataset - “dataset”: Dataset object associated with the sampler - “probability”: Probability of sampling from this dataset
num_replicas – Number of replicas (nodes) in the distributed setting
rank – Rank of the current node
n_examples_per_epoch – Number of examples in an epoch. Effectively, the “length” of the sampler (since we often sample with replacement). May be None, in which case the number of examples per epoch must be set dynamically by a parent sampler.
shuffle – Whether to shuffle the indices. If False, the iterator will return all sampled indices from the first dataset, then the second, etc.
drop_last – Whether to drop the last incomplete batch if the dataset size is not divisible by the batch size
- Returns:
An iterator over indices of the dataset for the current process (of length n_samples, not n_examples_per_epoch)
- Return type:
iter
References
PyTorch DistributedSampler (pytorch/pytorch)
- class atomworks.ml.samplers.FallbackSamplerWrapper(sampler: Sampler, fallback_sampler: Sampler, n_fallback_retries: int = 2)[source]#
Bases:
Sampler
A wrapper around a sampler that allows for a fallback sampler to be used when an error occurs.
Meant to be used with a FallbackDatasetWrapper.
- class atomworks.ml.samplers.LazyWeightedRandomSampler(weights: Sequence[float], num_samples: int, replacement: bool = True, generator: Generator | None = None, prefetch_buffer_size: int = 1)[source]#
Bases:
WeightedRandomSampler
- class atomworks.ml.samplers.LoadBalancedDistributedSampler(dataset: Dataset, key_to_balance: str, num_replicas: int | None = None, rank: int | None = None, drop_last: bool = False)[source]#
Bases:
DistributedSampler
DistributedSampler that balances large examples across replicas.
Helpful for validation, where we don’t want GPUs to be idle while waiting for the slowest replica to finish.
For example, we may want to avoid the scenario where one GPU receives many large examples that are slow to process, while another GPU receives many small examples that are quick to process.
NOTE: Only useful for validation, as the order of the examples is deterministic.
- Parameters:
dataset – Dataset used for sampling.
key_to_balance – Key in the dataset data dataframe that contains the length (size) of each example. The dataset must have a data attribute that can be accessed like a dataframe. For example, if the dataset has a data attribute that is a pandas DataFrame, the key_to_balance should be a column in that DataFrame (i.e., “n_tokens”).
num_replicas (int, optional) – Number of processes participating in distributed training. By default,
world_size
is retrieved from the current distributed group.rank (int, optional) – Rank of the current process within
num_replicas
. By default,rank
is retrieved from the current distributed group.drop_last (bool, optional) – if
True
, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. IfFalse
, the sampler will add extra indices to make the data evenly divisible across the replicas. Default:False
.
- class atomworks.ml.samplers.MixedSampler(datasets_info: list[dict[str, any]], n_examples_per_epoch: int | None = None, shuffle: bool = True)[source]#
Bases:
DistributedMixedSampler
A non-distributed sampler that samples from an arbitrary list of samplers with specified probabilities.
This class acts like a DistributedMixedSampler with rank=0 and num_replicas=1.
- Parameters:
datasets_info – List of dictionaries, where each dictionary must contain at a minimum: - “sampler”: Sampler object for the dataset - “dataset”: Dataset object associated with the sampler - “probability”: Probability of sampling from this dataset
n_examples_per_epoch – Number of examples in an epoch. Effectively, the “length” of the sampler.
shuffle – Whether to shuffle the indices. If False, the iterator will return all sampled indices from the first dataset, then the second, etc.
- atomworks.ml.samplers.calculate_af3_example_weights(df: DataFrame, alphas: dict[str, float], beta: float) Series [source]#
Determines the weight of each example in the DataFrame using a methodology inspired by AF-3.
- In AF-3, the weight of a given example is a function of:
The size of the cluster to which the example belongs (specific for interfaces vs. chains)
The number of proteins / nucleic acids / ligands in the example
Whether the example is an interface or a chain
- Specifically, AF3 gives the following formula (Section 2.5.1 from the AF-3 Supplementary Information):
w ∝ (β_r / N_clust) * (a_prot * n_prot + a_nuc * n_nuc + a_ligand * n_ligand)
- Where:
w is the weight of the example
β_r is a weighting hyperparameter that is distinct for interfaces and chains
N_clust is the number of examples in the cluster
a_prot, a_nuc, and a_ligand are the interface weight hyperparameters for proteins, nucleic acids, and ligands, respectively
n_prot, n_nuc, and n_ligand are the number of proteins, nucleic acids, and ligands in the example
- We make the following modifications to the original AF-3 formula:
We introduce n_peptide and a_peptide to better control the sampling over peptides (which were being over-sampled). We define peptides
as proteins with fewer than PEPTIDE_MAX_RESIDUES residues (see atomworks.ml.preprocessing.constants). - We introduce an incremental a_loi weight to control the sampling of ligands of interests (LOI), also described as Subject of Investigation.
- Thus, our full formula is:
w ∝ (β_r / N_clust) * (a_prot * n_prot + a_peptide * n_peptide + a_nuc * n_nuc + a_ligand * n_ligand + a_loi * is_loi)
- Parameters:
df (pd.DataFrame) – DataFrame containing the PN unit or interface data
alphas (dict) – Dictionary containing the weight hyperparameters for proteins, nucleic acids, ligands, and possibly peptides (common across interfaces and chains)
beta (float) – Weighting hyperparameter (distinct for interfaces and chains)
- Returns:
A Series containing the calculated weights for each row in the DataFrame
- Return type:
pd.Series
- atomworks.ml.samplers.calculate_weights_by_inverse_cluster_size(dataset_df: DataFrame, cluster_column: str = 'cluster') Tensor [source]#
Calculate weights for each row in the DataFrame as the inverse of its cluster size.
- Parameters:
dataset_df (pd.DataFrame) – DataFrame containing the PN unit or interface data
cluster_column (str) – Column name in dataset_df corresponding to the cluster info. Default is “cluster”.
- Returns:
A tensor containing the calculated weights for each row in the DataFrame
- Return type:
torch.Tensor
- atomworks.ml.samplers.calculate_weights_for_pdb_dataset_df(dataset_df: DataFrame, alphas: dict[str, float], beta: float, cluster_column: str = 'cluster') Tensor [source]#
Calculate weights for each row in the DataFrame based on the cluster size and the AF-3 weighting methodology.
- Parameters:
dataset_df (pd.DataFrame) – DataFrame containing the PN unit or interface data
alphas (dict[str, float]) – Dictionary containing alpha values for the weighting calculation (common across interfaces and chains/pn_units)
beta (float) – Beta value for the weighting calculation (distinct for interfaces and chains/pn_units)
- Returns:
A tensor containing the calculated weights for each row in the DataFrame
- Return type:
torch.Tensor
- atomworks.ml.samplers.get_cluster_sizes(df: DataFrame, cluster_column: str = 'cluster') dict[str, int] [source]#
Generate a mapping between cluster alphanumeric IDs and the number of PN units/interfaces in each cluster.
- Parameters:
df (pd.DataFrame) – DataFrame containing the PN unit or interface data
cluster_column (str) – Name of the column containing the cluster alphanumeric IDs
- Returns:
A dictionary where the keys are unique cluster IDs and the values are the counts of occurrences.
- Return type:
dict