Source code for atomworks.ml.common

from __future__ import annotations

import re
from typing import Any

from atomworks.io.common import default, exists  # noqa: F401


[docs] def generate_example_id(dataset_names: list[str], pdb_id: str, assembly_id: str, query_pn_unit_iids: list) -> str: """Generate a unique example ID from a DataFrame row. This unique ID is helpful for debugging and to track performance on specific examples. An example can be uniquely defined by, in order: (1) a composed list of dataset names (e.g., [pdb, pn_unit] to indicate the pn_unit dataset nested within the PDB dataset) (2) pdb_id (or any group-level identifier, if using a non-PDB dataset), within the dataset specified by (1) (3) assembly_id (4) query_pn_unit_iids """ # Format: {[dataset_names]}{pdb_id}{assembly_id}{query_pn_unit_iids} # Example for pn_unit dataset: {['pdb', 'pn_unit']}{6vyb}{1}{['A_1']} # Example for interface dataset: {['pdb', 'interfaces']}{6vyb}{1}{['A_1', 'B_1']} # Example for a distillation dataset: {['af2_distillation']}{6vyb}{1}{['A_1']} return f"{{{dataset_names}}}{{{pdb_id}}}{{{assembly_id}}}{{{query_pn_unit_iids}}}"
[docs] def parse_example_id(example_id: str) -> dict: """ Parse the example ID into its components: dataset names, pdb_id, assembly_id, and query_pn_unit_iids. Args: example_id (str): The example ID string generated by the generate_example_id function. Returns: dict: A dictionary containing the parsed components. """ # Use regular expression to find all parts within curly braces matches = re.findall(r"\{(.*?)\}", example_id) if len(matches) != 4: raise ValueError(f"Invalid example ID format: {example_id}, with {len(matches)} matches found (expected 4).") # Extract the components datasets = eval(matches[0]) # Convert string representation of list to actual list pdb_id = matches[1] assembly_id = matches[2] # Convert string representation of list to actual list query_pn_unit_iids = eval(matches[3]) if matches[3] else [] return { "datasets": datasets, "pdb_id": pdb_id, "assembly_id": assembly_id, "query_pn_unit_iids": query_pn_unit_iids, }
[docs] def as_list(value: Any) -> list: """Convert a value to a list. Handles various types using duck typing: - Iterable objects (lists, tuples, strings, etc.): converted to list - Single values: wrapped in a list """ try: # Try to iterate over the value (duck typing approach) # Exclude strings since they're iterable but we want to treat them as single values if isinstance(value, str): return [value] return list(value) except TypeError: # If it's not iterable, wrap it in a list return [value]