Source code for atomworks.ml.transforms._checks

"""
Convenience utils for common validation checks in transforms.

All checks take a `data` dictionary as input and raise an error if the check fails.
"""

from __future__ import annotations

from typing import Any

from atomworks.io.utils.selection import get_annotation_categories


[docs] def check_contains_keys(data: dict[str, Any], keys: list[str]) -> None: """Check if a key is in a dictionary.""" for key in keys: if key not in data: raise KeyError(f"Key `{key}` not in data. Available keys: {list(data.keys())}")
[docs] def check_does_not_contain_keys(data: dict[str, Any], keys: list[str]) -> None: """Check if a key is not in a dictionary.""" for key in keys: if key in data: raise KeyError(f"Key `{key}` already exists in data")
[docs] def check_is_instance(data: dict[str, Any], key: str, expected_type: type) -> None: """Check if the value of a key in a dictionary is of a certain type.""" if not isinstance(data[key], expected_type): raise ValueError(f"Key `{key}` in data is not of type `{expected_type}`, got {type(data[key])}")
[docs] def check_is_shape(data: dict[str, Any], key: str, expected_shape: tuple[int, ...]) -> None: """Check if the value of a key in a dictionary has a certain shape.""" if data[key].shape != expected_shape: raise ValueError(f"Key `{key}` in data has shape {data[key].shape} but expected shape {expected_shape}")
[docs] def check_nonzero_length(data: dict[str, Any], key: str) -> None: """Check if the length of the value of a key in a dictionary is nonzero.""" if len(data[key]) == 0: raise ValueError(f"Key {key} in data has length 0")
[docs] def check_atom_array_annotation( data: dict[str, Any], required: list[str], forbidden: list[str] = [], n_body: int = 1 ) -> None: """Check if `atom_array` key has the annotations specified in `required`.""" annotations = set(get_annotation_categories(data["atom_array"], n_body=n_body)) if not set(required).issubset(annotations): missing = set(required) - annotations raise ValueError(f"Key `atom_array` is missing the following annotations: {missing}") if len(forbidden) > 0 and set(forbidden).issubset(annotations): raise ValueError(f"Key `atom_array` has the following forbidden annotations: {forbidden}")
[docs] def check_atom_array_has_bonds(data: dict[str, Any]) -> None: """Check if `atom_array` key has bonds.""" if data["atom_array"].bonds is None: raise ValueError("Key `atom_array` in data has no `bonds` defined.")