Source code for atomworks.ml.datasets.parsers.custom_metadata_row_parsers
"""Row parser for non-standard metadata dataframes"""
from pathlib import Path
from typing import Any
import pandas as pd
from atomworks.io.constants import PDB_MIRROR_PATH
from atomworks.ml.datasets.parsers import MetadataRowParser
[docs]
class AF2FB_DistillationParser(MetadataRowParser): # noqa: N801
# TODO: Deprecate in favor of GenericDFParser
"""
DEPRECATION WARNING: This parser is deprecated and will be removed in a future release.
We should use the GenericDFParser instead, providing `path` and `example_id` columns.
Parser for AF2FB distillation metadata.
The AF2FB distillation dataset is provided courtesy of Meta/Facebook.
It contains ~7.6 Mio AF2 predicted structures from UniRef50.
Metadata (i.e. which sequences, which cluster identities @ 30% seq.id,
whether a sequence has an msa & template, sequence_hash etc.) are stored
in the `af2_distillation_facebook.parquet` dataframe.
The parquet has the following columns:
- example_id
- n_atoms
- n_res
- mean_plddt
- min_plddt
- median_plddt
- sequence_hash
- has_msa
- msa_depth
- has_template
- cluster_id
- seq (!WARNING: this is a relatively data-heavy column)
"""
def __init__(self, base_dir: str, file_extension: str = ".cif"):
"""
Initialize the AF2FB_DistillationParser.
This parser is designed to handle the AF2FB distillation dataset, which contains
approximately 7.6 million AlphaFold2 predicted structures from UniRef50.
Args:
- base_dir (str): The base directory where the AF2FB distillation dataset is stored.
Defaults to "/squash/af2_distillation_facebook", which is stored on `tukwila` for
ML model training.
- file_extension (str): The file extension of the structure files. Defaults to ".cif".
Raises:
- AssertionError: If the specified dataset directory does not exist.
"""
self.dataset_dir = Path(base_dir)
self.file_extension = file_extension
assert self.dataset_dir.exists(), f"Dataset directory {self.dataset_dir} does not exist."
@staticmethod
def _get_shard_from_hash(hash_value: str) -> str:
"""Due to the size of the AF2FB dataset, we store it with 2-level sharding.
The two layers of sharding is an optimization technique for faster filesystem
performance. (Do not put more than 10k files in any directory).
Example:
- example_id: UniRef50_A0A1S3ZVX8
- sequence_hash: f771c39dfbf
therefore the two level shard is `f7/71/` and the files can be found at
- ./cif/f7/71/UniRef50_A0A1S3ZVX8.cif
- ./msa/f7/71/f771c39dfbf.a3m
- ./template/f7/71/f771c39dfbf.atab
"""
return f"{hash_value[:2]}/{hash_value[2:4]}/"
def _parse(self, row: pd.Series) -> dict:
example_id = row["example_id"]
sequence_hash = row["sequence_hash"]
path = (
self.dataset_dir / "cif" / self._get_shard_from_hash(sequence_hash) / f"{example_id}{self.file_extension}"
)
return {
"example_id": example_id,
"path": path,
"assembly_id": "1", # just default to the first assembly (=identity if none given)
"sequence_hash": sequence_hash,
}
[docs]
class ValidationDFParserLikeAF3(MetadataRowParser):
# TODO: Deprecate in favor of GenericDFParser
"""
Parser for AF-3-style validation DataFrame rows.
As output, we give:
- pdb_id: The PDB ID of the structure.
- assembly_id: The assembly ID of the structure, required to load the correct assembly from the CIF file.
- path: The path to the CIF file.
- example_id: An identifier that combines the pdb_id and assembly_id.
- ground_truth: A dictionary containing non-feature information for loss and validation. For validation, we initialize with the following:
- interfaces_to_score: A list of tuples like (pn_unit_iid_1, pn_unit_iid_2, interface_type), which represent low-homology interfaces to score.
- pn_units_to_score: A list of tuples like (pn_unit_iid, pn_unit_type), which represent low-homology pn_units to score.
"""
def __init__(self, base_dir: Path = PDB_MIRROR_PATH, file_extension: str = ".cif.gz"):
self.base_dir = base_dir
self.file_extension = file_extension
def _parse(self, row: pd.Series) -> dict[str, Any]:
# Build the path to the CIF file
pdb_id = row["pdb_id"]
path = Path(f"{self.base_dir}/{pdb_id[1:3]}/{pdb_id}{self.file_extension}")
# Extract the interfaces and pn_units to score
# Example: [(A_1, B_1, "protein-protein"), (B_1, C_1, "protein-ligand")]
interfaces_to_score = (
eval(row["interfaces_to_score"])
if isinstance(row["interfaces_to_score"], str)
else [eval(interface) for interface in row["interfaces_to_score"]]
)
# Example: [(A_1, "protein"), (B_1, "DNA")]
pn_units_to_score = (
eval(row["pn_units_to_score"])
if isinstance(row["pn_units_to_score"], str)
else [eval(unit) for unit in row["pn_units_to_score"]]
)
return {
"example_id": row["example_id"],
"path": path,
"pdb_id": pdb_id,
"assembly_id": row["assembly_id"],
"ground_truth": {
"interfaces_to_score": interfaces_to_score,
"pn_units_to_score": pn_units_to_score,
},
}