Source code for atomworks.io.transforms.atom_array

"""
Transforms operating predominantly on Biotite's `AtomArray` objects.
These operations should take as input, and return, `AtomArray` objects.
"""

import logging
from collections import Counter, defaultdict

import biotite.structure as struc
import networkx as nx
import numpy as np
import pandas as pd
from biotite.structure import AtomArray, AtomArrayStack, stack

from atomworks.io.common import listmap, not_isin, sum_string_arrays
from atomworks.io.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER, HYDROGEN_LIKE_SYMBOLS, WATER_LIKE_CCDS
from atomworks.io.utils.bonds import (
    generate_inter_level_bond_hash,
    get_coarse_graph_as_nodes_and_edges,
    get_connected_nodes,
    hash_graph,
)
from atomworks.io.utils.ccd import atom_array_from_ccd_code
from atomworks.io.utils.selection import annot_start_stop_idxs

logger = logging.getLogger("atomworks.io")

try:
    import hydride
except ImportError:
    logger.warning("Hydride library not found, hydrogens cannot be inferred. Pip install hydride to enable.")


[docs] def subset_atom_array(atom_array: AtomArray | AtomArrayStack, keep: np.ndarray) -> AtomArray | AtomArrayStack: """Subsets an AtomArray or AtomArrayStack by a boolean mask.""" if isinstance(atom_array, AtomArrayStack): return atom_array[:, keep] else: return atom_array[keep]
[docs] def is_any_coord_nan(atom_array: AtomArray | AtomArrayStack) -> np.ndarray: """Returns a boolean mask of shape [n_atoms] indicating whether any coordinate is NaN for each atom in the AtomArray or AtomArrayStack.""" if isinstance(atom_array, AtomArrayStack): return np.isnan(atom_array.coord).any(axis=(0, -1)) else: return np.isnan(atom_array.coord).any(axis=-1)
[docs] def remove_nan_coords(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: """Returns a copy of the AtomArray or AtomArrayStack with rows where any coordinate is NaN removed.""" return subset_atom_array(atom_array, ~is_any_coord_nan(atom_array))
[docs] def remove_ccd_components( atom_array: AtomArray | AtomArrayStack, ccd_codes_to_remove: list[str] ) -> AtomArray | AtomArrayStack: """ Remove atoms from the AtomArray or AtomArrayStack that have CCD codes in the ccd_codes_to_remove list. Parameters: atom_array (AtomArray): The array of atoms. ccd_codes_to_remove (list): A list of CCD codes to be removed from the atom array. Returns: AtomArray: The filtered atom array. """ ccd_codes_to_remove = list(ccd_codes_to_remove) return subset_atom_array(atom_array, not_isin(atom_array.res_name, ccd_codes_to_remove))
[docs] def remove_hydrogens(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: """Removes hydrogens from the AtomArray or AtomArrayStack.""" keep = not_isin(atom_array.element, HYDROGEN_LIKE_SYMBOLS) return subset_atom_array(atom_array, keep)
[docs] def remove_waters(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: """Removes waters from the AtomArray or AtomArrayStack.""" return remove_ccd_components(atom_array, WATER_LIKE_CCDS)
[docs] def ensure_atom_array_stack(atom_array_or_stack: AtomArray | AtomArrayStack) -> AtomArrayStack: """Ensures that the input is an AtomArrayStack. If it is an AtomArray, it is converted to a stack.""" if isinstance(atom_array_or_stack, AtomArray): return stack([atom_array_or_stack]) elif isinstance(atom_array_or_stack, AtomArrayStack): return atom_array_or_stack else: raise TypeError(f"Expected AtomArray or AtomArrayStack, got {type(atom_array_or_stack)}")
[docs] def resolve_arginine_naming_ambiguity(atom_array: AtomArray, raise_on_error: bool = True) -> AtomArray: """ Arginine naming ambiguities are fixed (ensuring NH1 is always closer to CD than NH2) """ # TODO: Generalize to AtomArrayStack arg_mask = atom_array.res_name == "ARG" arg_nh1_mask = (atom_array.atom_name == "NH1") & arg_mask arg_nh2_mask = (atom_array.atom_name == "NH2") & arg_mask arg_cd_mask = (atom_array.atom_name == "CD") & arg_mask try: cd_nh1_dist = np.linalg.norm(atom_array.coord[arg_cd_mask] - atom_array.coord[arg_nh1_mask], axis=-1) cd_nh2_dist = np.linalg.norm(atom_array.coord[arg_cd_mask] - atom_array.coord[arg_nh2_mask], axis=-1) both_finite = np.isfinite(cd_nh1_dist) & np.isfinite(cd_nh2_dist) # Check if there are any name swamps required local_to_swap = (cd_nh1_dist > cd_nh2_dist) & both_finite # local mask # turn local mask into global mask to_swap = np.zeros(atom_array.array_length(), dtype=bool) to_swap[arg_nh1_mask] = local_to_swap to_swap[arg_nh2_mask] = local_to_swap # Swap NH1 and NH2 names if NH1 is further from CD than NH2 if np.any(to_swap): logger.debug(f"Resolving {np.sum(local_to_swap)} arginine naming ambiguities.") prev_nh1_coord = atom_array.coord[arg_nh1_mask & to_swap] prev_nh2_coord = atom_array.coord[arg_nh2_mask & to_swap] atom_array.coord[arg_nh1_mask & to_swap] = prev_nh2_coord atom_array.coord[arg_nh2_mask & to_swap] = prev_nh1_coord except ValueError as e: if raise_on_error: raise e else: logger.warning(f"Error resolving arginine naming ambiguity: {e}. Returning original atom array.") return atom_array
[docs] def mse_to_met(atom_array: AtomArray) -> AtomArray: """Convert MSE residues (selenomethionine) to MET (methionine).""" mse_mask = atom_array.res_name == "MSE" if np.any(mse_mask): se_mask = (atom_array.atom_name == "SE") & mse_mask logger.debug(f"Converting {np.sum(se_mask)} MSE residues to MET.") # Update residue name, hetero flag, and element atom_array.res_name[mse_mask] = "MET" atom_array.hetero[mse_mask] = False atom_array.atom_name[se_mask] = "SD" # ... handle cases for integer or string representations of element _elt_prev = atom_array.element[se_mask][0] if _elt_prev == "SE": atom_array.element[se_mask] = "S" elif _elt_prev == ELEMENT_NAME_TO_ATOMIC_NUMBER["SE"]: atom_array.element[se_mask] = ELEMENT_NAME_TO_ATOMIC_NUMBER["S"] elif _elt_prev == str(ELEMENT_NAME_TO_ATOMIC_NUMBER["SE"]): atom_array.element[se_mask] = str(ELEMENT_NAME_TO_ATOMIC_NUMBER["S"]) # Reorder atoms for canonical MET ordering atom_array_mse = atom_array[mse_mask] atom_array_mse = atom_array_mse[struc.info.standardize_order(atom_array_mse)] atom_array[mse_mask] = atom_array_mse return atom_array
[docs] def keep_last_residue(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: """ Removes duplicate residues in the atom array, keeping only the last occurrence. Args: atom_array (AtomArray): The atom array containing the chain information. Returns: AtomArray: The atom array with duplicate residues removed. """ atom_df = pd.DataFrame( { "chain_id": atom_array.chain_id, "res_id": atom_array.res_id, "res_name": atom_array.res_name, } ) # Get the mask of duplicates based on the combination of chain_id, res_id, and res_name collapsed_df = atom_df.drop_duplicates(subset=["chain_id", "res_id", "res_name"]) # Get duplicates based on res_id, keeping the last duplicate_mask = collapsed_df.duplicated(subset=["chain_id", "res_id"], keep="last") duplicates_df = collapsed_df[duplicate_mask] # Perform a left merge to find rows in atom_df that are also in duplicates_df merged_df = atom_df.merge(duplicates_df, on=["chain_id", "res_id", "res_name"], how="left", indicator=True) # Create a mask where True indicates the row is not in duplicates_df keep = merged_df["_merge"] == "left_only" # Remove rows from atom_array with the deletion mask return subset_atom_array(atom_array, keep)
[docs] def maybe_fix_non_polymer_at_symmetry_center( atom_array_stack: AtomArrayStack, clash_distance: float = 1.0, clash_ratio: float = 0.5 ) -> AtomArrayStack: """ In some PDB entries, non-polymer molecules are placed at the symmetry center and clash with themselves when transformed via symmetry operations. We should remove the duplicates in these cases, keeping the identity copy. We consider a non-polymer to be clashing with itself if at least `clash_ratio` of its atoms clash with the symmetric copy. Examples: — PDB ID `7mub` has a potassium ion at the symmetry center that when reflected with the symmetry operation clashes with itself. — PDB ID `1xan` has a ligand at a symmetry center that similarly when refelcted clashes with itself. Args: atom_array (AtomArray): The atom array to be patched. clash_distance (float): The distance threshold for two atoms to be considered clashing. clash_ratio (float): The percentage of atoms that must clash for the molecule to be considered clashing. Returns: AtomArray: The patched atom array. """ # Select one model AtomArray to simplify computations atom_array = atom_array_stack[0] # Filter to only atoms with coordinates to avoid non-physical clashes at the origin if "occupancy" in atom_array.get_annotation_categories(): resolved_mask = atom_array.occupancy > 0 else: resolved_mask = np.ones(atom_array.array_length(), dtype=bool) resolved_atom_array = atom_array[(resolved_mask) & (~is_any_coord_nan(atom_array))] if not np.any(~resolved_atom_array.is_polymer): return atom_array_stack # Early exit else: non_polymers = resolved_atom_array[~resolved_atom_array.is_polymer] # [n] # Build cell list for rapid distance computations cell_list = struc.CellList(non_polymers, cell_size=3.0) # Quick check to see whether any non-polymer is closer than 0.05A to any other. clash_matrix = cell_list.get_atoms(non_polymers.coord, clash_distance, as_mask=True) # [n, n] identity_matrix = np.identity(len(non_polymers), dtype=bool) if np.array_equal(clash_matrix, identity_matrix): return atom_array_stack # Early exit else: # Remove identity matrix so we don't count self-clashes clash_matrix = clash_matrix & ~identity_matrix logger.debug("Found clashing non-polymer at a symmetry center, resolving.") # Get list of chain_ids with clashing atoms (for computational efficiency) clashing_atom_mask = np.sum(clash_matrix, axis=1) > 0 clashing_chain_ids = np.unique(non_polymers.chain_id[clashing_atom_mask]) # For each clashing chain, we check whether any non-polymer is clashing with a symmetric copy of itself # We count the clashes with each symmetric copy of itself and remove those that have a clash ratio above the threshold # We keep the identity transformation, or the lowest transformation ID in the case of multiple symmetric copies chain_iids_to_remove = [] for chain_id in clashing_chain_ids: chain_mask = non_polymers.chain_id == chain_id mask = chain_mask & clashing_atom_mask # Mask for clashing atoms in the current chain chain_clash_matrix = clash_matrix[mask][:, mask] # Loop through possible transformation ID's transformation_ids_to_check = sorted(np.unique(non_polymers.transformation_id[mask].astype(str)).tolist()) while transformation_ids_to_check: transformation_id = str(transformation_ids_to_check.pop(0)) transformation_mask = non_polymers.transformation_id == str(transformation_id) # Create matrix where the rows correspond to the atoms of the current transformation and the columns corresponded to the other transformations chain_clash_matrix = clash_matrix[mask & transformation_mask][ :, mask & ~transformation_mask ] # [current transformation clashing atoms, other transformations clashing atoms] # We can then count clashes by transformation ID transformation_id_matrix = np.tile( non_polymers.transformation_id[mask & ~transformation_mask], (chain_clash_matrix.shape[0], 1) ) # Apply chain_clash_matrix to transformation_id_matrix so we can count clashes by transformation ID clashing_transformation_ids = np.where(chain_clash_matrix, transformation_id_matrix, None).flatten() clash_count_by_transformation_id = Counter( clashing_transformation_ids[clashing_transformation_ids != np.array(None)] ) threshold = clash_ratio * np.sum(chain_mask & transformation_mask) # For each transformation ID with a clash ratio above the threshold, note the chain_iid to remove, and remove from the list to check transformation_ids_to_remove = [ trans_id for trans_id, count in clash_count_by_transformation_id.items() if count > threshold ] chain_iids_to_remove.extend([f"{chain_id}_{trans_id}" for trans_id in transformation_ids_to_remove]) transformation_ids_to_check = [ id_ for id_ in transformation_ids_to_check if str(id_) not in transformation_ids_to_remove ] # Filter and return keep_mask = not_isin(atom_array.chain_iid, np.array(chain_iids_to_remove, dtype=atom_array.chain_iid.dtype)) atom_array_stack = atom_array_stack[:, keep_mask] return atom_array_stack
[docs] def add_polymer_annotation(atom_array: AtomArray | AtomArrayStack, chain_info_dict: dict) -> AtomArray | AtomArrayStack: """Adds an annotation to the atom array to indicate whether a chain is a polymer. Args: atom_array (AtomArray): The atom array containing the chain information. chain_info_dict (dict): Dictionary containing the sequence details of each chain. Returns: AtomArray: The updated atom array with the polymer annotation added. """ chain_ids = atom_array.get_annotation("chain_id") is_polymer = np.array([chain_info_dict[chain_id]["is_polymer"] for chain_id in chain_ids]) atom_array.set_annotation("is_polymer", is_polymer) return atom_array
[docs] def update_nonpoly_seq_ids(atom_array: AtomArray, chain_info_dict: dict) -> AtomArray: """ Updates the sequence IDs of non-polymeric chains in the atom array to the author sequence IDs. Args: atom_array (AtomArray): The atom array containing the chain information. chain_info_dict (dict): Dictionary containing the sequence details of each chain. Returns: AtomArray: The updated atom array with the sequence IDs updated for non-polymeric chains. """ # For non-polymeric chains, we use the author sequence ids author_seq_ids = atom_array.get_annotation("auth_seq_id") chain_ids = atom_array.get_annotation("chain_id") # Create mask based on the is_polymer column non_polymer_mask = ~np.array([chain_info_dict[chain_id]["is_polymer"] for chain_id in chain_ids]) # Update the atom_array_label with the (1-indexed) author sequence ids atom_array.res_id[non_polymer_mask] = author_seq_ids[non_polymer_mask] return atom_array
[docs] def replace_negative_res_ids_with_auth_seq_id(atom_array: AtomArray) -> AtomArray: """ Replaces res_id values of -1 with the corresponding auth_seq_id values. When loading from the PDB, this step is generally not needed; however, some AF-3 predictions have negative res_ids without labeling chains as non-polymeric via the entity_id field. Args: atom_array (AtomArray): The atom array to fix. Returns: AtomArray: The updated atom array with negative res_ids replaced by auth_seq_ids. """ author_seq_ids = atom_array.get_annotation("auth_seq_id") negative_res_id_mask = atom_array.res_id == -1 # Convert auth_seq_ids to int if they are strings (as they are sometimes from AF-3 predictions) if author_seq_ids.dtype.kind in "UO": # Unicode or Object (string-like) # Handle '.' values by replacing with -1, then convert to int author_seq_ids = np.where(author_seq_ids == ".", -1, author_seq_ids).astype(int) atom_array.res_id[negative_res_id_mask] = author_seq_ids[negative_res_id_mask] return atom_array
[docs] def add_charge_from_ccd_codes(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: """ Adds charge annotations to an atom array based on the Chemical Component Dictionary (CCD) codes. Retrieves charge information from the CCD for each residue and assigns it to matching atoms. If a residue or atom is not found in the CCD, a charge of 0 is assigned. If a charge annotation is already present, it is overwritten. Args: atom_array: The atom array to which charge annotations will be added. Can be either an AtomArray or AtomArrayStack. Returns: The input atom array with added charge annotations. WARNING: This function will assume that each residue in the atom array is exactly as in the CCD. Therefore the charges will be incorrect if it is in a different protonation state, or has been ionized or else in the original structure. Use this function with caution! NOTE: If you want to add charges to canonical amino acids based on the pH, take a look at the 'hydride.estimate_amino_acid_charges' function instead. Example: >>> atom_array = load_any("6lyz.cif", model=1) >>> atom_array_with_charges = add_charge_from_ccd_codes(atom_array) """ # Warn if a charge annotation is already present if "charge" in atom_array.get_annotation_categories(): logger.info("Charge annotation already present in atom array. It will be overwritten.") # Build up a lookup table (res_name, atom_name) -> charge for each res_name that appears in the atom_array unique_res_names = np.unique(atom_array.res_name) charge_lookup_table: dict[tuple[str, str], float] = {} for res_name in unique_res_names: try: ccd_array = atom_array_from_ccd_code(res_name) # Use dictionary comprehension to build lookup entries for this residue charge_lookup_table.update( { (res_name, atom_name): charge for atom_name, charge in zip(ccd_array.atom_name, ccd_array.charge, strict=False) } ) except ValueError: logger.info(f"CCD charge look-up failed for {res_name}. Assuming charge is 0 for all atoms.") continue # Create the charge annotations for the atom array res_names = atom_array.res_name atom_names = atom_array.atom_name charges = np.array( [charge_lookup_table.get((res, atom), 0) for res, atom in zip(res_names, atom_names, strict=False)] ) # Set the charges annotation atom_array.set_annotation("charge", charges) return atom_array
[docs] def add_hydrogen_atom_positions( atom_array: AtomArray | AtomArrayStack, residue_level_annots_to_copy_to_hydrogens: list[str] = [], ) -> AtomArray | AtomArrayStack: """Add hydrogens using biotite supported hydride library. Removes any existing hydrogens first, then adds new hydrogens using the hydride library. Args: atom_array: The atom array to which hydrogens will be added. residue_level_annots_to_copy_to_hydrogens (list[str]): A list of residue-level annotations that will be copied over to the newly-added hydrogens for each residue. Returns: The updated atom array with hydrogens added, preserving the input type. """ # Remove existing hydrogens atom_array = remove_hydrogens(atom_array) # Determine which fields to copy from the original array to the new hydrogens fields_to_copy_from_residue_if_present = ["auth_seq_id", "label_entity_id"] fields_to_copy_from_residue_if_present.extend(residue_level_annots_to_copy_to_hydrogens) fields_to_copy_from_residue_if_present = list( set(fields_to_copy_from_residue_if_present).intersection(set(atom_array.get_annotation_categories())) ) # Ensure charge annotation exists if "charge" not in atom_array.get_annotation_categories(): atom_array = add_charge_from_ccd_codes(atom_array) # Helper function to copy annotations from one array to another def _copy_missing_annotations_residue_wise( from_array: AtomArray, to_array: AtomArray, fields_to_copy: list[str] ) -> AtomArray: """Copy specified annotations residue-wise from one AtomArray to another. Updates annotations in-place.""" residue_starts = struc.get_residue_starts(from_array) residue_starts_atom_array = from_array[residue_starts] annot = {item: getattr(residue_starts_atom_array, item) for item in fields_to_copy_from_residue_if_present} for field in fields_to_copy: updated_field = struc.spread_residue_wise(to_array, annot[field]) to_array.set_annotation(field, updated_field) return to_array def _add_hydrogens_nan_tolerant(atom_array: AtomArray) -> AtomArray: """Adds hydrogens to the input AtomArray, safely handling the case in which some atoms have NaN coordinates""" original_nan_coords_mask = np.any(np.isnan(atom_array.coord), axis=1) if np.any(original_nan_coords_mask): # Temporarily set NaN coordinates to zero so that hydride doesn't error atom_array.coord[original_nan_coords_mask] = np.zeros((np.sum(original_nan_coords_mask), 3)) # Add hydrogens using hydride result_atom_array, original_atoms_mask = hydride.add_hydrogen(atom_array) # Reset the coordinates of atoms that originally had at least one NaN coordinate to be fully NaN originally_nan_inds = np.arange(result_atom_array.array_length())[original_atoms_mask][ original_nan_coords_mask ] result_atom_array.coord[originally_nan_inds, :] = np.nan # For any newly-added hydrogens bonded to heavy atoms with NaN coordinates, set their coordinates to NaN as well result_nan_coords_mask = np.any(np.isnan(result_atom_array.coord), axis=1) heavy_atom_nan_idces = np.where(result_nan_coords_mask & ~(result_atom_array.element == "H"))[0] for idx in heavy_atom_nan_idces: bonded_atoms = result_atom_array.bonds.get_bonds(idx)[0] bonded_h_atoms = bonded_atoms[result_atom_array[bonded_atoms].element == "H"] new_bonded_h_atoms = bonded_h_atoms[~original_atoms_mask[bonded_h_atoms]] result_atom_array.coord[new_bonded_h_atoms, :] = np.nan else: result_atom_array, original_atoms_mask = hydride.add_hydrogen(atom_array) return result_atom_array if isinstance(atom_array, AtomArrayStack): updated_arrays = [] for old_arr in atom_array: arr = _add_hydrogens_nan_tolerant(old_arr) arr = _copy_missing_annotations_residue_wise(old_arr, arr, fields_to_copy_from_residue_if_present) updated_arrays.append(arr) ret_array = struc.stack(updated_arrays) elif isinstance(atom_array, AtomArray): arr = _add_hydrogens_nan_tolerant(atom_array) ret_array = _copy_missing_annotations_residue_wise(atom_array, arr, fields_to_copy_from_residue_if_present) return ret_array
[docs] def add_pn_unit_id_annotation(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: """ Adds the polymer/non-polymer unit ID (pn_unit_id) annotation to the AtomArray. Two covalently bonded ligands are considered one PN unit, but a ligand bonded to a protein is considered two PN units. See the README glossary for more details on how we define `chains`, `pn_units`, and `molecules` within this codebase. Args: atom_array (AtomArray): The AtomArray to process. Returns: atom_array (AtomArray): The AtomArray including the `pn_unit_id` annotation. """ # ...initialize the pn_unit_id to chain_id (we will later update for multi-chain non-polymer PN units) pn_unit_id_annotation = atom_array.chain_id.astype(object) # ...make the NetworkX graph for non-polymer chains non_polymer_atom_array = atom_array[~atom_array.is_polymer] connected_chains = get_connected_nodes(*get_coarse_graph_as_nodes_and_edges(non_polymer_atom_array, "chain_id")) for connected_chain in connected_chains: # ...set the same the pn_unit_id for each chain in the connected chain pn_unit_id = ",".join(sorted(connected_chain)) for chain_id in connected_chain: pn_unit_id_annotation[atom_array.chain_id == chain_id] = pn_unit_id atom_array.set_annotation("pn_unit_id", pn_unit_id_annotation.astype(str)) return atom_array
[docs] def add_pn_unit_iid_annotation(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: """Adds the polymer/non-polymer unit instance ID (pn_unit_iid) annotation to the AtomArray or AtomArrayStack.""" # ...create an array that concatenates the pn_unit_id and transformation_id _temp_pn_unit_iid = sum_string_arrays(atom_array.pn_unit_id, "_", atom_array.transformation_id) _final_pn_unit_iid = np.full(atom_array.array_length(), fill_value="", dtype=object) # ...iterate through unique pn_unit_iids # (We implicitly assume that a given pn_unit_id will have the same transformation_id across all atoms in the unit) for pn_unit_iid in np.unique(_temp_pn_unit_iid): # ...get the pn_unit_atom_array pn_unit_atom_array = subset_atom_array(atom_array, _temp_pn_unit_iid == pn_unit_iid) # ...get the transformation_id and pn_unit_id (which is the same for all atoms in the unit) transformation_id = pn_unit_atom_array.transformation_id[0] pn_unit_id = pn_unit_atom_array.pn_unit_id[0].astype(str) # ...split apart the pn_unit_id by commas pn_unit_ids = pn_unit_id.split(",") # ...add the transformation_id to each pn_unit_id pn_unit_iids = [f"{unit_id}_{transformation_id}" for unit_id in pn_unit_ids] # ...join the instance-level identifiers back into a single string pn_unit_iid_formatted = ",".join(pn_unit_iids) # ...update the AtomArray with the instance-level identifier _final_pn_unit_iid[_temp_pn_unit_iid == pn_unit_iid] = pn_unit_iid_formatted atom_array.set_annotation("pn_unit_iid", _final_pn_unit_iid.astype(str)) return atom_array
[docs] def add_molecule_id_annotation(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: """Adds the molecule ID (molecule_id) annotation to the AtomArray.""" # ...initialize the pn_unit_id to chain_id (we will later update for multi-chain non-polymer PN units) atom_array.add_annotation("molecule_id", dtype=np.int16) # ...make the NetworkX graph for all pn_units connected_pn_units = get_connected_nodes(*get_coarse_graph_as_nodes_and_edges(atom_array, "pn_unit_id")) # ...iterate through connected pn_units for idx, connected_pn_unit in enumerate(connected_pn_units): # ...set the same the molecule_id for each pn_unit in the connected pn_unit molecule_id = idx for pn_unit_id in connected_pn_unit: atom_array.molecule_id[atom_array.pn_unit_id == pn_unit_id] = molecule_id return atom_array
[docs] def add_molecule_iid_annotation(atom_array_stack: AtomArrayStack) -> AtomArrayStack: """Adds the molecule instance ID (molecule_iid) annotation to the AtomArrayStack""" # ...concatenate molecule_id and transformation_id to create a unique molecule instance ID molecule_iids_str = np.char.add( atom_array_stack.molecule_id.astype(str), atom_array_stack.transformation_id.astype(str) ) # ...map each unique molecule_iid to an integer (0-indexed) _, inverse_indices = np.unique(molecule_iids_str, return_inverse=True) # ...set the annotation atom_array_stack.set_annotation("molecule_iid", inverse_indices.astype(np.int16)) return atom_array_stack
[docs] def annotate_entities( atom_array: AtomArray, level: str, lower_level_id: str | list[str], lower_level_entity: str, add_inter_level_bond_hash: bool = True, ) -> tuple[AtomArray, dict]: """ Annotates entities in an AtomArray at a given `id` level, based on the connectivity and annotations at the lower level. The intended use is, for example: - For the `molecule` level, `molecule_entities` are generated for each `molecule_id` based on the connectivty at the `pn_unit` level. - For the `pn_unit` level, `pn_unit_entities` are generated for each `pn_unit_id` based on the connectivty at the `chain` level. - For the `chain` level, `chain_entities` are generated for each `chain_id` based on the connectivty at the `residue` level. Args: - atom_array (AtomArray): The AtomArray to process. - level (str): The level at which to annotate entities (e.g., "chain", "pn_unit", "entity") - lower_level_id (str | list[str]): A list of annotations to consider for determining segment boundaries at a lower level. E.g. "pn_unit_id", "chain_id" or "res_id". - lower_level_entity (str): The annotation to use for identifying entities at the lower level. E.g. "pn_unit_entity", "chain_entity" or "res_name". - add_inter_level_bond_hash (bool): Whether to add a hash of the inter-level bonds to the entity hash. For some cases, this may be necessary to distinguish entities (e.g., when determining molecule-level entities). In others (e.g., for polymers), this may be overkill. Returns: - Tuple[AtomArray, dict]: A tuple containing: - atom_array (AtomArray): The updated AtomArray with the entity annotation. - entities_info (dict): A dictionary mapping entity IDs to lists of instance IDs. Example: >>> atom_array = AtomArray(...) >>> entities_at_level, entities_info = annotate_entities( ... atom_array, level="chain", lower_level_id="res_id", lower_level_entity="res_name" ... ) >>> print(entities_at_level) [0, 0, 1, 1, 2, 2] >>> print(entities_info) {0: [0, 1], 1: [2, 3], 2: [4, 5]} """ _next_available_entity_id = 0 _hash_to_entity_id = {} ids_at_level = np.unique(atom_array.get_annotation(level + "_id")) # ... initialize annotations to fill entities_annotation = np.zeros(len(atom_array), dtype=int) entities_info = defaultdict(list) for instance_id in np.unique(ids_at_level): is_instance = atom_array.get_annotation(level + "_id") == instance_id instance = atom_array[is_instance] # ... get connectivity and node annotations for the coarse graph at the lower level _, edges = get_coarse_graph_as_nodes_and_edges(instance, lower_level_id) instance_graph = nx.Graph() instance_graph.add_edges_from(edges) # ... set node attributes to lower level entities lower_level_iter = struc.segments.segment_iter( instance, annot_start_stop_idxs(instance, lower_level_id, add_exclusive_stop=True) ) node_attrs = { idx: lower_level_instance.get_annotation(lower_level_entity)[0] for idx, lower_level_instance in enumerate(lower_level_iter) } nx.set_node_attributes(instance_graph, node_attrs, "node") # ... create the graph hash hash = hash_graph(instance_graph, node_attr="node") # ... add the inter-level bond hash (only consider the first lower level id; since we hash at the atom-level, this simplication is valid) if add_inter_level_bond_hash: hash += generate_inter_level_bond_hash( atom_array=instance, lower_level_id=lower_level_id[0] if isinstance(lower_level_id, list) else lower_level_id, lower_level_entity=lower_level_entity, ) # ... check if the graph has been seen before if hash in _hash_to_entity_id: entity_id = _hash_to_entity_id[hash] else: entity_id = _next_available_entity_id _hash_to_entity_id[hash] = entity_id _next_available_entity_id += 1 # ... assign the entity id to the instance entities_annotation[is_instance] = entity_id entities_info[entity_id].append(instance_id) atom_array.set_annotation(level + "_entity", entities_annotation) return atom_array, dict(entities_info)
[docs] def add_chain_iid_annotation(atom_array_stack: AtomArrayStack) -> AtomArrayStack: """Adds the chain instance ID (chain_iid) annotation to the AtomArrayStack""" # ...concatenate chain_id and transformation_id to create a unique chain instance ID chain_iid = sum_string_arrays( atom_array_stack.chain_id, "_", atom_array_stack.transformation_id, ) atom_array_stack.set_annotation("chain_iid", chain_iid) return atom_array_stack
[docs] def add_iid_annotations_to_assemblies( assemblies_dict: dict[str | int, AtomArray | AtomArrayStack], ) -> dict[str | int, AtomArray | AtomArrayStack]: """Adds chain, PN unit, and molecule IIDs to assembly AtomArrayStacks.""" for assembly_id, assembly in assemblies_dict.items(): # ...add chain IIDs assembly = add_chain_iid_annotation(assembly) # ...check if we added bonds to the atom array # ...add PN unit IIDs, if we have the `pn_unit_id` annotation if "pn_unit_id" in assembly.get_annotation_categories(): assembly = add_pn_unit_iid_annotation(assembly) # ...add molecule IIDs, if we have the `molecule_id` annotation if "molecule_id" in assembly.get_annotation_categories(): assembly = add_molecule_iid_annotation(assembly) # ...update the dictionary assemblies_dict[assembly_id] = assembly return assemblies_dict
[docs] def add_id_and_entity_annotations(atom_array: AtomArray) -> AtomArray: """Adds all 6 ('chain', 'pn_unit', 'molecule') x ('id', 'entity') annotations to the AtomArray.""" # ...annotate PN units (requires bonds) atom_array = add_pn_unit_id_annotation(atom_array) # ...annotate molecules (requires bonds) atom_array = add_molecule_id_annotation(atom_array) levels = ["chain", "pn_unit", "molecule"] lower_level_ids = ["res_id", "chain_id", "pn_unit_id"] lower_level_entities = ["res_name", "chain_entity", "pn_unit_entity"] for level, lower_level_id, lower_level_entity in zip(levels, lower_level_ids, lower_level_entities, strict=False): # ...annotate entities at appropriate level atom_array, _ = annotate_entities( atom_array=atom_array, level=level, lower_level_id=lower_level_id, lower_level_entity=lower_level_entity, ) return atom_array
[docs] def add_chain_type_annotation( atom_array: AtomArray | AtomArrayStack, chain_info_dict: dict ) -> AtomArray | AtomArrayStack: """ Adds a chain_type annotation to the AtomArray. Args: - atom_array (AtomArray | AtomArrayStack): The full atom array. - chain_info_dict (dict): A dictionary mapping chain IDs to chain information. Returns: - AtomArray | AtomArrayStack: The AtomArray with the chain_type annotation added as an integer. """ # Add annotation for chain_type as an integer atom_array.add_annotation("chain_type", dtype=np.int8) for chain_id in np.unique(atom_array.chain_id): chain_type = chain_info_dict[chain_id]["chain_type"] # We use the integer representation of the ChainType enum for efficiency atom_array.chain_type[atom_array.chain_id == chain_id] = chain_type.value # Return the modified atom array return atom_array
[docs] def add_atomic_number_annotation(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: """Adds the atomic number (atomic_number) annotation to the AtomArray.""" atom_array.set_annotation( "atomic_number", np.array(listmap(ELEMENT_NAME_TO_ATOMIC_NUMBER.get, np.char.upper(atom_array.element)), dtype=np.int8), ) return atom_array