Source code for atomworks.ml.utils.debug

import logging
import os
import pickle
import re
from datetime import datetime

from atomworks.ml.common import default

logger = logging.getLogger("atomworks.ml")
_USER = default(os.getenv("USER"), "")

try:
    import wandb
except ImportError:
    wandb = None


def _remove_special_characters(s: str) -> str:
    assert isinstance(s, str)
    # Remove unwanted characters using regex
    clean_s = re.sub(r"[^a-zA-Z0-9_]", "", s)
    return f"{clean_s}"


[docs] def save_failed_example_to_disk( example_id: str, fail_dir: str, *, data: dict = {}, rng_state_dict: dict = {}, error_msg: str = "", ) -> None: """ Attempts to save a failed example to disk as a pickle file. Args: - example_id (str): The ID of the example. - fail_dir (str): The directory where the failed example should be saved. Defaults to a specific path. - rng_state_dict (dict): The random number generator state dictionary. - error_msg (str): The error message associated with the failure. Returns: None """ try: # Get wandb run ID if currently in a wandb run run_id = "" if wandb is not None and hasattr(wandb, "run") and wandb.run is not None: run_id = wandb.run.id file_path = os.path.join(fail_dir, run_id, _remove_special_characters(example_id) + ".pkl") # Ensure the fail directory exists os.makedirs(os.path.dirname(file_path), exist_ok=True, mode=0o777) # Allow everyone to read/write with open(file_path, "wb") as f: data = { "example_id": example_id, "rng_state_dict": rng_state_dict, "error_msg": error_msg, "wandb_run_id": run_id, "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "user": _USER, } | data pickle.dump(data, f) except KeyboardInterrupt as e: raise e except Exception as e: logger.warning( f"Failed to save failed example to disk: {e}. Are you sure the directory exists? Do you have write permissions?" )