Source code for atomworks.ml.common
from __future__ import annotations
import re
from typing import Any
from atomworks.io.common import default, exists # noqa: F401
[docs]
def generate_example_id(dataset_names: list[str], pdb_id: str, assembly_id: str, query_pn_unit_iids: list) -> str:
"""Generate a unique example ID from a DataFrame row.
This unique ID is helpful for debugging and to track performance on specific examples.
An example can be uniquely defined by, in order:
(1) a composed list of dataset names (e.g., [pdb, pn_unit] to indicate the pn_unit dataset nested within the PDB dataset)
(2) pdb_id (or any group-level identifier, if using a non-PDB dataset), within the dataset specified by (1)
(3) assembly_id
(4) query_pn_unit_iids
"""
# Format: {[dataset_names]}{pdb_id}{assembly_id}{query_pn_unit_iids}
# Example for pn_unit dataset: {['pdb', 'pn_unit']}{6vyb}{1}{['A_1']}
# Example for interface dataset: {['pdb', 'interfaces']}{6vyb}{1}{['A_1', 'B_1']}
# Example for a distillation dataset: {['af2_distillation']}{6vyb}{1}{['A_1']}
return f"{{{dataset_names}}}{{{pdb_id}}}{{{assembly_id}}}{{{query_pn_unit_iids}}}"
[docs]
def parse_example_id(example_id: str) -> dict:
"""
Parse the example ID into its components: dataset names, pdb_id, assembly_id, and query_pn_unit_iids.
Args:
example_id (str): The example ID string generated by the generate_example_id function.
Returns:
dict: A dictionary containing the parsed components.
"""
# Use regular expression to find all parts within curly braces
matches = re.findall(r"\{(.*?)\}", example_id)
if len(matches) != 4:
raise ValueError(f"Invalid example ID format: {example_id}, with {len(matches)} matches found (expected 4).")
# Extract the components
datasets = eval(matches[0]) # Convert string representation of list to actual list
pdb_id = matches[1]
assembly_id = matches[2]
# Convert string representation of list to actual list
query_pn_unit_iids = eval(matches[3]) if matches[3] else []
return {
"datasets": datasets,
"pdb_id": pdb_id,
"assembly_id": assembly_id,
"query_pn_unit_iids": query_pn_unit_iids,
}
[docs]
def as_list(value: Any) -> list:
"""Convert a value to a list.
Handles various types using duck typing:
- Iterable objects (lists, tuples, strings, etc.): converted to list
- Single values: wrapped in a list
"""
try:
# Try to iterate over the value (duck typing approach)
# Exclude strings since they're iterable but we want to treat them as single values
if isinstance(value, str):
return [value]
return list(value)
except TypeError:
# If it's not iterable, wrap it in a list
return [value]