"""Utilities to create the dataset to be loaded by RF2AA. See the main script for term Glossary."""
from __future__ import annotations
import logging
import math
from collections import defaultdict
from typing import Any, Final
import networkx as nx
import numpy as np
import requests
from biotite.structure import AtomArray, CellList
from scipy.spatial.distance import cdist
from atomworks.io.common import not_isin
from atomworks.io.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER, METAL_ELEMENTS
from atomworks.ml.common import default
from atomworks.ml.preprocessing.constants import ClashSeverity
logger = logging.getLogger("preprocess")
[docs]
def get_pn_units_with_non_biological_bonds(atom_array: AtomArray, bond_mask: np.ndarray) -> np.ndarray:
"""
Checks for non-biological bonds between PN units within an assembly.
Note that "inter-PN-unit bonds" in this instance do not include bonds between non-polymers,
which are treated as the same PN unit (by definition).
Specifically, this function looks for inter-molecular:
- Oxygen-oxygen bonds
- Fluorine-fluorine bonds
- Bonds involving free oxygen or hydroxyl groups (e.g. HOH, OH, O)
Args:
atom_array: AtomArray containing the relevant structure
bond_mask: Mask of inter-molecular bonds
Returns:
numpy.ndarray: Array of polymer/non-polymer unit instance IDs that contain non-biological bonds
"""
# Get atoms and residues involved in the inter-molecular bonds
bonds_to_check = atom_array.bonds.as_array()[bond_mask]
atom_a_elements = atom_array.atomic_number[bonds_to_check[:, 0]]
atom_b_elements = atom_array.atomic_number[bonds_to_check[:, 1]]
atom_a_res_names = atom_array.res_name[bonds_to_check[:, 0]]
atom_b_res_names = atom_array.res_name[bonds_to_check[:, 1]]
# Check for non-biological bonds
non_biological_bonds = (
(atom_a_elements == ELEMENT_NAME_TO_ATOMIC_NUMBER["O"])
& (atom_b_elements == ELEMENT_NAME_TO_ATOMIC_NUMBER["O"]) # Oxygen-oxygen bonds
| (atom_a_elements == ELEMENT_NAME_TO_ATOMIC_NUMBER["F"])
& (atom_b_elements == ELEMENT_NAME_TO_ATOMIC_NUMBER["F"]) # Fluorine-fluorine bonds
| np.isin(atom_a_res_names, ["HOH", "OH", "O"]) # Bonds involving free oxygen or hydroxyl groups
| np.isin(atom_b_res_names, ["HOH", "OH", "O"])
)
# Get unique polymer/non-polymer unit IDs with non-biological bonds
return np.unique(
np.concatenate(
[
atom_array.pn_unit_iid[bonds_to_check[:, 0]][non_biological_bonds],
atom_array.pn_unit_iid[bonds_to_check[:, 1]][non_biological_bonds],
]
)
)
[docs]
def get_clashing_pn_units(
pn_unit_iids_to_consider: np.Array, atom_array: AtomArray, cell_list: CellList, clash_distance: float
) -> tuple[set[str], dict[str, set[str]]]:
"""Wrapper function to find clashing PN units within an atom array."""
# Process PN unit-by-PN unit to avoid memory issues
clashing_pn_units_dict = defaultdict(set)
clashing_pn_units_set = set()
for query_pn_unit_iid in pn_unit_iids_to_consider:
query_pn_unit_atom_array = atom_array[atom_array.pn_unit_iid == query_pn_unit_iid]
clashing_pn_units = get_pn_units_clashing_with_pn_unit(
query_pn_unit_atom_array, atom_array, cell_list, clash_distance
)
if clashing_pn_units:
clashing_pn_units_set.update(clashing_pn_units | {query_pn_unit_iid})
clashing_pn_units_dict[query_pn_unit_iid] = set(clashing_pn_units)
return clashing_pn_units_set, clashing_pn_units_dict
[docs]
def get_atom_mask_from_cell_list(
coord: np.array, cell_list: CellList, cell_list_size: int, cutoff: float, chunk_size: int = int(2e9)
) -> np.ndarray:
"""
Builds a mask indicating which atoms clash with the query PN unit. If the number of comparisons is too large,
the computation is split into manageable chunks along the rows of `coord`.
Args:
coord (ndarray): The coordinates of the query PN unit. Shape is (n, 3).
cell_list (CellList): A CellList object that allows efficient vicinity searches.
cell_list_size (int): The number of atoms in the cell list.
clash_distance (float): The distance threshold below which atoms are considered to be clashing.
chunk_size (int): The maximum number of comparisons allowed in a single chunk.
Returns:
ndarray: Mask indicating which atoms in `cell_list` clash with the atoms in `coord`. Shape is (n, cell_list_size), dtype is bool.
"""
num_coords = coord.shape[0]
clashing_atom_mask = np.zeros((num_coords, cell_list_size), dtype=bool)
max_rows_per_chunk = max(1, chunk_size // cell_list_size) # Ensure at least 1 row per chunk
if num_coords * cell_list_size > chunk_size:
logger.info(
f"{num_coords * cell_list_size:,} comparisons needed; distance computation split into {math.ceil(num_coords / max_rows_per_chunk)} chunks."
)
for i in range(0, num_coords, max_rows_per_chunk):
end = min(i + max_rows_per_chunk, num_coords)
clashing_atom_mask[i:end, :] = cell_list.get_atoms(coord[i:end], cutoff, as_mask=True)
else:
clashing_atom_mask = cell_list.get_atoms(coord, cutoff, as_mask=True)
return clashing_atom_mask
[docs]
def get_pn_units_clashing_with_pn_unit(
query_pn_unit: AtomArray, filtered_atom_array: AtomArray, cell_list: CellList, clash_distance: float
) -> set[str]:
"""
Finds clashes between a query PN unit and the rest of the structure.
A clash is defined as any pair of atoms from the query PN unit and the rest of the structure that are closer than `clash_distance`.
Args:
query_pn_unit: AtomArray containing the query PN unit that we want to check for clashes
filtered_atom_array: AtomArray containing atoms with non-zero occupancy
cell_list: CellList of structure for rapid distance computations
clash_distance: Distance threshold for clashing atoms
Returns:
Set of polymer/non-polymer unit instance IDs that are clashing with the query PN unit
"""
clashing_atom_mask = get_atom_mask_from_cell_list(
query_pn_unit.coord, cell_list, len(filtered_atom_array), clash_distance
)
collapsed_mask = np.any(clashing_atom_mask, axis=0)
clashing_atoms = filtered_atom_array[collapsed_mask]
# Filter out query PN unit atoms
query_pn_unit_iids = np.unique(query_pn_unit.pn_unit_iid)
clashing_atoms = clashing_atoms[not_isin(clashing_atoms.pn_unit_iid, query_pn_unit_iids)]
return set(np.unique(clashing_atoms.pn_unit_iid))
[docs]
def handle_clashing_pn_units(
clashing_pn_units_set: set, clashing_pn_units_dict: dict, atom_array: AtomArray, pn_unit_iid_map: dict
) -> AtomArray:
"""
Resolves clashing PN units according to the following process:
1. Sort clashing PN units by the number of atoms within the PN unit
2. Iterate through the sorted list keeping (1) the larger PN unit (2) the lower transformation number until all clashes are resolved
Args:
clashing_pn_units_set: Set of polymer/non-polymer unit instance IDs that contain clashing atoms
clashing_pn_units_dict: Dictionary mapping polymer/non-polymer unit instance IDs to a list of clashing polymer/non-polymer unit instance IDs
atom_array: AtomArray containing atoms with non-zero occupancy
pn_unit_iid_map: Dictionary mapping integer representations of polymer/non-polymer unit instance IDs to the polymer/non-polymer unit instance IDs themselves
Returns:
AtomArray: AtomArray with clashing atoms removed
ClashSeverity: Enum representing the severity of the clash
"""
pn_units_to_remove = set()
pn_units_to_keep = set()
# Build a dictionary of clashing PN unit details
clashing_pn_unit_details = {}
for pn_unit in clashing_pn_units_set:
pn_unit_atom_array = atom_array[atom_array.pn_unit_iid == pn_unit]
clashing_pn_unit_details[pn_unit] = {
"num_atoms": len(pn_unit_atom_array),
"is_metal": pn_unit_atom_array[0].element.upper in METAL_ELEMENTS,
"is_polymer": pn_unit_atom_array[0].is_polymer,
}
# Sort clashing PN units by the number of atoms in the reference PN unit (high to low) and then the transformation ID (low to high)
# The last character in the polymer/non-polymer unit instance ID is the transformation ID, even in cases where we have covalently bound chains (e.g., A_1,B_1)
def sort_key(x: tuple[str, dict[str, Any]]) -> tuple[int, int]:
try:
return (x[1]["num_atoms"], -int(pn_unit_iid_map[x[0]][-1]))
except ValueError:
return (x[1]["num_atoms"], -ord(pn_unit_iid_map[x[0]][-1]))
sorted_clashing_pn_units = [x[0] for x in sorted(clashing_pn_unit_details.items(), key=sort_key, reverse=True)]
# Define the ClashSeverity
num_polymers = len(
np.unique(atom_array.pn_unit_iid[atom_array.is_polymer])
) # Number of polymer PN units in the structure
num_clashing_polymers = len(
np.unique([pn_unit for pn_unit in clashing_pn_units_set if clashing_pn_unit_details[pn_unit]["is_polymer"]])
) # Number of clashing polymer PN units
if num_clashing_polymers / num_polymers > 0.5:
clash_severity = ClashSeverity.SEVERE
elif num_clashing_polymers > 0:
clash_severity = ClashSeverity.MODERATE
else:
clash_severity = ClashSeverity.MILD
# Keep the larger PN unit
for pn_unit in sorted_clashing_pn_units:
if pn_unit not in pn_units_to_remove:
pn_units_to_keep.add(pn_unit)
pn_units_to_remove.update(clashing_pn_units_dict[pn_unit] - pn_units_to_keep)
logger.warning(
f"Removing clashing PN units: {[pn_unit_iid_map[pn_unit] for pn_unit in pn_units_to_remove]} from the structure."
)
# Remove clashing PN units
atom_array = atom_array[not_isin(atom_array.pn_unit_iid, list(pn_units_to_remove))]
return atom_array, clash_severity
[docs]
def get_intra_pn_unit_bonds(pn_unit_iid: str, full_atom_array: AtomArray) -> np.ndarray:
# NOTE: Not currently used; kept for potential future use
"""
Retrieve all bonds within the PN unit.
Includes inter-chain bonds if the PN unit is composed of multiple chains.
Does NOT include bonds between the PN unit and any other PN units (e.g., protein-ligand bonds).
Args:
pn_unit_iid (str): The polymer/non-polymer unit instance ID (e.g., 'A1', 'C2,B3', etc.)
full_atom_array (AtomArray): The full structure AtomArray. Must include all atoms, including those with zero occupancy.
Returns:
numpy.ndarray: Array of Biotite bond objects (atom_a_index, atom_b_index, bond_type) within the specified PN unit.
"""
all_bonds = full_atom_array.bonds.as_array()
atom_a_indices, atom_b_indices = all_bonds[:, 0], all_bonds[:, 1]
mask = (full_atom_array.pn_unit_iid[atom_a_indices] == pn_unit_iid) & (
full_atom_array.pn_unit_iid[atom_b_indices] == pn_unit_iid
)
return all_bonds[mask]
[docs]
def calculate_molecule_diameter(full_molecule_atoms: AtomArray) -> float:
# NOTE: Not currently used; kept for potential future use
"""
Calculates the molecular diameter, defined as the maximum number of bonds between any two vertices
in the molecule, using a maximum spanning tree.
Args:
full_molecule_atoms (AtomArray): The array of atoms representing the molecule. Must include all atoms, including those with zero occupancy.
Returns:
float: The molecular diameter. If the diameter cannot be computed, returns 0.0.
"""
if len(full_molecule_atoms.bonds.as_array() > 0):
try:
nx_graph = full_molecule_atoms.bonds.as_graph() # as a NetworkX graph
tree = nx.maximum_spanning_tree(nx_graph)
return nx.diameter(tree)
except nx.exception.NetworkXError:
logger.warning("Could not compute diameter.")
return 0.0
else:
return 0.0
[docs]
def get_soi_ligands_from_pdb_id(pdb_id: str) -> set[str]:
"""
This function takes a PDB ID and returns a set of ligand names that are annotated as subject of
investigation (SOI) in the PDB. Such ligands are often considered
biologically meaningful.
NOTE: This function is kept from the old pipeline for convenience and testing,
but not used in the current processing pipeline, where the LOI is extracted directly
from the cif file instead of via an FTP query.
Args:
pdb_id (str): The PDB ID to query.
Returns:
Set[str]: A set of ligand names that are annotated as SOI in the PDB.
"""
try:
pdb_id = pdb_id.lower()
core_response = requests.get(f"https://data.rcsb.org/rest/v1/core/entry/{pdb_id}")
core_json_dict = core_response.json()
nonpolymer_ids = core_json_dict["rcsb_entry_container_identifiers"].get("non_polymer_entity_ids", [])
soi_ligand_names = []
for nonpolymer_id in nonpolymer_ids:
ligand_response = requests.get(
f"https://data.rcsb.org/rest/v1/core/nonpolymer_entity/{pdb_id}/{nonpolymer_id}"
)
ligand_json_dict = ligand_response.json()
ligand_name = ligand_json_dict["pdbx_entity_nonpoly"]["comp_id"]
if "rcsb_nonpolymer_entity_annotation" in ligand_json_dict:
for annotation_dict in ligand_json_dict["rcsb_nonpolymer_entity_annotation"]:
if annotation_dict["type"] == "SUBJECT_OF_INVESTIGATION" and annotation_dict["name"] == ligand_name:
soi_ligand_names.append(ligand_name)
break
except Exception:
return set()
return set(soi_ligand_names)
[docs]
def get_ligand_validity_scores_from_pdb_id(pdb_id: str) -> list[dict[str, str | int | float | None]]:
"""
Query the RCSB PDB for ligand validity scores for a given PDB ID.
Args:
pdb_id (str): The PDB ID to query.
Returns:
records: (list[dict[str, str | int | float | None]]): A list of dictionaries, each containing
the ligand validity scores for a ligand (e.g. RSCC, RSR) as well identifiers such as the
residue name, chain ID, and entity ID. Can easily be converted to a pandas DataFrame for
easier handling via `pd.DataFrame(records)`.
References:
- https://www.rcsb.org/docs/general-help/ligand-structure-quality-in-pdb-structures
"""
pdb_graphql_url: Final[str] = "https://data.rcsb.org/graphql"
# Query string in graphql language to get ligand validity scores from a PDB entry
ligand_validity_query: Final[str] = """
query ($id: String!) {
entry(entry_id:$id){
nonpolymer_entities {
rcsb_nonpolymer_entity_container_identifiers {
nonpolymer_comp_id
rcsb_id
}
rcsb_nonpolymer_entity_annotation {
type
}
nonpolymer_entity_instances {
rcsb_nonpolymer_entity_instance_container_identifiers {
auth_seq_id
auth_asym_id
asym_id
entity_id
entry_id
}
rcsb_nonpolymer_instance_validation_score {
RSCC
RSR
alt_id
completeness
intermolecular_clashes
is_best_instance
mogul_angle_outliers
mogul_angles_RMSZ
mogul_bond_outliers
mogul_bonds_RMSZ
ranking_model_fit
ranking_model_geometry
score_model_fit
score_model_geometry
stereo_outliers
average_occupancy
type
is_subject_of_investigation
is_subject_of_investigation_provenance
}
}
}
}
}
"""
# Perform the actual query for the target PDB ID
response = requests.post(pdb_graphql_url, json={"query": ligand_validity_query, "variables": {"id": pdb_id}})
# Extract the records from the response
records = []
if response.status_code == 200:
data = response.json()
try:
nonpolymer_entities = default(data["data"]["entry"]["nonpolymer_entities"], [])
for entity in nonpolymer_entities:
res_name = entity["rcsb_nonpolymer_entity_container_identifiers"]["nonpolymer_comp_id"]
for instance in entity.get("nonpolymer_entity_instances", []):
record_template = {"res_name": res_name}
record_template.update(instance.get("rcsb_nonpolymer_entity_instance_container_identifiers", {}))
validation_scores = default(instance["rcsb_nonpolymer_instance_validation_score"], [])
for score in validation_scores:
record = record_template.copy()
record.update(score)
records.append(record)
except KeyError:
logger.debug(f"No validation scores found for PDB ID: {pdb_id}")
except TypeError:
logger.debug(f"No validation scores found for PDB ID: {pdb_id}")
else:
logger.debug(f"Query failed with status code {response.status_code} and response: {response.text}")
return records
[docs]
def get_inter_pn_unit_bond_mask(atom_array: AtomArray) -> np.ndarray:
"""
Returns a mask indicating which bonds in `atom_array.bonds` are between two distinct PN units.
Because we are operating at the PN unit-level, such bonds cannot be bonds between non-polymers.
WARNING: Must be applied before reassigning PN unit IIDs (e.g., as is done for covalent modifications).
Args:
atom_array (AtomArray): The full atom array. Must have PN unit-level annotations.
Returns:
numpy.ndarray: A boolean mask indicating which bonds are between two PN units.
"""
bond_pn_unit_a = atom_array.pn_unit_iid[atom_array.bonds.as_array()[:, 0]]
bond_pn_unit_b = atom_array.pn_unit_iid[atom_array.bonds.as_array()[:, 1]]
return bond_pn_unit_a != bond_pn_unit_b
[docs]
def get_bonded_polymer_pn_units(query_pn_unit_iid: str, filtered_atom_array: AtomArray) -> set[str]:
"""
Returns a set of polymer PN units that are covalently bonded to a given PN unit.
For example, useful to detect oligosaccharides that are covalently bonded to a protein.
Args:
query_pn_unit_iid (str): The full ID of the non-polymer PN unit to check for bonds.
filtered_atom_array (AtomArray): AtomArray with non-zero occupancy
Returns:
set[str]: A set of full IDs of polymer PN units that are covalently bonded to the query PN unit.
"""
# Check if the non polymer is covalently bonded to a polymer
inter_pn_unit_bonds = filtered_atom_array.bonds.as_array()[get_inter_pn_unit_bond_mask(filtered_atom_array)]
bond_atom_a_pn_unit_iids = filtered_atom_array.pn_unit_iid[inter_pn_unit_bonds[:, 0]]
bond_atom_b_pn_unit_iids = filtered_atom_array.pn_unit_iid[inter_pn_unit_bonds[:, 1]]
# Find bonded PN units
bonded_pn_units = set(bond_atom_b_pn_unit_iids[np.where(bond_atom_a_pn_unit_iids == query_pn_unit_iid)[0]]) | set(
bond_atom_a_pn_unit_iids[np.where(bond_atom_b_pn_unit_iids == query_pn_unit_iid)[0]]
)
# Get the set of polymers
polymer_pn_unit_iids = set(filtered_atom_array.pn_unit_iid[filtered_atom_array.is_polymer])
# Set intersection between bonded PN units and polymer PN units
return bonded_pn_units & polymer_pn_unit_iids