Source code for atomworks.io.utils.bonds

"""
Utility functions for the detection, and creation, of bonds in a structure.
"""

__all__ = [
    "generate_inter_level_bond_hash",
    "get_coarse_graph_as_nodes_and_edges",
    "get_connected_nodes",
    "get_inferred_polymer_bonds",
    "get_struct_conn_bonds",
    "hash_atom_array",
    "hash_graph",
]

import hashlib
import logging
from typing import Any

import biotite.structure as struc
import networkx as nx
import numpy as np
import pandas as pd
from biotite.structure import AtomArray
from biotite.structure.io.pdbx.convert import (
    PDBX_BOND_TYPE_TO_TYPE_ID,
    _filter_bonds,
    _filter_canonical_links,
    _get_struct_conn_col_name,
)

from atomworks.enums import ChainType, ChainTypeInfo
from atomworks.io.common import sum_string_arrays, to_hashable
from atomworks.io.constants import (
    AA_LIKE_CHEM_TYPES,
    CHEM_TYPE_POLYMERIZATION_ATOMS,
    DEFAULT_VALENCE,
    HYDROGEN_LIKE_SYMBOLS,
    NA_LIKE_CHEM_TYPES,
    STRUCT_CONN_BOND_ORDER_TO_INT,
    STRUCT_CONN_BOND_TYPES,
)
from atomworks.io.utils.ccd import get_chem_comp_leaving_atom_names, get_chem_comp_type
from atomworks.io.utils.selection import get_annotation, get_residue_starts
from atomworks.io.utils.testing import has_ambiguous_annotation_set

logger = logging.getLogger("atomworks.io")


def _get_leaving_atom_idxs_for(atom_name: str, res_name: str, atom_names: np.ndarray, offset: int = 0) -> np.ndarray:
    """
    Get the indices of the leaving atoms for a given residue and atom.
    """
    leaving_atoms = get_chem_comp_leaving_atom_names(res_name).get(atom_name, ())
    return offset + np.where(np.isin(atom_names, leaving_atoms))[0]


[docs] def get_inferred_polymer_bonds(atom_array: AtomArray) -> tuple[list[tuple[int, int, struc.BondType]], np.ndarray]: """ Infers and returns polymer bonds between consecutive residues in an atom array based on chemical component types and chain types. The function identifies bonds by looking at consecutive residues within the same chain and determining the appropriate bonding atoms based on either the chain type (as a fallback) or more detailed chemical component types. It also tracks leaving atoms that are displaced during bond formation. Leaving groups are inferred from the CCD entries for the chemical components. If a CCD code is missing from your local CCD mirror, leaving groups will not be inferred. Args: - atom_array (AtomArray): The atom array containing the structure information. Must include annotations for chain_id, res_id, res_name, and atom_name. Optionally includes chain_type annotation. Returns: - polymer_bonds (np.array[[int, int, struc.BondType]]): List of tuples containing (atom1_idx, atom2_idx, bond_type) for each inferred polymer bond. - leaving_atom_idxs (np.ndarray): Array of atom indices that represent leaving groups displaced during bond formation. Example: >>> # Create an atom array with two consecutive peptide residues >>> atom_array = AtomArray(length=10) >>> atom_array.chain_id = ["A"] * 10 >>> atom_array.res_id = [1] * 5 + [2] * 5 >>> atom_array.res_name = ["ALA"] * 5 + ["GLY"] * 5 >>> atom_array.atom_name = ["N", "CA", "C", "OXT", "CB"] + ["N", "CA", "C", "O", "H2"] >>> # Get the polymer bonds >>> bonds, leaving = get_inferred_polymer_bonds(atom_array) >>> print(bonds) # Shows C-N peptide bond between residues [(2, 5, <BondType.SINGLE>)] # C of ALA to N of GLY >>> print(leaving) # Shows leaving OXT from C and H2 from N (other hydrogen atom names not shown for simplicity) [array([3]), array([9])] """ # ... initialize return values bonds: list[tuple[int, int, struc.BondType]] = [] leaving: list[np.ndarray] = [] # ... get annotations we need to work with chain_ids = atom_array.chain_id res_ids = atom_array.res_id res_names = atom_array.res_name atom_names = atom_array.atom_name chain_types = get_annotation(atom_array, "chain_type", default=None) is_polymer = get_annotation(atom_array, "is_polymer", default=np.zeros(atom_array.array_length(), dtype=bool)) # ... get iterators over the residues residue_starts = get_residue_starts(atom_array, add_exclusive_stop=True) this_res_starts = residue_starts[:-2] next_res_starts = residue_starts[1:-1] next_res_stops = residue_starts[2:] # ... loop over the residues and add the bonds for this_res_start, next_res_start, next_res_stop in zip( this_res_starts, next_res_starts, next_res_stops, strict=False ): # ... skip if residues are not on the same chain if chain_ids[this_res_start] != chain_ids[next_res_start]: continue # ... and skip if residues don't have consecutive res_id's # (NOTE: same res_id is allowed, if ins_code is different) if res_ids[next_res_start] - res_ids[this_res_start] > 1: continue # ... get fallback default bonding atoms based on chain type bonding_atoms = None if chain_types is not None: chain_type = ChainType.as_enum(chain_types[this_res_start]) bonding_atoms = ChainTypeInfo.ATOMS_AT_POLYMER_BOND.get(chain_type, None) # ... get (more detailed) bonding atoms based on chem-comp types this_link = get_chem_comp_type(res_names[this_res_start], mode="warn") next_link = get_chem_comp_type(res_names[next_res_start], mode="warn") # ... decide which bonds to form: both_aa = (this_link in AA_LIKE_CHEM_TYPES) and (next_link in AA_LIKE_CHEM_TYPES) both_na = (this_link in NA_LIKE_CHEM_TYPES) and (next_link in NA_LIKE_CHEM_TYPES) if (this_link in CHEM_TYPE_POLYMERIZATION_ATOMS) and (both_aa or both_na): bonding_atoms = CHEM_TYPE_POLYMERIZATION_ATOMS[this_link] # ... add the bonds if we have bonding atoms if bonding_atoms is not None: # bonding_atoms: tuple[str, str] = (atom1_name, atom2_name) atom1_name, atom2_name = bonding_atoms # ... get the atoms names within the current residues this_res_atom_names = atom_names[this_res_start:next_res_start] next_res_atom_names = atom_names[next_res_start:next_res_stop] # ... find the indices of the bonding atoms based on the atoms names atom1_idx = np.where(this_res_atom_names == atom1_name)[0] atom2_idx = np.where(next_res_atom_names == atom2_name)[0] if len(atom1_idx) == 0 or len(atom2_idx) == 0: # ... bonding atoms are not found in the adjacent residues # ... -> skip this bond logger.info( f"Bonding atoms {atom1_name} and {atom2_name} not found " f"in the adjacent residues {this_res_start} and {next_res_start}!" ) continue # ... add the bond bonds.append( ( this_res_start + atom1_idx[0], # ... add global atom idx offset next_res_start + atom2_idx[0], # ... add global atom idx offset struc.BondType.SINGLE, ) ) # ... compute the leaving atoms leaving_this_res = _get_leaving_atom_idxs_for( atom_name=atom1_name, res_name=res_names[this_res_start], atom_names=this_res_atom_names, offset=this_res_start, ) leaving_next_res = _get_leaving_atom_idxs_for( atom_name=atom2_name, res_name=res_names[next_res_start], atom_names=next_res_atom_names, offset=next_res_start, ) leaving.append(leaving_this_res) if len(leaving_this_res) > 0 else None leaving.append(leaving_next_res) if len(leaving_next_res) > 0 else None # ... optionally add `is_polymer` annotation to the atom array is_polymer[this_res_start:next_res_stop] = True if "is_polymer" not in atom_array.get_annotation_categories(): # ... if polymer annotation was not present before, we set it here based on the inferred bonds atom_array.set_annotation("is_polymer", is_polymer) return np.array(bonds).reshape(-1, 3), np.concatenate(leaving) if len(leaving) > 0 else np.array([], dtype=int)
def get_struct_conn_dict_from_atom_array( atom_array: AtomArray, ) -> dict[str, np.ndarray]: """Returns a struct_conn dictionary corresponding to a given AtomArray. This contains the keys used in `get_struct_conn_bonds`. NOTE: These AtomArray-derived struct_conn_dicts will never contain disulfide or hydrogen bonds, as Biotite does not distinguish these in the BondList. Possible types are "covale" and "metalc". Args: atom_array (AtomArray): The atom array to get the struct_conn dictionary from. Returns: dict[str, np.ndarray]: The struct_conn dictionary. """ struct_conn_dict = {} for res_array in struc.residue_iter(atom_array): if len(np.unique(res_array.atom_name)) != len(res_array.atom_name): raise ValueError( "Duplicate atom names detected in the same residue -- cannot infer struct_conn. " "This may happen when a non-polymer is loaded from a CIF file without using `atomworks.io.parser.parse`. " ) # Keep only inter-residue bonds bond_array = _filter_bonds(atom_array, "inter") if len(bond_array) == 0: return struct_conn_dict # Filter out 'standard' links, i.e. backbone bonds between adjacent canonical # nucleotide/amino acid residues bond_array = bond_array[~_filter_canonical_links(atom_array, bond_array)] if len(bond_array) == 0: return struct_conn_dict use_iids = False # By default, we use chain_ids to determine bonds has_chain_iids = "chain_iid" in atom_array.get_annotation_categories() # Determine whether we need to fall back to using chain_iids if has_ambiguous_annotation_set(atom_array): if not has_chain_iids: raise ValueError( "Ambiguous bond annotations detected. This happens when there are atoms that " "have the same `(chain_id, res_id, res_name, atom_id, ins_code)` identifier. " "This happens for example when you have a bio-assembly with multiple copies " "of a chain that only differ by `transformation_id`.\n" "You can fix this for example by re-naming the chains to be named uniquely. " "For the purposes of this function, you can also add a unambiguous chain_iid annotation instead. " ) elif has_ambiguous_annotation_set( atom_array, annotation_set=["chain_iid", "res_id", "res_name", "atom_name", "ins_code"] ): raise ValueError( "Ambiguous bond annotations detected. This happens when there are atoms that " "have the same `(chain_id, res_id, res_name, atom_id, ins_code)` identifier. " "This happens for example when you have a bio-assembly with multiple copies " "of a chain that only differ by `transformation_id`.\n" "In this case, falling back to the `chain_iid` annotation was insufficient to resolve the ambiguity." "You can fix this for example by re-naming the chains to be named uniquely. " "For the purposes of this function, you can also add a unambiguous chain_iid annotation instead. " ) else: use_iids = True # Add the bond type information struct_conn_dict["conn_type_id"] = np.array([PDBX_BOND_TYPE_TO_TYPE_ID[btype] for btype in bond_array[:, 2]]) label_asym_id_field = "chain_iid" if use_iids else "chain_id" cif_field_to_annot = { "label_asym_id": label_asym_id_field, "label_comp_id": "res_name", "label_seq_id": "res_id", "label_atom_id": "atom_name", "pdbx_PDB_ins_code": "ins_code", } for col_name, annot_name in cif_field_to_annot.items(): annot = atom_array.get_annotation(annot_name) # ...for each bond partner for i in range(2): atom_indices = bond_array[:, i] struct_conn_dict[_get_struct_conn_col_name(col_name, i + 1)] = annot[atom_indices].astype(str) return struct_conn_dict
[docs] def get_struct_conn_bonds( atom_array: AtomArray, struct_conn_dict: dict[str, np.ndarray], add_bond_types: list[str] = ["covale"], raise_on_failure: bool = False, ) -> tuple[np.ndarray, np.ndarray]: """ Adds bonds from the 'struct_conn' category of a CIF block to an atom array. Only covalent bonds are considered. Args: atom_array (AtomArray): The atom array used to get atom indices. struct_conn_dict (dict[str, np.ndarray]): The struct_conn category of a CIF block as a dictionary. E.g. (Only mandatory fields, as defined by the RCSB, are shown) ``` { 'conn_type_id': array(['disulf', ...]), 'ptnr1_label_asym_id': array(['A', ...]), 'ptnr1_label_comp_id': array(['CYS', ...]), 'ptnr1_label_seq_id': array(['6', ...]), 'ptnr1_label_atom_id': array(['SG', ...]), 'ptnr1_symmetry': array(['1_555', ...]), 'ptnr2_label_asym_id': array(['A', ...]), 'ptnr2_label_comp_id': array(['CYS', ...]), 'ptnr2_label_seq_id': array(['127', ...]), 'ptnr2_label_atom_id': array(['SG', ...]), 'ptnr2_symmetry': array(['1_555', ...]), } ``` However, in this function, we only require the following fields: - conn_type_id (e.g., "covale") - ptnr1_label_asym_id (chain_id or chain_iid, e.g., "A" or "A_1") - ptnr1_label_comp_id (residue name in the CCD, e.g., "CYS") - ptnr1_label_seq_id (residue ID, e.g., "6") - ptnr1_label_atom_id (atom name, e.g., "SG") - ptnr2_label_asym_id - ptnr2_label_comp_id - ptnr2_label_seq_id - ptnr2_label_atom_id add_bond_types (list[str]): A list of bond types that should be added. Valid bond types are: ["covale", "disulf", "metalc", "hydrog"]. Defaults to ["covale"], which is the use-case in structure-prediction, where we would a-priori know covalent bonds (except for disulfides). raise_on_failure(bool): If True, raise an error if specified bonds cannot be made (e.g., if the atoms are missing). Defaults to False. NOTE: While chain_iid annotations are allowed, a given bond is expected to contain only one annotation type, i.e. both chain_id or both chain_iid Returns: bonds (np.array[[int, int, struc.BondType]]): A List of bonds to be added to the atom array. leaving (np.ndarray): An array of indices of atoms that are leaving groups for bookkeeping. References: - https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_struct_conn.conn_type_id.html """ # ... validate input invalid_bond_types = set(add_bond_types) - STRUCT_CONN_BOND_TYPES if len(invalid_bond_types) > 0: raise ValueError( f"Invalid bond type(s) provided: {invalid_bond_types}! Valid bond types are: {STRUCT_CONN_BOND_TYPES}" ) if len(struct_conn_dict) == 0: return np.empty((0, 3), dtype=int), np.empty((0,), dtype=int) # ... convert struct_conn_dict to a DataFrame struct_conn_df = pd.DataFrame(struct_conn_dict) struct_conn_df = struct_conn_df[struct_conn_df["conn_type_id"].isin(add_bond_types)] if struct_conn_df.empty: # ... skip if no bonds to add return np.empty((0, 3), dtype=int), np.empty((0,), dtype=int) logger.debug(f"Attempting to add {len(struct_conn_df)} bonds from `struct_conn`") # ... extract relevant annotations chain_ids = atom_array.chain_id res_names = atom_array.res_name res_ids = atom_array.res_id ins_codes = atom_array.ins_code atom_names = atom_array.atom_name is_polymer = atom_array.is_polymer global_atom_idx = np.arange(atom_array.array_length()) alt_atom_ids = get_annotation(atom_array, "alt_atom_id", default=atom_names) uses_alt_atom_id = get_annotation(atom_array, "uses_alt_atom_id", default=np.zeros(len(atom_array), dtype=bool)) all_res_names = np.unique(res_names) all_chain_ids = np.unique(chain_ids) polymer_chain_ids = np.unique(chain_ids[is_polymer]) # Get iid-level annotations if present if "chain_iid" in atom_array.get_annotation_categories(): chain_iids = atom_array.chain_iid all_chain_iids = np.unique(chain_iids) polymer_chain_iids = np.unique(chain_iids[is_polymer]) else: chain_iids = None all_chain_iids = None polymer_chain_iids = None # ... initialize return values bonds: list[tuple[int, int, struc.BondType]] = [] leaving: list[np.ndarray] = [] for _, row in struct_conn_df.iterrows(): res_name1 = row["ptnr1_label_comp_id"] res_name2 = row["ptnr2_label_comp_id"] if (res_name1 not in all_res_names) or (res_name2 not in all_res_names): # ... skip if the residues were removed from the structure if raise_on_failure: raise ValueError(f"Residue {res_name1} or {res_name2} not found in the atom array!") continue chain_id1 = row["ptnr1_label_asym_id"] chain_id2 = row["ptnr2_label_asym_id"] # Default to using id-level identifiers relevant_chain_identifiers = chain_ids relevant_polymer_chain_identifiers = polymer_chain_ids # If iid-level identifiers are present, use these as a fallback if (chain_id1 not in all_chain_ids) or (chain_id2 not in all_chain_ids): if (chain_iids is not None) and (chain_id1 in all_chain_iids) and (chain_id2 in all_chain_iids): relevant_chain_identifiers = chain_iids relevant_polymer_chain_identifiers = polymer_chain_iids else: # ... skip, but warn if the chains are not present in the structure logger.info( f"Found covalent bond involving chains {chain_id1} and {chain_id2}, but at least one " "chain was removed during cleaning. This is likely because the chain is made up of a " "residue that is not in the local CCD. This should automatically be resolved once you " "update your CCD, unless you are working with an outdated structure file." ) if raise_on_failure: raise ValueError(f"Chain {chain_id1} or {chain_id2} not found in the atom array!") continue # For non-polymers, we use the auth_seq_id if available and valid (i.e., not "." or "?"); otherwise we use the label_seq_id # (Required to avoid ambiguity, since if using `label` only we may have multiple residue within a # chain with the same label_seq_id and the same res_name; see: 6MUB) res_id1 = int( row["ptnr1_label_seq_id"] if ((chain_id1 in relevant_polymer_chain_identifiers) or ("ptnr1_auth_seq_id" not in row)) and row["ptnr1_label_seq_id"] != "." else row["ptnr1_auth_seq_id"] ) res_id2 = int( row["ptnr2_label_seq_id"] if ((chain_id2 in relevant_polymer_chain_identifiers) or ("ptnr2_auth_seq_id" not in row)) and row["ptnr2_label_seq_id"] != "." else row["ptnr2_auth_seq_id"] ) ins_code1 = row.get("pdbx_ptnr1_PDB_ins_code", "") ins_code2 = row.get("pdbx_ptnr2_PDB_ins_code", "") ins_code1 = "" if ins_code1 in (".", "?") else ins_code1 ins_code2 = "" if ins_code2 in (".", "?") else ins_code2 # ... get masks for the residues to which atoms 1 & 2 belong in_res1 = ( (relevant_chain_identifiers == chain_id1) & (res_ids == res_id1) & (res_names == res_name1) & (ins_codes == ins_code1) ) in_res2 = ( (relevant_chain_identifiers == chain_id2) & (res_ids == res_id2) & (res_names == res_name2) & (ins_codes == ins_code2) ) if (not in_res1.any()) or (not in_res2.any()): logger.info( f"Residue {chain_id1}/{res_id1}/{res_name1} or {chain_id2}/{res_id2}/{res_name2} " "not found in the atom array!" ) if raise_on_failure: raise ValueError( f"Residue {chain_id1}/{res_id1}/{res_name1} or {chain_id2}/{res_id2}/{res_name2} " "not found in the atom array!" ) continue in_res1_start = global_atom_idx[in_res1][0] in_res2_start = global_atom_idx[in_res2][0] # Ensure that the we picked the correct residue (to handle sequence heterogeneity; see PDB ID `3nez` for an example) # (short circuit eval to avoid indexing errors in cases where we don't have one of the residues due to seq. heterogeneity # - e.g. 3nez) if ( (in_res1.sum() == 0) or (in_res2.sum() == 0) or (res_name1 != res_names[in_res1_start]) or (res_name2 != res_names[in_res2_start]) ): logger.info( f"Covalent bond involving residues {chain_id1}/{res_id1}/{res_name1} and " f"{chain_id2}/{res_id2}/{res_name2} was found in `struct_conn`, but the " f"residues are not present in the atom array. This is likely due to " f"resolved sequence heterogeneity which removed one of the residues." ) if raise_on_failure: raise ValueError( f"Residue {chain_id1}/{res_id1}/{res_name1} or {chain_id2}/{res_id2}/{res_name2} " "not found in the atom array!" ) continue # If all residues are present, we can proceed with identifying the global indices of the # atoms in the bond and add the bond # ... get the indices of the atoms and append to the list atom_name1 = row["ptnr1_label_atom_id"] atom_name2 = row["ptnr2_label_atom_id"] # ... skip, but warn if the atoms (either the standard are not present in the atom array all_names = np.concatenate((atom_names, alt_atom_ids)) if (atom_name1 not in all_names) or (atom_name2 not in all_names): logger.info( f"Covalent bond involving atoms {atom_name1} and {atom_name2} was found in `struct_conn`, but the " "atoms are not present in the residue's AtomArray!" ) continue if uses_alt_atom_id[in_res1_start]: atom1_local_idx = np.where(alt_atom_ids[in_res1] == atom_name1)[0][0] else: atom1_local_idx = np.where(atom_names[in_res1] == atom_name1)[0][0] if uses_alt_atom_id[in_res2_start]: atom2_local_idx = np.where(alt_atom_ids[in_res2] == atom_name2)[0][0] else: atom2_local_idx = np.where(atom_names[in_res2] == atom_name2)[0][0] # ... convert local atom indices to global indices atom1_global_idx = in_res1_start + atom1_local_idx atom2_global_idx = in_res2_start + atom2_local_idx # ... add the bond # Metal coordination bonds don't have a `pdbx_value_order`, so these are handled separately if row["conn_type_id"] == "metalc": bonds.append([atom1_global_idx, atom2_global_idx, struc.BondType.COORDINATION]) else: bond_order = STRUCT_CONN_BOND_ORDER_TO_INT.get(row.get("pdbx_value_order"), 1) bonds.append([atom1_global_idx, atom2_global_idx, struc.BondType(bond_order)]) # ... and identify the leaving atoms leaving_res1 = _get_leaving_atom_idxs_for( atom_name=atom_names[atom1_global_idx], res_name=res_name1, atom_names=atom_names[in_res1], offset=in_res1_start, ) leaving_res2 = _get_leaving_atom_idxs_for( atom_name=atom_names[atom2_global_idx], res_name=res_name2, atom_names=atom_names[in_res2], offset=in_res2_start, ) leaving.append(leaving_res1) if len(leaving_res1) > 0 else None leaving.append(leaving_res2) if len(leaving_res2) > 0 else None return np.array(bonds).reshape(-1, 3), np.concatenate(leaving) if len(leaving) > 0 else np.array([], dtype=int)
[docs] def get_coarse_graph_as_nodes_and_edges( atom_array: AtomArray, annotations: str | tuple[str] ) -> tuple[np.ndarray, np.ndarray]: """ Returns the coarse-grained nodes and edges at the given annotation level based on the atom array's bond connectivity. Args: - atom_array (AtomArray): The atom array containing atomic information and bonds. - annotations (str | tuple[str]): A single annotation or a tuple of annotations to be used for node identification. Returns: - nodes (np.ndarray): An array of unique nodes, each represented by a combination of annotations. - edges (np.ndarray): An array of edges, where each edge is a tuple of node indices representing a bond between two nodes. Example: >>> atom_array = cached_parse("5ocm")["atom_array"] >>> nodes, edges = get_coarse_graph(atom_array, ["chain_id", "transformation_id"]) >>> print(nodes) array([('A', '1'), ('F', '1'), ('G', '1'), ('H', '1'), ('I', '1'), ('W', '1'), ('X', '1'), ('Y', '1')], dtype=[('chain_id', '<U4'), ('transformation_id', '<U1')]) >>> print(edges) array([[0, 0], [1, 1], [2, 2], [3, 3], [5, 5], [6, 6]]) """ annotations = [annotations] if isinstance(annotations, str) else annotations atom1, atom2, _ = atom_array.bonds.as_array().T if len(annotations) > 1: _annots = np.zeros( len(atom_array), dtype=[(annot, atom_array.get_annotation(annot).dtype) for annot in annotations] ) for annot in annotations: _annots[annot] = atom_array.get_annotation(annot) # [n_atoms, n_annotations] else: _annots = atom_array.get_annotation(annotations[0]) # [n_atoms] annot1 = _annots[atom1] # [n_bonds, n_annotations] annot2 = _annots[atom2] # [n_bonds, n_annotations] nodes = np.unique(_annots, axis=0) # [n_nodes, n_annotations] self_edges = np.vstack([nodes, nodes]).T # [n_nodes, 2] edges = np.unique(np.vstack([self_edges, np.vstack([annot1, annot2]).T]), axis=0) # [n_edges, 2] # Map nodes to integers node_to_idx = {to_hashable(node): i for i, node in enumerate(nodes)} if len(edges) > 0: edges = np.apply_along_axis( lambda x: (node_to_idx[to_hashable(x[0])], node_to_idx[to_hashable(x[1])]), 1, edges ) return nodes, edges
[docs] def get_connected_nodes(nodes: np.ndarray, edges: np.ndarray) -> list[list[Any]]: """Returns connected nodes as a mapped list given corresponding arrays of nodes and edges. Example: >>> nodes = np.array([("A", "1"), ("B", "1"), ("C", "1"), ("D", "1")]) >>> edges = np.array([[0, 1], [0, 2], [1, 2]]) >>> connected_nodes = get_connected_nodes(nodes, edges) >>> print(connected_nodes) [[("A", "1"), ("B", "1"), ("C", "1")], [("D", "1")]] """ # ...make the graph graph = nx.Graph() graph.add_edges_from(edges) # ...return lists of connected chains return [[nodes[x] for x in component] for component in nx.connected_components(graph)]
[docs] def hash_graph( graph: nx.Graph, node_attr: str | None = None, edge_attr: str | None = None, iterations: int = 3, digest_size: int = 16, ) -> str: """ Computes a hash for a given graph using the Weisfeiler-Lehman (WL) graph hashing algorithm and additionally adds a node and edge attribute hash, if specified, to deal with common edge cases where WL fails (e.g. disconnected graphs). Args: - graph (networkx.Graph): The input graph to be hashed. - node_attr (str | None): The node attribute to be used for hashing. If None, node attributes are ignored. - edge_attr (str | None): The edge attribute to be used for hashing. If None, edge attributes are ignored. - iterations (int): The number of iterations for the WL algorithm. Default is 3. - digest_size (int): The size of the hash digest for WL. Default is 16. Returns: - str: The computed hash of the graph. Example: >>> import networkx as nx >>> G = nx.gnm_random_graph(10, 15) >>> hash_graph(G) '504894f49dd84b17c391b163af69624b' """ # ... compute WL-hash hash = nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash( graph, node_attr=node_attr, edge_attr=edge_attr, iterations=iterations, digest_size=digest_size ) if node_attr is not None: # ... add number of unique nodes to hash hash += f"_{len(graph.nodes)}" # ... add number of unique node attributes with counts to hash node_attr_dict = nx.get_node_attributes(graph, node_attr) hash += "_" + ",".join( [ f"{elt}:{count}" for elt, count in zip(*np.unique(list(node_attr_dict.values()), return_counts=True), strict=False) ] ) if edge_attr is not None: # ... add number of unique edges to hash hash += f"_{len(graph.edges)}" return hash
def _atom_array_to_networkx_graph( atom_array: AtomArray, annotations: tuple[str] = ("element", "atom_name"), bond_order: bool = True, cast_aromatic_bonds_to_same_type: bool = True, ) -> nx.Graph: """Convert an AtomArray to a NetworkX graph.""" # ... create the bond graph bonds = atom_array.bonds.as_array() bond_list = bonds[:, :2] # ... create the bond graph for the atom array bond_graph = nx.Graph() bond_graph.add_nodes_from(range(len(atom_array))) bond_graph.add_edges_from(bond_list) # ... annotate the bond graph with bond order if bond_order: bond_type = bonds[:, -1] if cast_aromatic_bonds_to_same_type: bond_type[bond_type == struc.BondType.AROMATIC_SINGLE] = 0 bond_type[bond_type == struc.BondType.AROMATIC_DOUBLE] = 0 bond_type[bond_type == struc.BondType.AROMATIC_TRIPLE] = 0 nx.set_edge_attributes( bond_graph, {tuple(bond): type for bond, type in zip(bond_list, bond_type, strict=False)}, "bond_type" ) # ... annotate the bond graph with the desired node annotations if annotations: node_data = sum_string_arrays(*[atom_array.get_annotation(annot).astype(str) for annot in annotations]) # ... map the node annotations to the bond graph nx.set_node_attributes(bond_graph, {n: node_data[n] for n in bond_graph.nodes()}, "node_data") return bond_graph
[docs] def hash_atom_array( atom_array: AtomArray, annotations: tuple[str] = ("element", "atom_name"), bond_order: bool = True, cast_aromatic_bonds_to_same_type: bool = False, use_md5: bool = False, md5_length: int | None = None, ) -> str: """ Computes a hash for an AtomArray based on the bond connectivity and the selected node annotations. Args: atom_array (AtomArray): The array of atoms to hash annotations (tuple[str]): The node annotations to include in the hash bond_order (bool): Whether to include bond order in the hash cast_aromatic_bonds_to_same_type (bool): Whether to treat all aromatic bonds as the same type use_md5 (bool): Whether to use MD5 hashing on the output md5_length (int | None): If using MD5, the number of characters to keep from the hash. If None, returns full hash. Returns: str: The computed hash """ # ... create the bond graph bond_graph = _atom_array_to_networkx_graph( atom_array, annotations=annotations, bond_order=bond_order, cast_aromatic_bonds_to_same_type=cast_aromatic_bonds_to_same_type, ) hash_str = hash_graph( bond_graph, node_attr="node_data" if annotations else None, edge_attr="bond_type" if bond_order else None ) if use_md5: hash_str = hashlib.md5(hash_str.encode()).hexdigest() if md5_length is not None: hash_str = hash_str[:md5_length] return hash_str
[docs] def generate_inter_level_bond_hash( atom_array: AtomArray, lower_level_id: str, lower_level_entity: str | None = None ) -> str: """ Generates a hash string representing the inter-level bonds within an AtomArray. When computing entities IDs, we must consider inter-level bonds at the atom- and residue-level to avoid ambiguity. Args: atom_array (AtomArray): The array of atoms containing bond and annotation information. lower_level_id (str): The level which to find, and hash, the inter-level bonds. For example, when computing molecule entities, we'd consider the inter-PN Unit bonds. lower_level_entity (str } None): An additional entity annotation to use when computing the hash. Optional; if None, then only residue ID, residue name, and atom name are used. Returns: str: A hash string representing the inter-level bonds. """ # ...find the inter-level bonds bond_a = atom_array.get_annotation(lower_level_id)[atom_array.bonds.as_array()[:, 0]] bond_b = atom_array.get_annotation(lower_level_id)[atom_array.bonds.as_array()[:, 1]] inter_level_bonds = atom_array.bonds.as_array()[bond_a != bond_b] if inter_level_bonds.size: # ...loop over the bonds and create a (sorted) list of tuples with the relevant information bond_tuples = [] for atom_idx in range(inter_level_bonds.shape[0]): atom_a = atom_array[inter_level_bonds[atom_idx, 0]] atom_b = atom_array[inter_level_bonds[atom_idx, 1]] bond_tuples.append( tuple( sorted( [ ( getattr(atom_a, lower_level_entity) if lower_level_entity else None, atom_a.res_id, atom_a.res_name, atom_a.atom_name, ), ( getattr(atom_b, lower_level_entity) if lower_level_entity else None, atom_b.res_id, atom_b.res_name, atom_b.atom_name, ), ] ) ) ) # ...sort the list of tuples, and hash return str(hash(tuple(sorted(bond_tuples)))) else: return ""
def spoof_struct_conn_dict_from_string(bonds: list[tuple[str, str]]) -> dict[str, list[str]]: """Spoof a struct_conn_dict from a list of bond strings. NOTE: For SMILES, atoms are named with their element type and the order that they appear in the SMILES string. For example, the first carbon atom in a SMILES string would be named "C1", the second "C2", and so on. NOTE: We only support covalent bonds. Args: bonds (list[tuple[str, str]]): A list of bond strings. Each bond string should be in the format: "CHAIN_ID/RES_NAME/RES_ID/ATOM_NAME, CHAIN_ID/RES_NAME/RES_ID/ATOM_NAME" In PyMol, you can hover over an atom to display the relevant information. For example, clicking a CA atom in PyMol prints in the console: ``` >>> You clicked /6wtf/A/A/ALA`294/CA ``` We could then copy this information to specify an atom, "A/ALA/294/CA" Returns: dict[str, list[str]]: A dictionary in struct_conn format. Example: ``` >>> bonds = [ ... ("A/THR/4/CG", "D/UNL/1/C5"), ... ("A/CYS/5/SG", "A/CYS/137/SG") ... ] >>> struct_conn_dict = spoof_struct_conn_dict_from_string(bonds) >>> print(struct_conn_dict) { 'conn_type_id': ['covale', 'covale'], 'ptnr1_label_asym_id': ['A', 'A'], 'ptnr1_label_comp_id': ['THR', 'CYS'], 'ptnr1_label_seq_id': ['4', '5'], 'ptnr1_label_atom_id': ['CG', 'SG'], 'ptnr2_label_asym_id': ['D', 'A'], 'ptnr2_label_comp_id': ['UNL', 'CYS'], 'ptnr2_label_seq_id': ['1', '137'], 'ptnr2_label_atom_id': ['C5', 'SG'], } ``` """ struct_conn_dict = { "conn_type_id": [], "ptnr1_label_asym_id": [], "ptnr1_label_comp_id": [], "ptnr1_label_seq_id": [], "ptnr1_label_atom_id": [], "ptnr2_label_asym_id": [], "ptnr2_label_comp_id": [], "ptnr2_label_seq_id": [], "ptnr2_label_atom_id": [], } for bond in bonds: try: # Split the bond string into two parts ptnr1, ptnr2 = bond # Parse the first partner ptnr1_chain_id, ptnr1_res_name, ptnr1_res_id, ptnr1_atom_name = ptnr1.split("/") struct_conn_dict["ptnr1_label_asym_id"].append(ptnr1_chain_id) struct_conn_dict["ptnr1_label_comp_id"].append(ptnr1_res_name) struct_conn_dict["ptnr1_label_seq_id"].append(ptnr1_res_id) struct_conn_dict["ptnr1_label_atom_id"].append(ptnr1_atom_name) # Parse the second partner ptnr2_chain_id, ptnr2_res_name, ptnr2_res_id, ptnr2_atom_name = ptnr2.split("/") struct_conn_dict["ptnr2_label_asym_id"].append(ptnr2_chain_id) struct_conn_dict["ptnr2_label_comp_id"].append(ptnr2_res_name) struct_conn_dict["ptnr2_label_seq_id"].append(ptnr2_res_id) struct_conn_dict["ptnr2_label_atom_id"].append(ptnr2_atom_name) # Assuming all bonds are covalent for simplicity; adjust as needed struct_conn_dict["conn_type_id"].append("covale") except ValueError as e: raise ValueError(f"Error parsing bond string '{bond}': {e}") from None return struct_conn_dict def _get_bond_degree_per_atom(atom_array: struc.AtomArray) -> np.ndarray: """ Returns the total degree (= sum of bond orders) for each atom. """ # Count both ends of each edge edge_list = atom_array.bonds._bonds[:, :2] weights = atom_array.bonds._bonds[:, -1].copy() # ... remove aromaticity from the weights: weights[weights == struc.BondType.AROMATIC_SINGLE] = 1 weights[weights == struc.BondType.AROMATIC_DOUBLE] = 2 weights[weights == struc.BondType.AROMATIC_TRIPLE] = 3 degree = np.bincount(edge_list.ravel(), weights=np.repeat(weights, 2)) # ... pad in case of unbonded atoms if len(degree) <= atom_array.array_length(): degree = np.pad(degree, (0, atom_array.array_length() - len(degree))) return degree def correct_formal_charges_for_specified_atoms(atom_array: struc.AtomArray, to_update: np.ndarray) -> struc.AtomArray: """ Fix formal charges for atoms in an AtomArray after forming bonds between CCD components. Args: atom_array (AtomArray): The AtomArray to fix. to_update (np.ndarray): A boolean mask of atoms whose formal charges should be fixed. These are normally the atoms for which bonds were modified. Returns: AtomArray: The AtomArray with fixed formal charges. """ # ... check that the AtomArray has hydrogens if not np.isin(atom_array.element, HYDROGEN_LIKE_SYMBOLS).any(): logger.warning("Hydrogens not given. Cannot fix formal charges.") return atom_array # ... get valences (masked for elements with no default valence) _invalid = -10 default_valence = np.array([DEFAULT_VALENCE.get(elt, _invalid) for elt in atom_array.element[to_update]]) # ... compute total number of bonds per atom degree = _get_bond_degree_per_atom(atom_array)[to_update] # ... compute formal charge formal_charge = degree - default_valence # ... update the relevant entries valid = default_valence != _invalid # ... convert local indices to global indices global_idxs = np.arange(atom_array.array_length())[to_update] atom_array.charge[global_idxs[valid]] = formal_charge[valid] return atom_array def correct_bond_types_for_nucleophilic_additions( atom_array: struc.AtomArray, to_update: np.ndarray ) -> struc.AtomArray: """ Account for nucleophilic additions that result in carbons that violate the octet rule. In some cases (see: 1TQH), there is no leaving group specified, since the bond is formed by a nucleophilic addition to a carbonyl carbon. In this case, we should convert the C=O double bond to a C-O single bond. Args: atom_array (AtomArray): The AtomArray to fix. to_update (np.ndarray): A boolean mask of atoms that are candidates for correction. Returns: AtomArray: The AtomArray with fixed bond types. """ updated_carbon_mask = (atom_array.element == "C") & to_update if not updated_carbon_mask.any(): # (Early exit) return atom_array invalid_carbon_mask = (_get_bond_degree_per_atom(atom_array) > 4) & updated_carbon_mask bonds_arr = atom_array.bonds.as_array() for c_idx in np.where(invalid_carbon_mask)[0]: mask = (bonds_arr[:, 0] == c_idx) | (bonds_arr[:, 1] == c_idx) # If any of the bonds are to a hyrogen, we skip # (Handling hydrogens requires inferring leaving atoms, which is out-of-scope for this function) if np.any(np.isin(atom_array.element[bonds_arr[mask, 0]], HYDROGEN_LIKE_SYMBOLS)) or np.any( np.isin(atom_array.element[bonds_arr[mask, 1]], HYDROGEN_LIKE_SYMBOLS) ): continue # Check if any of the bonds are double bonds to an oxygen for bond_idx in np.where(mask)[0]: atom1, atom2, bond_type = bonds_arr[bond_idx] other_idx = atom2 if atom1 == c_idx else atom1 if atom_array.element[other_idx] == "O" and bond_type == struc.BondType.DOUBLE: # Set the bond order to single and log a warning atom_array.bonds.remove_bond(atom1, atom2) atom_array.bonds.add_bond(atom1, atom2, struc.BondType.SINGLE) logger.warning( f"Corrected C=O double bond to single bond between atom {c_idx} (C) and {other_idx} (O) due to nucleophilic addition (degree > 4). " f"chain_id: {atom_array.chain_id[c_idx]}, res_name: {atom_array.res_name[c_idx]}, res_id: {atom_array.res_id[c_idx]}, atom_name of invalid carbon: {atom_array.atom_name[c_idx]}, atom_name of oxygen: {atom_array.atom_name[other_idx]}" ) break return atom_array