Source code for atomworks.ml.transforms.base
"""Base classes for transformations."""
from __future__ import annotations
import contextlib
import logging
import os
import pickle
import pprint
import re
import time
from abc import ABC, ABCMeta, abstractmethod
from collections.abc import Callable, Iterable
from typing import Any, ClassVar
import numpy as np
import torch
from toolz import valmap
from atomworks.ml.transforms._checks import check_contains_keys, check_does_not_contain_keys
from atomworks.ml.utils.rng import capture_rng_states, rng_state, serialize_rng_state_dict
logger = logging.getLogger("transforms")
DEBUG = os.getenv("DEBUG", True)
if DEBUG:
logger.setLevel(logging.DEBUG)
logger.debug("Debug mode is on")
import traceback
else:
logger.setLevel(logging.INFO)
[docs]
class TransformPipelineError(Exception):
"""A custom error class for Transform pipelines (via `Compose`)."""
def __init__(self, message: str, rng_state_dict: dict[str, Any] | None = None):
super().__init__(message)
# expose RNG state dict for debugging
self.rng_state_dict = rng_state_dict
[docs]
class TransformedDict(dict):
"""A thin wrapper around a dictionary that can be used to track the transform history."""
def __new__(cls, __existing_dict_to_wrap: dict[str, Any] | None = None, **kwargs):
"""Create a new instance or return the existing TransformedDict instance.
NOTE: To get a pure dictionary, simply use `dict(transformed_dict)` on a TransformedDict instance.
TransformedDict's behave just like dicts for all intents and purposes, so you can use them just like
a regular dictionary.
Args:
__existing_dict_to_wrap (dict, optional): This is useful for wrapping an existing dictionary.
The odd name `__existing_dict_to_wrap` is used as an unlikely name to avoid conflicts
with the `dict` class.
**kwargs: Additional keyword arguments to pass to the dictionary constructor. This ensures
that a TransformedDict can be initialized just like a regular dictionary if no existing
dictionary to wrap is provided.
"""
# if the argument is already a TransformedDict, return it
if isinstance(__existing_dict_to_wrap, TransformedDict):
return __existing_dict_to_wrap
# otherwise, instantiate a new built-in `dict`
instance = super().__new__(cls)
# ... and update it with the provided dictionary if given
if __existing_dict_to_wrap is not None:
assert len(kwargs) == 0, "Either `__existing_dict_to_wrap` or `kwargs` must be provided, but not both."
# ... if the argument is a dict, update the instance with the dict
instance.update(__existing_dict_to_wrap)
# ... or update it with the keyword arguments if given
else:
# ... if the argument is a dict, update the instance with the dict
instance.update(kwargs)
# set the transform history tracker
instance.__transform_history__ = []
return instance
[docs]
class Transform(ABC):
"""
Abstract base class for transformations on dictionary objects.
Class level attributes:
- validate_input (bool): Whether to validate the input.
- raise_if_invalid_input (bool): Whether to raise an error if the input is invalid.
- requires_previous_transforms (list[str]): Transforms that must have been applied before this transform.
- incompatible_previous_transforms (list[str]): Transforms that cannot have preceeded this transform.
- previous_transforms_order_matters (bool): Whether the order of the transforms is important.
- _track_transform_history (bool): Whether to track the transform history.
To write a subclass, you need to implement the following methods:
- check_input(data: dict): Validates the input data. Should raise an error if the input is invalid.
The returned value is not used.
- forward(data: dict): Applies the transformation to the input data and returns the transformed data.
"""
validate_input: bool = True
raise_if_invalid_input: bool = True
requires_previous_transforms: ClassVar[list[str]] = []
incompatible_previous_transforms: ClassVar[list[str]] = []
previous_transforms_order_matters: bool = False
_track_transform_history: bool = True
# To be implemented by subclasses (optional)
[docs]
def check_input(self, data: dict[str, Any]) -> None: # noqa: B027
"""
Check if the input dictionary is valid for the transform. Raises an error if the input is invalid.
"""
pass
[docs]
@abstractmethod
def forward(self, data: dict[str, Any], *args, **kwargs) -> dict[str, Any]:
"""
Apply a transformation to the input dictionary and return the transformed dictionary.
Parameters:
data (dict): The input dictionary to transform.
Returns:
dict: The transformed dictionary.
"""
pass
# Internal logic for formatting error messages, debugging, logging and transform history tracking
def _format_error_msg(self, e: Exception) -> str:
"""
Formats the error message with optional traceback when in DEBUG mode.
"""
msg = f"Invalid input for {self.__class__.__name__}: {e}"
if DEBUG:
msg += f"\n\n{traceback.format_exc()}\n" + "=" * 80
return msg
def _transform_to_str(self, t: str | Transform | ABCMeta) -> str:
"""
Convert a transform to a string.
"""
if isinstance(t, str):
# case: transform was provided as string, e.g. as `"RemoveKeys"`
return t
elif isinstance(t, ABCMeta):
# case: transform was provided as class, e.g. as `RemoveKeys`
return t.__name__
elif isinstance(t, Transform):
# case: transform was provided as instance, e.g. as `RemoveKeys()`
return t.__class__.__name__
else:
raise ValueError(f"Transform `{t}` cannot be converted to a string form for comparison of history.")
def _ensure_has_transform_history(self, data: dict[str, Any] | TransformedDict) -> TransformedDict:
"""Ensure that the data dictionary has a transform history by wrapping it in a `TransformedDict`."""
data = TransformedDict(data)
return data
def _get_transform_history(self, data: TransformedDict) -> list[str]:
"""
Get the transform history from the data.
"""
return data.__transform_history__
def _maybe_update_transform_history(self, data: TransformedDict) -> dict[str, Any]:
"""
Update the transform history by appending the current transform to the transform history.
"""
if self._track_transform_history:
this_transform_record = {
"name": self.__class__.__name__,
"instance": hex(id(self)),
"start_time": time.time(),
"end_time": None,
"processing_time": None,
}
# record the current transform in the transform history
data.__transform_history__ = [*data.__transform_history__, this_transform_record]
return data
def _maybe_restore_transform_history(self, data: TransformedDict, transform_history: list[str]) -> dict[str, Any]:
"""
Restore the transform history, in case the data was copied.
"""
if not hasattr(data, "__transform_history__") or len(data.__transform_history__) == 0:
# restore previous transform history if it is not present (e.g. if the data was copied)
data.__transform_history__ = transform_history
return data
def _maybe_record_processing_time(self, data: TransformedDict) -> dict[str, Any]:
"""
Record the processing time for the transform.
"""
if self._track_transform_history and len(data.__transform_history__) > 0:
for reverse_idx in range(len(data.__transform_history__) - 1, -1, -1):
# record the processing time for the current transform
record = data.__transform_history__[reverse_idx]
if record["instance"] == hex(id(self)):
start_time = record["start_time"]
end_time = time.time()
data.__transform_history__[reverse_idx]["end_time"] = end_time
data.__transform_history__[reverse_idx]["processing_time"] = end_time - start_time
return data
def _check_transform_history(self, data: TransformedDict) -> None:
"""
Check if the previous transforms are valid for the transform.
Raises an error if the input is invalid.
"""
# extract the transform history
history = [record["name"] for record in data.__transform_history__]
# ensure that `incompatible_previous_transforms` did not get applied
for t in self.incompatible_previous_transforms:
t = self._transform_to_str(t)
if t in history:
raise ValueError(
f"Transform `{self.__class__.__name__}` cannot be applied if any of the transforms {self.incompatible_previous_transforms} "
f"have been applied before it. Current transform history: {history}"
)
# get indices of `requires_previous_transforms` in the transform history
indices = []
for t in self.requires_previous_transforms:
t = self._transform_to_str(t)
pattern = re.compile(t)
matches = [index for index, t in enumerate(history) if pattern.search(t)]
if len(matches) == 0:
raise ValueError(f"Transform `{t}` is missing from the transform history, which is {history}.")
elif len(matches) > 1:
raise ValueError(
f"Transform `{t}` appears multiple times in the transform history, which is {history}."
)
assert len(matches) == 1
indices.append(matches[0])
# check if the indices are in the correct order
if self.previous_transforms_order_matters and (indices != sorted(indices)):
current_order = ">".join([history[i] for i in sorted(indices)])
required_order = ">".join(self.requires_previous_transforms)
raise ValueError(
f"Transform `{self.__class__.__name__}` requires the transforms {required_order} "
f"to have been applied before it in this order, but the current order is {current_order}."
)
def __call__(self, data: dict[str, Any], *args, **kwargs) -> dict[str, Any]:
"""
Validate and apply the transformation to the given dictionary.
Raises:
ValueError: If the input is invalid and raise_if_invalid_input is True.
"""
# enable history tracking if it is not already enabled
data = self._ensure_has_transform_history(data)
# validate input
if self.validate_input:
try:
# check if the input is valid
self._check_transform_history(data)
self.check_input(data)
except Exception as e:
# if it is not valid, log or raise an error
formatted_msg = self._format_error_msg(e)
if self.raise_if_invalid_input:
logger.error(formatted_msg)
raise TransformPipelineError(formatted_msg) from e
else:
logger.warning(formatted_msg)
return data
# update transform history if it is being tracked
data = self._maybe_update_transform_history(data)
# get previous transform history (needed for `_maybe_restore_transform_history` later)
# (NOTE: It is neccessary to carry the transform history outside the `forward` method
# and the `data` object to allow users to seamlessly copy the dict and work with the
# dict without losing the transform history.)
transform_history = self._get_transform_history(data)
# apply the transformation
data = self.forward(data, *args, **kwargs)
assert isinstance(
data, dict
), f"`forward` method of {self.__class__.__name__} must return a dictionary, not {type(data)}."
# restore the transform history if `data` was copied (which loses the transform history)
data = self._ensure_has_transform_history(data)
data = self._maybe_restore_transform_history(data, transform_history)
data = self._maybe_record_processing_time(data)
return data
def __repr__(self) -> str:
"""String representation of the transform for debugging, notebooks and logging."""
# Get all the attributes of the class
repr_str = f"{self.__class__.__name__} at {hex(id(self))}"
if len(self.__dict__) > 0:
attributes = [
f"{k}={pprint.pformat(v, indent=2, depth=1, compact=True, sort_dicts=False)}"
for k, v in self.__dict__.items()
]
repr_str += "(\n " + ",\n ".join(attributes) + "\n)"
return repr_str
def __add__(self, other: Transform) -> Compose:
# Case 1: self & other are `Compose` instances
# ... overridden in `Compose` class
# Case 2: self is a `Compose` instance and other is a `Transform` instance
# ... overridden in `Compose` class
# Case 3: self is a `Transform` instance and other is a `Compose` instance
if isinstance(self, Transform) and isinstance(other, Compose):
return Compose([self, *other.transforms], track_rng_state=other.track_rng_state)
# Case 4: self & other are simple `Transform` instances
elif isinstance(self, Transform) and isinstance(other, Transform):
return Compose([self, other])
# Case 5: other is not a `Transform` instance
else:
raise ValueError(f"Expected a Transform or Compose, but got a {type(other)}")
[docs]
class Compose(Transform):
"""
Compose multiple transformations together.
This class allows you to chain multiple transformations and apply them sequentially to a data dictionary.
It is particularly useful for preprocessing pipelines where multiple steps need to be applied in a specific order.
Attributes:
- transforms (list[Transform]): A list of transformations to be applied.
- track_rng_state (bool): Whether to track and serialize the random number generator (RNG) state. This is
useful for debugging when dealing with probabilistic transformations. The RNG state is returned with
the error message if the transform pipeline fails, allowing you to instantiate the same RNG state
with `eval` for debugging.
"""
_track_transform_history: bool = False # Compose does not show up in the transform history
def __init__(self, transforms: list[Transform], track_rng_state: bool = True, print_rng_state: bool = False):
"""
Initialize the Compose transformation pipeline.
Args:
- transforms (list[Transform]): A list of transformations to be applied sequentially.
- track_rng_state (bool): Whether to track and serialize the random number generator (RNG) state.
This is useful for debugging when dealing with probabilistic transformations. The RNG state
is returned with the error message if the transform pipeline fails, allowing you to instantiate
the same RNG state with `eval` for debugging.
- print_rng_state (bool): Whether to print the RNG state upon failure. This can be useful
for debugging and reproducing specific states for transforms with stochasticity.
Raises:
ValueError: If `transforms` is not a list or tuple, if it is empty, or if it contains elements that
are not instances of `Transform`.
"""
if not isinstance(transforms, list | tuple):
raise ValueError(f"Expected a list or tuple of Transforms, but got a {type(transforms)}")
if not len(transforms) > 0:
raise ValueError("Got an empty list of transforms.")
if not all(isinstance(t, Transform) for t in transforms):
invalid_type = next(t for t in transforms if not isinstance(t, Transform))
raise ValueError(f"Expected a list or tuple of Transforms, but got a {type(invalid_type)}")
self.transforms = transforms
self.track_rng_state = track_rng_state
self.latest_rng_state_dict = None
self.print_rng_state = print_rng_state
def __add__(self, other: Transform | list[Transform] | Compose) -> Compose:
if isinstance(other, Compose):
return Compose(
self.transforms + other.transforms, track_rng_state=self.track_rng_state or other.track_rng_state
)
elif isinstance(other, Transform):
return Compose([*self.transforms, other], track_rng_state=self.track_rng_state)
elif isinstance(other, list):
return Compose(self.transforms + other, track_rng_state=self.track_rng_state)
else:
raise ValueError(f"Expected a Transform or list of Transforms or Compose, but got a {type(other)}")
def _stop_transforms(
self,
next_transform: Transform,
next_transform_idx: int,
stop_before: Transform | int | str | None = None,
) -> bool:
if stop_before is None:
return False
elif isinstance(stop_before, int):
return next_transform_idx == stop_before
elif isinstance(stop_before, str):
return next_transform.__class__.__name__ == stop_before
elif isinstance(stop_before, Transform):
return next_transform.__class__.__name__ == stop_before.__class__.__name__
else:
raise ValueError(f"Expected a Transform or str or int, but got a {type(stop_before)}")
[docs]
def forward(
self,
data: dict,
rng_state_dict: dict[str, Any] | None = None,
_stop_before: Transform | str | int | None = None,
) -> dict:
"""
Apply a series of transformations to the input data.
Args:
data (dict): The input data to be transformed.
rng_state_dict (dict[str, Any] | None, optional): Random number generator state dictionary.
If provided, sets the RNG state before applying transforms. Defaults to None.
_stop_before (Transform | str | int | None, optional): Specifies a point to stop the transformation
process. Can be a Transform instance, a string (transform class name), or an integer (index).
Defaults to None.
Returns:
dict: The transformed data.
Raises:
Exception: If any transform in the pipeline fails, with details about the failure point and RNG state.
"""
# set the RNG state context if given
with (
rng_state(rng_state_dict, include_cuda=False) if rng_state_dict else contextlib.nullcontext()
) as rng_state_dict:
if self.track_rng_state and rng_state_dict is None:
# collect RNG states at the start of the pipeline and execute the transforms
rng_state_dict = capture_rng_states()
self.latest_rng_state_dict = rng_state_dict
try:
# execute the transforms
for idx, transform in enumerate(self.transforms):
if self._stop_transforms(transform, idx, _stop_before):
# ... capability to stop before a specific transform for debugging
break
# ... otherwise apply the transform
data = transform(
data
) # BREAKPOINT: Set debug breakpoint here to step through the transforms one-by-one
except KeyboardInterrupt:
raise
except Exception as e:
# construct error message including the RNG states
msg = f"Transforms failed at stage `{transform.__class__.__name__}`: " + str(e)
if "example_id" in data:
msg += f"\nFailure occurred for example ID: {data['example_id']}."
if self.track_rng_state and self.print_rng_state:
msg += "\nRandom number generator states at the start of the pipeline (you can instantiate the string below with `eval` for debugging):\n"
msg += repr(serialize_rng_state_dict(rng_state_dict))
# Update error message of original exception
e.args = (msg,)
# Raise the new custom exception with the original traceback
raise e.with_traceback(e.__traceback__) # noqa: B904
return data
def __repr__(self) -> str:
return "Compose(\n " + ",\n ".join([str(t.__class__.__name__) for t in self.transforms]) + "\n)"
def __len__(self) -> int:
return len(self.transforms)
def __getitem__(self, idx: int | slice | Iterable[int]) -> Transform:
if isinstance(idx, slice):
return Compose(self.transforms[idx], track_rng_state=self.track_rng_state)
elif hasattr(idx, "__iter__"):
return Compose([self.transforms[i] for i in idx], track_rng_state=self.track_rng_state)
else:
return self.transforms[idx]
[docs]
class RemoveKeys(Transform):
"""
Remove keys from the data dictionary.
"""
def __init__(self, keys: list[str], require_keys_exist: bool = True):
self.keys = keys
self.validate_input = require_keys_exist
[docs]
def forward(self, data: dict) -> dict:
for key in self.keys:
if key in data:
del data[key]
return data
[docs]
class SubsetToKeys(Transform):
"""
Keep only the keys in the data dictionary.
"""
def __init__(self, keys: list[str], require_keys_exist: bool = True):
self.keys = keys
self.validate_input = require_keys_exist
[docs]
def forward(self, data: dict) -> dict:
return {key: data[key] for key in self.keys if key in data}
[docs]
class AddData(Transform):
"""
Add data to the data dictionary.
"""
def __init__(self, data: dict, allow_overwrite: bool = False):
self.data = data
self.validate_input = not allow_overwrite
[docs]
def check_input(self, data: dict) -> None:
check_does_not_contain_keys(data, self.data.keys())
[docs]
class LogData(Transform):
"""
Log the data dictionary. Meant for debugging.
"""
_track_transform_history: bool = False # LogData does not show up in the transform history
def __init__(self, log_level: int = logging.INFO, depth: int | None = 1, **pprint_kwargs):
assert depth is None or depth > 0, "Depth must be a positive integer or None"
self.log_level = log_level
self.depth = depth
self.pprint_kwargs = pprint_kwargs
[docs]
def forward(self, data: dict) -> dict:
# Construct log message
msg = "=" * 80 + "\n"
msg += f"Data: \n{pprint.pformat(data, indent=2, depth=self.depth, sort_dicts=False, **self.pprint_kwargs)}\n"
msg += "=" * 80
# Log the message
logger.log(
level=self.log_level,
msg=msg,
)
return data
[docs]
class PickleToDisk(Transform):
"""
Save the data dictionary to a pickle file.
"""
def __init__(
self,
dir_path: str,
file_name_func: Callable[[dict], str] | None = None,
save_transform_history: bool = False,
overwrite: bool = False,
):
self.dir_path = dir_path
self.file_name_func = file_name_func
self.overwrite = overwrite
self.save_transform_history = save_transform_history
if not file_name_func:
file_name_func = lambda data: f"{data['id']}.pkl" # noqa
# Ensure the directory exists
os.makedirs(self.dir_path, exist_ok=True)
[docs]
def forward(self, data: dict) -> dict:
file_name = self.file_name_func(data)
file_path = os.path.join(self.dir_path, file_name)
if os.path.exists(file_path) and not self.overwrite:
raise ValueError(f"File {file_path} already exists. Set overwrite=True to overwrite it.")
with open(file_path, "wb") as f:
# NOTE: We cast the data to a dict to ensure that the data is serializable
# and that deserialization does not fail due to the presence of custom classes.
# (in particular the `TransformedDict` class)
pickle.dump(dict(data), f)
return data
[docs]
class RaiseError(Transform):
"""
Raises an error for testing and debugging purposes.
"""
def __init__(self, error_type: Exception = ValueError, error_message: str = "User requested raising an error."):
self.error_type = error_type
self.error_message = error_message
[docs]
class Identity(Transform):
"""
Identity transform. Does nothing and just passes the data through.
"""
validate_input = False
raise_if_invalid_input = False
_track_transform_history = False
[docs]
class RandomRoute(Transform):
"""
Route probabilistically between various transforms.
This transform is useful for routing between different transforms probabilistically, e.g. for
sampling different cropping strategies.
"""
validate_input = False
raise_if_invalid_input = False
_track_transform_history: bool = True # RandomRoute records history because it changes the RNG state
def __init__(self, transforms: list[Transform], probs: list[float]):
"""
Initializes the RandomRoute transform.
Args:
transforms (list[Transform]): A list of transformations to route between.
probs (list[float]): A list of probabilities corresponding to each transform. The probabilities
must be non-negative and sum to 1. There must be as many probabilities as there are transforms.
Raises:
AssertionError: If inputs are invalid (e.g. probabilities don't add up, are negative, etc.)
"""
# Validate inputs
assert len(transforms) == len(probs), (
f"Number of transforms must match number of probabilities. "
f"Got {len(transforms)} transforms and {len(probs)} probabilities."
)
assert np.isclose(np.sum(probs), 1) or np.isclose(
np.sum(probs), 0
), f"Probabilities must sum to 1 or 0. Got {np.sum(probs)}"
assert all(isinstance(t, Transform) for t in transforms), (
f"All elements in transforms must be Transform instances. "
f"Got {type(next(t for t in transforms if not isinstance(t, Transform)))}"
)
self.transforms = transforms
self.probs = probs
[docs]
@classmethod
def from_dict(cls, transform_dict: dict[Transform, float]) -> RandomRoute:
probs = list(transform_dict.values())
transforms = list(transform_dict.keys())
return cls(transforms, probs)
[docs]
@classmethod
def from_list(cls, transform_list: list[tuple[float, Transform]]) -> RandomRoute:
probs, transforms = zip(*transform_list, strict=False)
return cls(transforms, probs)
[docs]
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
# Choose a transform probabilistically
# EDGE CASE: If the probabilities sum to 0, skip the transform
if np.isclose(np.sum(self.probs), 0):
# skip
return data
idx = np.random.choice(len(self.transforms), p=self.probs)
# Apply the transform
return self.transforms[idx](data)
[docs]
class ConditionalRoute(Transform):
"""
Route conditionally between various transforms.
This Transform is useful for routing between different transforms based on a condition, e.g. skipping transforms during inference.
"""
def __init__(self, condition_func: Callable[[dict[str, Any]], Any], transform_map: dict[Any, Transform]):
"""
Initialize the ConditionalRoute transformation.
Args:
condition_func (Callable[[dict[str, Any]], Any]): A function that takes the data dictionary and returns a condition value.
transform_map (dict[Any, Transform]): A dictionary mapping condition values to their corresponding transforms.
Example:
```python
ConditionalRoute(
condition_func=lambda data: data.get("mode", "inference"),
transform_map={
"train": TrainingTransform(),
"inference": Identity(),
# Defaults to Identity if no match; "inference" included for clarity
},
)
```
"""
self.condition_func = condition_func
self.transform_map = transform_map
[docs]
def check_input(self, data: dict[str, Any]) -> None:
# No specific input validation required for routing
pass
[docs]
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
"""
Apply the appropriate transformation based on the condition value.
Args:
data (dict[str, Any]): The input data dictionary.
Returns:
dict[str, Any]: The transformed data dictionary.
"""
condition_value = self.condition_func(data)
transform = self.transform_map.get(condition_value, Identity())
return transform(data)
[docs]
def convert_to_torch(data: dict[str, Any], keys: list[str], device: str = "cpu") -> dict[str, Any]:
"""Convert the contents of specified `data` keys to torch tensors and move them to the specified device.
For each given top-level `data` key, all nested numpy arrays are converted to torch tensors.
Args:
data (dict[str, Any]): The input data dictionary.
keys (list[str]): List of `data` keys within which to search for numpy arrays to convert to torch tensors.
device (str): The device to which the tensors should be moved (e.g., 'cpu', 'cuda'). Default is 'cpu'.
Returns:
dict[str, Any]: The data dictionary with numpy arrays converted to torch tensors.
"""
# Set of supported numpy data types
supported_dtypes = (
np.float64,
np.float32,
np.float16,
np.complex64,
np.complex128,
np.int64,
np.int32,
np.int16,
np.int8,
np.uint64,
np.uint32,
np.uint16,
np.uint8,
np.bool_,
)
def _convert_to_tensor(value: Any) -> Any:
"""Convert a value to a torch tensor if it is a numpy array or recursively handle nested dictionaries."""
if isinstance(value, np.ndarray) and value.dtype in supported_dtypes:
return torch.tensor(value, device=device)
elif isinstance(value, dict):
return valmap(_convert_to_tensor, value)
elif isinstance(value, list):
return [_convert_to_tensor(v) for v in value]
else:
return value
for key in keys:
if key in data:
data[key] = _convert_to_tensor(data[key])
else:
raise KeyError(f"Key '{key}' not found in the data dictionary.")
return data
[docs]
class ConvertToTorch(Transform):
"""
Converts the contents of specified `data` keys to torch tensors and moves them to the specified device.
"""
def __init__(self, keys: list[str], device: str = "cpu"):
self.keys = keys
self.device = device
[docs]
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
return convert_to_torch(data, self.keys, self.device)
[docs]
class RaiseOnCondition(Transform):
"""
Raises a user-specified exception if a given condition is met.
"""
def __init__(self, condition: callable, error_message: str, exception_to_raise: type[Exception] = ValueError):
self.condition = condition
self.error_message = error_message
self.exception_class = exception_to_raise
[docs]
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
if self.condition(data):
raise self.exception_class(self.error_message)
return data