"""Utility functions to visualize atom arrays with py3Dmol in Jupyter notebooks."""
__all__ = ["view"]
import gzip
import io
import logging
import os
import uuid
from itertools import cycle
from pathlib import Path
import biotite.structure as struc
import numpy as np
import py3Dmol
from biotite.structure import AtomArray, AtomArrayStack
from biotite.structure.io import mol, pdb, pdbx
from atomworks.io.constants import ATOMIC_NUMBER_TO_ELEMENT, METAL_ELEMENTS
from atomworks.io.utils.io_utils import read_any, to_cif_string
logger = logging.getLogger("atomworks.io")
try:
import pymol_remote.client
_is_pymol_remote_installed = True
except ImportError:
_is_pymol_remote_installed = False
logger.warning("PymolSession not installed, visualization will not work")
IPD_PYMOL_COLORS = [
"#888888", # pymol_gray
"#FAC72C", # good_yellow
"#29B0C1", # good_teal
"#AAC32F", # good_green
"#EC72A4", # good_pink
"#4499E7", # good_blue
"#DCDCDC", # good_gray
"#E44A3E", # good_red
"#65B37C", # good_light_green
"#4FB9AF", # paper_teal
"#FFE0AC", # paper_navaho
"#FFC6B2", # paper_melon
"#FFACB7", # paper_pink
"#D59AB5", # paper_purple
"#9596C6", # paper_lightblue
"#6686C5", # paper_blue
"#4B5FAA", # paper_darkblue
"#222222", # pymol_black
]
_is_metal = np.vectorize(lambda x: ATOMIC_NUMBER_TO_ELEMENT.get(x, x.capitalize()) in METAL_ELEMENTS)
[docs]
def view(
structure: AtomArray | AtomArrayStack,
*,
zoom_to_selection: dict[str, int | str] | None = None,
show_hover: bool = True,
show_unoccupied: bool = False,
show_cartoon: bool = True,
show_surface: bool = True,
width: int = 600,
height: int = 400,
ligand_linewidth: float = 0.2,
polymer_sidechain_linewidth: float = 0.05,
min_polymer_size: int = 1,
colors: list[str] = IPD_PYMOL_COLORS,
) -> py3Dmol.view:
"""Visualize an AtomArray structure using py3Dmol for display in jupyter notebooks.
Args:
- structure (AtomArray): The atomic structure to be visualized.
- zoom_to_selection (dict[str, int | str] | None, optional): A dictionary specifying the
selection to zoom into. Defaults to None. Here are some examples:
- `{'serial': 35}` - will zoom to the atom with index 35 in the atom array
- `{'chain': 'A', 'resi': 35}` - will zoom to the residue id 35 in chain A
- `{'chain': 'C'} - will zoom to the entire chain C
!WARNING! If the selection is wrong, the visualization will be empty.
- show_hover (bool, optional): Whether to enable hover functionality to display atom details.
Defaults to True.
- show_unoccupied (bool, optional): Whether to show unoccupied atoms. Defaults to False.
- show_cartoon (bool, optional): Whether to show the cartoon. Defaults to True.
- show_surface (bool, optional): Whether to show the surface. Defaults to False.
- width (int, optional): The width of the visualization window. Defaults to 400.
- height (int, optional): The height of the visualization window. Defaults to 300.
- ligand_linewidth (float, optional): The linewidth for ligand representation. Defaults to 0.2.
- polymer_sidechain_linewidth (float, optional): The linewidth for polymer sidechain representation. Defaults to 0.05.
- min_polymer_size (int, optional): The minimum size for a chain to be displayed as a polymer. Defaults to 1.
- colors (list[str], optional): A list of colors to cycle through for different chains. Defaults to IPD_PYMOL_COLORS.
Returns:
py3Dmol.view: The py3Dmol view object for the structure visualization.
"""
if isinstance(structure, AtomArrayStack):
logger.warning("AtomArrayStack is not supported; using the first model.")
structure = structure[0]
# Initialize the py3Dmol view with specified width and height
view = py3Dmol.view(width=width, height=height)
# Handle unoccupied atoms
if not show_unoccupied and ("occupancy" in structure.get_annotation_categories()):
structure = structure[structure.occupancy > 0]
# Convert the structure to a temporary CIF string for interacting with py3Dmol
_tmp_cif_str = to_cif_string(
structure,
_allow_ambiguous_bond_annotations=True,
include_entity_poly=False,
)
# ... add the structure model to the view in mmCIF format
view.addModel(_tmp_cif_str, "structure", format="mmcif")
# Get the chain IDs from the structure
chain_ids = struc.get_chains(structure)
# Iterate over each chain and assign styles based on the type of polymer
for chain_id, color in zip(chain_ids, cycle(colors)):
is_protein = np.all(
struc.filter_polymer(
structure[structure.chain_id == chain_id], pol_type="peptide", min_size=min_polymer_size
)
& struc.filter_amino_acids(structure[structure.chain_id == chain_id])
)
is_nucleic = np.any(
struc.filter_polymer(
structure[structure.chain_id == chain_id], pol_type="nucleotide", min_size=min_polymer_size
)
& struc.filter_nucleotides(structure[structure.chain_id == chain_id])
)
is_ion = np.all(_is_metal(structure[structure.chain_id == chain_id].element))
if is_protein or is_nucleic:
# Apply protein or nucleic acid style
style = {"stick": {"radius": polymer_sidechain_linewidth, "style": "outline"}}
if show_cartoon:
style["cartoon"] = {"color": color, "arrows": True}
view.setStyle({"chain": chain_id}, style)
elif is_ion:
view.setStyle(
{"chain": chain_id},
{"stick": {"radius": polymer_sidechain_linewidth, "style": "outline"}},
)
elif is_ion:
# Apply ion style
view.setStyle(
{"chain": chain_id},
{"sphere": {"scale": 0.8}},
)
else:
# Apply ligand style
# ... first, set the style for carbon atoms colored by chain
view.setStyle(
{"chain": chain_id, "elem": "C"},
{"stick": {"color": color, "radius": ligand_linewidth}},
)
# ... then, set the style for all other atoms based on the element
view.setStyle(
{"chain": chain_id, "not": {"elem": "C"}},
{"stick": {"colorscheme": "element", "radius": ligand_linewidth}},
)
if show_surface:
view.addSurface(py3Dmol.VDW, {"opacity": 0.4, "color": "gray"})
# Add hover functionality to display atom details on hover
if show_hover:
js_script = """function(atom,viewer) {
if(!atom.label) {
atom.label = viewer.addLabel(
atom.chain + ':' +
atom.resn + '(' + atom.resi + '):' +
atom.atom + '(idx' + atom.serial + ')',
{position: atom, backgroundColor:"white", fontColor:"black"}
);
}
}"""
view.setHoverable(
{},
True,
js_script,
"""function(atom,viewer) {
if(atom.label) {
viewer.removeLabel(atom.label);
delete atom.label;
}
}""",
)
# Zoom to the entire structure or to a specific selection if provided
view.zoomTo()
if zoom_to_selection is not None:
view.zoomTo(zoom_to_selection)
return view
def get_pymol_session(hostname: str | None = None, port: int | None = None) -> "pymol_remote.client.PymolSession":
"""
Establishes a connection to a PyMOL server and returns a `pymol_remote.client.PymolSession` object.
First attempts to reuse an existing global session if no hostname/port is specified.
Otherwise tries to establish a new connection, attempting up to 5 consecutive ports.
If you want to use `pymol_remote`, make sure to follow the usage instructions at
https://github.com/Croydon-Brixton/pymol-remote
Args:
- hostname (str | None, optional): The hostname of the PyMOL server. Defaults to 'localhost' if None.
- port (int | None, optional): The starting port number to attempt connection. Defaults to 9123 if None.
Returns:
pymol_remote.client.PymolSession: An active connection to the PyMOL server.
Raises:
- ImportError: If `pymol_remote` package is not installed.
- RuntimeError: If unable to establish connection after trying 5 consecutive ports.
"""
if not _is_pymol_remote_installed:
raise ImportError("`pymol_remote` is not installed or in the pythonpath, visualization will not work.")
# ... get existing session if available
if (hostname is None) and (port is None):
session = pymol_remote.client._GLOBAL_SERVER_PROXY
if session:
return pymol_remote.client.PymolSession(hostname=session.hostname, port=session.port, force_new=False)
# ... otherwise, try to connect to a new session
hostname = hostname or "localhost"
port = port or 9123
for i in range(5):
try:
session = pymol_remote.client.PymolSession(hostname=hostname, port=port + i)
break
except Exception:
session = None
pass
if session is None:
raise RuntimeError(
f"Failed to connect to Pymol on {hostname}:{port}."
"Ensure you are using SSH forwarding and `pymol_remote` correctly."
)
return session
def view_pymol(
structure: AtomArray
| AtomArrayStack
| pdbx.CIFFile
| pdbx.BinaryCIFFile
| pdb.PDBFile
| pdbx.CIFBlock
| pdbx.BinaryCIFBlock
| os.PathLike,
id: str | None = None,
hostname: str | None = None,
port: int | None = None,
as_bcif: bool = False,
overwrite: bool = False,
grid_slot: int | None = None,
) -> str:
"""
Visualizes an AtomArray structure in PyMOL by connecting to a PyMOL server and loading the structure. If no ID is
provided, generates a unique identifier for the structure.
Args:
- structure (AtomArray | AtomArrayStack | CIFFile | BinaryCIFFile | PDBFile | CIFBlock | BinaryCIFBlock | PathLike):
The atomic structure to be visualized in PyMOL. For `PathLike`, the file extension is used to determine the format
of the structure when no `id` is provided.
- id (str | None, optional): Unique identifier for the structure in PyMOL. If None, generates a random 9-character
string in XXX-XXX-XXX format. Defaults to None.
- hostname (str | None, optional): The hostname of the PyMOL server. If None, uses 'localhost' or attempts to reuse
an existing connection. Defaults to None.
- port (int | None, optional): The port number for the PyMOL server connection. If None, uses default port 9123 or
attempts to reuse existing connection. Defaults to None.
- as_bcif (bool, optional): Whether to transport the structure as BCIF instead of CIF. This speeds up the
network transfer of the structure but reading bcif files is not supported by all pymol versions.
(pymol 2.6 (LTS) and 3.1+ support bcif, older versions do not).
This only takes effect if `structure` is an `AtomArray` or `AtomArrayStack`.
Defaults to False.
- overwrite (bool, optional): Whether to overwrite an existing object with the same ID. Defaults to False.
- grid_slot (int | None, optional): The grid slot to use for the structure. If None, a random slot is chosen.
Defaults to None.
Returns:
str: The identifier used for the structure in PyMOL.
Raises:
ImportError: If `pymol_remote` package is not installed.
RuntimeError: If unable to establish connection to PyMOL server.
"""
# Establish a connection to the pymol server
session = get_pymol_session(hostname, port)
if isinstance(structure, str | Path):
id = id or os.path.basename(structure).split(".")[0]
structure = read_any(structure)
# Generate a unique ID for the structure if not provided
if id is None:
# Generate random 9-character string in 3-3-3 format
random_str = str(uuid.uuid4()).replace("-", "")[:9]
id = f"{random_str[:3]}-{random_str[3:6]}-{random_str[6:]}"
if id in session.get_names():
if overwrite:
logger.warning(f"Object {id=} already exists in PyMOL, overwriting.")
session.delete(id)
else:
raise ValueError(f"Object {id=} already exists in PyMOL, set overwrite=True to overwrite.")
# Send to pymol
if isinstance(structure, AtomArray | AtomArrayStack):
format = "bcif" if as_bcif else "cif"
buffer = to_cif_string(
structure,
id=id,
_allow_ambiguous_bond_annotations=True,
include_entity_poly=True,
include_nan_coords=False,
include_bonds=True,
extra_fields=[],
as_bcif=as_bcif,
)
elif isinstance(structure, pdbx.CIFFile | pdb.PDBFile | mol.SDFile | pdbx.CIFBlock):
format = {
pdbx.CIFFile: "cif",
pdb.PDBFile: "pdb",
mol.SDFile: "sdf",
pdbx.CIFBlock: "cif",
}[type(structure)]
buffer = io.StringIO()
if isinstance(structure, pdbx.CIFBlock):
_tmp = pdbx.CIFFile()
_tmp[id] = structure
structure = _tmp
structure.write(buffer)
buffer = buffer.getvalue()
elif isinstance(structure, pdbx.BinaryCIFFile | pdbx.BinaryCIFBlock):
format = "bcif"
buffer = io.BytesIO()
if isinstance(structure, pdbx.BinaryCIFBlock):
_tmp = pdbx.BinaryCIFFile()
_tmp[id] = structure
structure = _tmp
structure.write(buffer)
buffer = buffer.getvalue()
else:
raise ValueError(
f"Unsupported structure type: {type(structure)}. Only AtomArray, AtomArrayStack, CIFFile, and BCIFFile are supported."
)
# turn str into bytes if it is not already
if not isinstance(buffer, bytes):
buffer = buffer.encode("utf-8")
# compress for faster network transfer
buffer = gzip.compress(buffer)
session.set_state(buffer, object=id, format=format)
grid_slot = np.random.randint(0, 10_000) if grid_slot is None else grid_slot
session.set("grid_slot", grid_slot, id)
return id