Source code for buildamol.utils.visual

"""
Visualization auxiliary functions
"""

import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.colors as colors

import buildamol.utils as utils
import buildamol.utils.auxiliary as aux

import periodictable

Draw = aux.Draw
Chem = aux.Chem

DEFAULT_BACKEND = "plotly"
"""
The default visualization backend for bare `draw` and `show` methods on objects. 
"""


def set_backend(backend: str):
    """
    Set the default visualization backend, which will be used by objects when calling `draw` and `show` methods.

    Parameters
    ----------
    backend : str
        Available backends are:
            - plotly (default)
            - py3dmol
            - nglview
    """
    backend = backend.strip().lower()
    if backend not in ("plotly", "py3dmol", "nglview"):
        raise ValueError(f"Unsupported backend: {backend}")
    global DEFAULT_BACKEND
    DEFAULT_BACKEND = backend


default_plotly_opacity = 1.0
"""
The default opacity for plotly-based visualizations.
"""

default_plotly_marker_size = 5
"""
The default marker size for plotly-based visualizations.
"""

default_plotly_bond_color = "black"
"""
The default color for plotly-based bond visualizations.
"""

default_plotly_linewidth = 1.2
"""
The default linewidth for plotly-based bond visualizations.
"""


[docs] class Chem2DViewer: """ View a molecule in 2D using the RDKit library. Parameters ---------- molecule The molecule to view. This may be any object that holds a biopython structure e.g. a Molecule, AtomGraph, or ResidueGraph. drawer: str The 2D drawer to use. This can be any of: - png (default, uses `MolDraw2DCairo`, requires cairo to be installed) - svg (uses `MolDraw2DSVG`) highlight_color : str The color to use for highlighting atoms and bonds (deprecated, specify a color when calling `highlight_atoms` or `highlight_bonds` instead). linewidth : float The linewidth of the bonds (deprecated, specify a linewidth when calling the `draw` method instead). atoms : str The label to use for the atoms (deprecated, specify a label when calling the `label_atoms` method instead). This can be any of the following: - None (element, except for carbons) - "element" (elements, even for carbons) - "serial" (the atom serial number) - "id" (the atom id / name) - "resid" (atom id + parent residue) - "off" (no label) - any function that takes an (rdkit) atom and returns a string. """ def __init__( self, molecule, drawer: str = "png", highlight_color: str = None, linewidth: float = None, atoms: str = None, ): if Chem is None: raise ImportError( "rdkit is not available. Please install it and be sure to use a compatible environment." ) drawer = drawer.strip().lower() if drawer not in ("svg", "png"): raise ValueError(f"Unsupported drawer: {drawer}") self._drawer_type = drawer self._raw_molecule = None self._raw_is_rdkit = False if hasattr(molecule, "to_rdkit"): mol = molecule.to_rdkit() self._raw_molecule = molecule elif molecule.__class__.__name__ in ("AtomGraph", "ResidueGraph"): mol = molecule._molecule.to_rdkit() self._raw_molecule = molecule elif "Chem" in str(molecule.__class__.mro()[0]): mol = molecule self._raw_molecule = molecule self._raw_is_rdkit = True else: raise ValueError( f"Unsupported molecule type: {molecule.__class__.__name__}" ) mol.RemoveAllConformers() self.mol = mol if atoms is not None: aux.deprecation_warning( "The `atoms` argument is deprecated and will be removed in future versions. Please use the `label_atoms` method instead." ) if atoms == "element": atoms = lambda atom: atom.GetSymbol() elif atoms == "serial": atoms = lambda atom: str(atom.GetPDBResidueInfo().GetSerialNumber()) elif atoms == "id": atoms = lambda atom: atom.GetPDBResidueInfo().GetName().strip() elif atoms == "off": atoms = lambda atom: "" elif atoms == "resid": def atoms(atom): info = atom.GetPDBResidueInfo() return f"{info.GetName().strip()}@{info.GetResidueName().strip()}[{info.GetResidueNumber()}]" elif callable(atoms): pass else: raise ValueError(f"Unsupported atom label: {atoms}") for a in mol.GetAtoms(): a.SetProp("atomLabel", atoms(a)) self._atoms_to_highlight = [] self._atoms_highlight_colors = {} self._atoms_highlight_radii = {} self._bonds_to_highlight = {} if highlight_color is not None: aux.deprecation_warning( "The `highlight_color` argument is deprecated and will be removed in future versions. Please specify a color when calling `highlight_atoms` or `highlight_bonds` instead." ) if linewidth is not None: aux.deprecation_warning( "The `linewidth` argument is deprecated and will be removed in future versions. Please specify a linewidth when calling the `draw` method instead." ) self.highlight_color = highlight_color or "cyan" self.linewidth = linewidth or 1 self.options = Draw.MolDrawOptions() self._custom_colors = {}
[docs] def label_atoms(self, func_or_mapping, rdkit: bool = None): """ Generate custom atom labels Parameters ---------- func_or_mapping : str, callable or dict If a string is provided it has to be one of the following: - "element" (elements, even for carbons) - "serial" (the atom serial number) - "id" (the atom id / name) - "resid" (atom id + parent residue) - "off" (no label) Alternatively, either a function that takes an atom and returns a string. Or a dictionary mapping atoms to strings. Only one type of key can be included in the dictionary! Supported dictionary keys are: - BuildAMol Atoms - atom serial numbers (int) - atom ids (str) (will match all atoms with that id) rdkit : bool Whether the function takes an RDKit atom or a BuildAMol atom. """ if isinstance(func_or_mapping, str): if func_or_mapping == "element": func_or_mapping = lambda atom: atom.GetSymbol() elif func_or_mapping == "serial": func_or_mapping = lambda atom: str( atom.GetPDBResidueInfo().GetSerialNumber() ) elif func_or_mapping == "id": func_or_mapping = ( lambda atom: atom.GetPDBResidueInfo().GetName().strip() ) elif func_or_mapping == "off": func_or_mapping = lambda atom: "" elif func_or_mapping == "resid": func_or_mapping = ( lambda atom: f"{atom.GetPDBResidueInfo().GetName().strip()}@{atom.GetPDBResidueInfo().GetResidueName().strip()}[{atom.GetPDBResidueInfo().GetResidueNumber()}]" ) else: raise ValueError( f"Unsupported atom label for string identifier: '{func_or_mapping}'. Supported are 'element', 'serial', 'id', 'resid', and 'off'." ) rdkit = True if rdkit is False and self._raw_is_rdkit: raise ValueError( "The underlying molecule is an RDKit molecule. Cannot perform BuildAMol operations on RDKit Atoms. Please set `rdkit=True`." ) if callable(func_or_mapping): try: test_atom = next(iter(self.mol.GetAtoms())) func_or_mapping(test_atom) rdkit = True except Exception: if self._raw_is_rdkit: raise ValueError( "The underlying molecule is an RDKit molecule. Cannot perform BuildAMol operations on RDKit Atoms. Please set `rdkit=True`." ) rdkit = False elif isinstance(func_or_mapping, dict): # allow for Atoms as well under the hood first_key = next(iter(func_or_mapping.keys())) if hasattr(first_key, "GetPDBResidueInfo"): func_or_mapping = { atom.GetPDBResidueInfo().GetSerialNumber(): label for atom, label in func_or_mapping.items() } elif hasattr(first_key, "element") and hasattr(first_key, "serial_number"): func_or_mapping = { atom.serial_number: label for atom, label in func_or_mapping.items() } elif isinstance(first_key, str): _func_or_mapping = {} for key, label in func_or_mapping.items(): matching_atoms = self._raw_molecule.get_atoms(key) for atom in matching_atoms: _func_or_mapping[atom.serial_number] = label func_or_mapping = _func_or_mapping elif not isinstance(first_key, int): raise ValueError( "When providing a mapping, the keys must be either BuildAMol Atoms, RDKit Atoms, or atom serial numbers (ints)." ) def default_label(atom): element = atom.GetSymbol() is_carbon = element == "C" if is_carbon: return "" else: return element func = lambda atom: func_or_mapping.get( atom.GetPDBResidueInfo().GetSerialNumber(), default_label(atom) ) for a in self.mol.GetAtoms(): a.SetProp("atomLabel", func(a)) return self else: raise ValueError( f"func_or_mapping must be a valid string identifier, a callable or a dictionary, got {type(func_or_mapping)}." ) if rdkit or (rdkit is None and self._raw_is_rdkit): for a in self.mol.GetAtoms(): a.SetProp("atomLabel", str(func_or_mapping(a))) else: for a in self.mol.GetAtoms(): serial = a.GetPDBResidueInfo().GetSerialNumber() bam_atom = self._raw_molecule.get_atom(serial, by="serial") a.SetProp("atomLabel", str(func_or_mapping(bam_atom))) return self
[docs] def highlight_atoms(self, *atoms, color, radius=0.3): """ Highlight atoms in the molecule. Parameters ---------- atoms : list The Atoms to highlight. color The color to use for highlighting. This can be either a string, a tuple of RGB values, or a callable that takes an atom and returns a color. radius The radius to use for highlighting. This can be a float or a callable that takes an atom and returns a float. """ if isinstance(atoms[0], (list, tuple, set)) and len(atoms) == 1: atoms = atoms[0] a = atoms[0] if isinstance(a, (str, int)): if self._raw_is_rdkit: raise ValueError( "When providing atom ids or serial numbers the underlying molecule cannot be an RDKit molecule. Please provide RDKit or BuildAMol Atoms directly." ) atoms = self._raw_molecule.get_atoms(atoms) elif hasattr(a, "element") and hasattr(a, "serial_number"): pass elif hasattr(a, "GetPDBResidueInfo"): if not self._raw_is_rdkit: atoms = [ self._raw_molecule.get_atom(a.GetPDBResidueInfo().GetSerialNumber()) for a in atoms ] self._atoms_to_highlight.extend(atoms) if callable(color): color = {atom: color(atom) for atom in atoms} else: color = {atom: color for atom in atoms} self._atoms_highlight_colors.update(color) if callable(radius): radius = {atom: radius(atom) for atom in atoms} else: radius = {atom: radius for atom in atoms} self._atoms_highlight_radii.update(radius) return self
[docs] def highlight_bonds(self, *bonds, color=None): """ Highlight bonds in the molecule. Parameters ---------- bonds : list The bonds (tuples of BuildAMol Atoms) to highlight. color The color to use for highlighting. This can be either a string or a tuple of RGB values, or a callable that takes a bond and returns a color. """ if isinstance(bonds[0], (list, tuple, set)) and len(bonds) == 1: bonds = bonds[0] if callable(color): bonds = {bond: color(bond) for bond in bonds} else: bonds = {bond: color or self.highlight_color for bond in bonds} self._bonds_to_highlight.update(bonds) return self
[docs] def highlight_residues(self, *residues, color): """ Highlight all bonds and atoms in the given residues. Parameters ---------- residues : list The residues (BuildAMol Residue objects) whose bonds to highlight. color The color to use for highlighting. This can be either a string or a tuple of RGB values, or a callable that takes a bond and returns a color. """ if self._raw_is_rdkit: raise ValueError( "When providing residues the underlying molecule cannot be an RDKit molecule. Please provide RDKit or BuildAMol Atoms directly." ) if isinstance(residues[0], (list, tuple, set)) and len(residues) == 1: residues = residues[0] bonds = [] atoms = [] for residue in residues: residue = self._raw_molecule.get_residue(residue) atoms.extend(residue.get_atoms()) bonds.extend(self._raw_molecule.get_bonds(residue)) self.highlight_bonds(*bonds, color=color).highlight_atoms(*atoms, color=color) return self
[docs] def set_colors(self, _element_colors: dict): """ Set the colors for the atoms. Parameters ---------- _element_colors : dict A dictionary mapping element symbols to colors. The keys should be elementy symbols (i.e. 6 for carbon, etc.) and the values should be RGB tuples. """ _element_colors = { key: colors.to_rgba(value) for key, value in _element_colors.items() } keys = list(_element_colors.keys()) for k in keys: if isinstance(k, str): _k = periodictable.elements.symbol(k).number _element_colors[_k] = _element_colors.pop(k) self._custom_colors.update(_element_colors) return self
[docs] def set_options(self, **kwargs): """ Set the default drawing options. Parameters ---------- **kwargs Any additional arguments to pass to `MolDrawOptions`. """ aux.deprecation_warning( "The `set_options` method is deprecated and will be removed in future versions. Provide drawing options directly as kwargs to the `draw` method instead." ) for k, v in kwargs.items(): if not k.startswith("_") and hasattr(self.options, k): setattr(self.options, k, v) return self
[docs] def draw( self, draw_hydrogens: bool = False, linewidth: float = 1, fontsize: float = 20, width: int = 1000, height: int = 500, background: tuple = None, **kwargs, ): """ Generate the 2D image. Parameters ---------- draw_hydrogens : bool Whether to draw hydrogens. linewidth: float The linewidth of the bonds. fontsize : float The font size of the atom labels. width : int The width of the image in pixels. height : int The height of the image in pixels. background : tuple The background color to use. Use `None` for a transparent background. **kwargs Any additional arguments to pass to the `MolDrawOptions` of the RDKit drawer (either `MolDraw2DSVG` or `MolDraw2DCairo`). Returns ------- str or PIL.Image.Image The SVG string (if drawer is "svg") or a PIL Image (if drawer is "png"). """ if not draw_hydrogens: mol = Chem.rdmolops.RemoveHs(self.mol) else: mol = self.mol drawer = ( Draw.rdMolDraw2D.MolDraw2DSVG(width, height) if self._drawer_type == "svg" else Draw.rdMolDraw2D.MolDraw2DCairo(width, height) ) draw_options = drawer.drawOptions() for k, v in self.options.__dict__.items(): setattr(draw_options, k, v) for k, v in self.options.__dict__.items(): setattr(draw_options, k, v) for k, v in kwargs.items(): if not k.startswith("_") and hasattr(draw_options, k): setattr(draw_options, k, v) draw_options.bondLineWidth = linewidth if self._custom_colors: draw_options.updateAtomPalette(self._custom_colors) draw_options.fixedFontSize = fontsize if background is None: draw_options.clearBackground = False else: if isinstance(background, str): background = colors.to_rgba(background) draw_options.setBackgroundColour(colors.to_rgba(background)) kws = self._prepare_highlighting(mol, draw_hydrogens) drawer.DrawMoleculeWithHighlights(mol, legend="", **kws) drawer.FinishDrawing() if self._drawer_type == "svg": svg = drawer.GetDrawingText() return svg else: from PIL import Image from io import BytesIO img = drawer.GetDrawingText() img = Image.open(BytesIO(img)) return img
[docs] def show(self, draw_hydrogens: bool = False, **kwargs): """ Show the molecule Parameters ---------- draw_hydrogens : bool Whether to draw hydrogens. **kwargs Any additional keyword arguments to pass to `draw`. """ if self._drawer_type == "svg": dpi = kwargs.pop("dpi", 300) out = self.draw(draw_hydrogens=draw_hydrogens, **kwargs) if self._drawer_type == "svg": # turn SVG string into a PIL image from io import BytesIO from svglib.svglib import svg2rlg from reportlab.graphics import renderPM img = svg2rlg(BytesIO(out.encode("utf-8"))) img = renderPM.drawToPIL(img, dpi=dpi) img.show() else: out.show()
def _prepare_highlighting(self, mol, include_hydrogens: bool): kws = {} has_atoms_to_highlight = len(self._atoms_to_highlight) > 0 has_bonds_to_highlight = len(self._bonds_to_highlight) > 0 kws["highlight_atom_map"] = {} kws["highlight_bond_map"] = {} kws["highlight_radii"] = {} kws["highlight_linewidth_multipliers"] = {} if not has_atoms_to_highlight and not has_bonds_to_highlight: return kws if not include_hydrogens: atom_filter = lambda atom: atom.element != "H" bond_filter = lambda bond: ( bond[0].element != "H" and bond[1].element != "H" ) else: atom_filter = lambda atom: True bond_filter = lambda bond: True if has_atoms_to_highlight: highlight_atom_map = { atom: [colors.to_rgba(color)] for atom, color in self._atoms_highlight_colors.items() if atom_filter(atom) } highlight_radii = { atom: radius for atom, radius in self._atoms_highlight_radii.items() if atom_filter(atom) } highlight_atom_map = { self._rdkit_atom_from_buildamol_atom(atom, mol).GetIdx(): color for atom, color in highlight_atom_map.items() } highlight_radii = { self._rdkit_atom_from_buildamol_atom(atom, mol).GetIdx(): radius for atom, radius in highlight_radii.items() } kws["highlight_atom_map"] = highlight_atom_map kws["highlight_radii"] = highlight_radii if has_bonds_to_highlight: highlight_bond_map = { idx: color for idx, color in self._bonds_to_highlight.items() if bond_filter(idx) } highlight_bond_map = { bond.GetIdx(): [colors.to_rgba(color)] for (atom1, atom2), color in highlight_bond_map.items() if ( bond := mol.GetBondBetweenAtoms( self._rdkit_atom_from_buildamol_atom(atom1, mol).GetIdx(), self._rdkit_atom_from_buildamol_atom(atom2, mol).GetIdx(), ) ) } kws["highlight_bond_map"] = highlight_bond_map return kws def _rdkit_atom_from_buildamol_atom(self, atom, mol=None): mol = mol or self.mol if hasattr(atom, "GetPDBResidueInfo"): return atom elif hasattr(atom, "serial_number"): _atom = next( ( _atom for _atom in mol.GetAtoms() if _atom.GetPDBResidueInfo().GetSerialNumber() == atom.serial_number ), None, ) if _atom is None: raise ValueError(f"Atom {atom} did not have an RDKit equivalent.") elif isinstance(atom, int): _atom = next( (_atom for _atom in mol.GetAtoms() if _atom.GetIdx() == atom), None, ) if _atom is None: raise ValueError( f"Atom with index {atom} did not have an RDKit equivalent." ) else: raise ValueError( f"Unsupported atom type: {atom.__class__.__name__}. The input has to be a BuildAMol Atom, an RDKit Atom, or an atom index (int)." ) return _atom
[docs] class Py3DmolViewer: """ View a molecule in 3D using the py3Dmol library. Attributes ---------- view : py3Dmol.view The py3Dmol view object. Parameters ---------- molecule The molecule to view. width : int The width of the viewer in pixels. height : int The height of the viewer in pixels. style : dict The style to apply to the visualization. """ default_style = {"stick": {}} def __init__( self, molecule, width: int = 500, height: int = 500, style: dict = None ) -> None: try: import py3Dmol except ImportError: py3Dmol = None if py3Dmol is None or Chem is None: raise ImportError( "py3Dmol and/or rdkit are not available. Please install them and be sure to use a compatible (Jupyter) environment." ) if isinstance(molecule, (list, tuple, set)): molecule = aux.AtomIterator(molecule) if not hasattr(molecule, "get_atoms"): raise ValueError( f"Unsupported molecule type: {molecule.__class__.__name__}. The input has to be a Py3DmolViewer, Molecule, or any other class with an 'get_atoms' method that can be converted to PDB." ) if hasattr(molecule, "to_pdb"): self.pdb = utils.pdb.encode_pdb(molecule) else: self.pdb = utils.pdb.make_atoms_table(molecule) self.style = dict(Py3DmolViewer.default_style) if style: self.style.update(style) self.view = py3Dmol.view(width=width, height=height) self.view.addModel(self.pdb, "pdb") self.n_models = 1 self.view.setStyle(self.style) self.view.zoomTo()
[docs] def set_style(self, style: dict, model=None) -> None: """ Set the visualization style. Parameters ---------- style : dict The style to add. model : int A specific model to apply the style to. """ if model: if model > self.n_models: raise ValueError( f"Model {model} does not exist. The viewer contains {self.n_models} models." ) self.view.setStyle({"model": model}, style) else: self.view.setStyle(style) return self
[docs] def add(self, other, style=None): """ Add a second molecule to the viewer. Parameters ---------- other This may either be another Py3DmolViewer, a molecule object that can be converted to an RDKit molecule. """ if isinstance(other, Py3DmolViewer): self.view.addModel(other.pdb, "pdb") if style is None: style = other.style elif hasattr(other, "to_pdb"): pdb = utils.pdb.encode_pdb(other) self.view.addModel(pdb, "pdb") if style is None: style = self.style elif hasattr(other, "get_atoms"): pdb = utils.pdb.make_atoms_table(other) self.view.addModel(pdb, "pdb") if style is None: style = self.style elif isinstance(other, str): self.view.addModel(other, "pdb") if style is None: style = self.style else: raise ValueError( f"Unsupported molecule type: {other.__class__.__name__}. The input has to be a Py3DmolViewer or Molecule." ) self.view.setStyle({"model": self.n_models}, style) self.n_models += 1 return self
def __iadd__(self, other): return self.add(other) def __add__(self, other): return self.add(other)
[docs] def show(self): """ Show the molecule in a Jupyter notebook """ return self.view.show()
[docs] class NglViewer: """ View a molecule in 3D using the NGLView library. Parameters ---------- molecule The molecule to view. This may be any object that holds a biopython structure e.g. a Molecule, AtomGraph, or ResidueGraph. """ def __init__(self, molecule): try: import nglview except ImportError: nglview = None if nglview is None: raise ImportError( "NGLView is not available. Please install it with `pip install nglview` and be sure to use a compatible environment." ) if hasattr(molecule, "to_pdb"): self.pdb = utils.pdb.encode_pdb(molecule) elif molecule.__class__.__name__ in ("AtomGraph", "ResidueGraph"): self.pdb = utils.pdb.encode_pdb(molecule._molecule) else: raise ValueError( f"Unsupported molecule type: {molecule.__class__.__name__}" )
[docs] def show(self): """ Show the molecule in a Jupyter notebook """ import nglview import io f = io.StringIO(self.pdb) f.seek(0) fig = nglview.show_file(f, ext="pdb") return fig
def rgba_to_hex(rgba: tuple) -> str: """ Convert an rgba color to hex. Parameters ---------- rgba : tuple The rgba color to convert. Returns ------- str The hex color. """ return "#" + "".join([hex(int(i * 255))[2:] for i in rgba])
[docs] class PlotlyViewer3D: __continuous_colors__ = [ "navy", "blue", "teal", "green", "lightgreen", "yellow", "orange", "red", "crimson", "darkred", "brown", "purple", "pink", ] __atom_colors__ = { "C": "darkslategray", "O": "red", "H": "lightgray", "N": "blue", "S": "yellow", "P": "purple", "F": "green", "Cl": "green", "Br": "green", "I": "green", } def __init__(self) -> None: PlotlyViewer3D.reset(self) self._color_idx = 0 self.opacity = default_plotly_opacity self.size = default_plotly_marker_size self.bond_color = default_plotly_bond_color self.bond_linewidth = default_plotly_linewidth def _get_color(self): color = self.__continuous_colors__[self._color_idx] self._color_idx = (self._color_idx + 1) % len(self.__continuous_colors__) return color def _get_atom_color(self, atom): return self.__atom_colors__.get(atom.element.title(), "black")
[docs] def add(self, fig): """ Add a plotly figure to the viewer. """ if isinstance(fig, PlotlyViewer3D): data = fig.figure.data else: data = getattr(fig, "data", fig) self.figure.add_traces(data)
def __add__(self, fig): self.add(fig) return self
[docs] def show(self): self.figure.show()
[docs] def write_html(self, path): self.figure.write_html(path)
[docs] def reset(self, **kwargs): self.figure = go.Figure( layout=go.Layout( scene=dict( xaxis=dict( showgrid=False, showline=False, showticklabels=False, range=kwargs.pop("xlim", None), ), yaxis=dict( showgrid=False, showline=False, showticklabels=False, range=kwargs.pop("ylim", None), ), zaxis=dict( showgrid=False, showline=False, showticklabels=False, range=kwargs.pop("zlim", None), ), # aspectmode="cube", ), template="simple_white", ) ) return self
[docs] def viewbox(self, xlim=None, ylim=None, zlim=None): if isinstance(xlim, (int, float)): xlim = [-xlim, xlim] if isinstance(ylim, (int, float)): ylim = [-ylim, ylim] if isinstance(zlim, (int, float)): zlim = [-zlim, zlim] self.figure.update_layout( scene=dict( xaxis=dict(range=xlim), yaxis=dict(range=ylim), zaxis=dict(range=zlim), ) ) return self
[docs] def update_layout(self, **kwargs): self.figure.update_layout(**kwargs) return self
[docs] def draw_point( self, id: str, coord, color="black", opacity=1.0, size=5, showlegend=True, **kwargs, ): new = go.Scatter3d( x=[coord[0]], y=[coord[1]], z=[coord[2]], mode="markers", marker=dict(opacity=opacity, color=color, size=size), name=id, hoverinfo="name", showlegend=showlegend, **kwargs, ) self.add(new) return self
[docs] def draw_vector( self, id, coord_a, coord_b, color="black", linewidth=1.5, opacity=1.0, showlegend=True, hoverinfo: str = "skip", elongate: float = 1.0, legendgroup: str = None, ): new = go.Scatter3d( x=[coord_a[0], coord_a[0] + (coord_b[0] - coord_a[0]) * elongate], y=[coord_a[1], coord_a[1] + (coord_b[1] - coord_a[1]) * elongate], z=[coord_a[2], coord_a[2] + (coord_b[2] - coord_a[2]) * elongate], mode="lines", line=dict(color=color, width=linewidth), name=id, hoverinfo=hoverinfo, opacity=opacity, showlegend=showlegend, legendgroup=legendgroup, ) self.add(new) return self
[docs] def draw_edges( self, *edges, color="black", linewidth=1, opacity=1.0, elongate: float = 1.0, showlegend: bool = True, name: str = None, ): for edge in edges: self.draw_vector( f"{edge[0].id}-{edge[1].id}", edge[0].coord, edge[1].coord, color=color, linewidth=linewidth, opacity=opacity, elongate=elongate, showlegend=showlegend, legendgroup=name, ) return self
[docs] def draw_points( self, coords: list, ids: list = None, colors: list = None, opacities: list = None, sizes: list = None, showlegends: list = None, **kwargs, ): if ids is None: ids = [str(i) for i in range(len(coords))] if colors is None: colors = ["black" for _ in range(len(coords))] elif isinstance(colors, str): colors = [colors for _ in range(len(coords))] if opacities is None: opacities = [1.0 for _ in range(len(coords))] elif isinstance(opacities, (int, float)): opacities = [opacities for _ in range(len(coords))] if showlegends is None: showlegends = [True for _ in range(len(coords))] elif isinstance(showlegends, bool): showlegends = [showlegends for _ in range(len(coords))] if sizes is None: sizes = [self.size for _ in range(len(coords))] elif isinstance(sizes, (int, float)): sizes = [sizes for _ in range(len(coords))] for idx, coord in enumerate(coords): self.draw_point( ids[idx], coord, colors[idx], opacities[idx], sizes[idx], showlegends[idx], **kwargs, ) return self
[docs] def highlight_atoms( self, *atoms, names: list = None, colors: list = None, opacity: float = 1, size: int = 10, showlegend: bool = True, hoverinfo: str = "name", ): if colors is not None and not isinstance(colors, list): colors = [colors] * len(atoms) atom_scatter = [] for idx, atom in enumerate(atoms): atom = self._src.get_atom(atom) if colors is None: color = self.__atom_colors__.get(atom.element.title(), "black") else: color = colors[idx] if names is None: name = repr(atom) else: name = names[idx] new = go.Scatter3d( x=[atom.coord[0]], y=[atom.coord[1]], z=[atom.coord[2]], mode="markers", marker=dict(color=color, opacity=opacity, size=size), hoverinfo=hoverinfo, showlegend=showlegend, name=name, legendgroup="Highlighted", ) atom_scatter.append(new) self.add(atom_scatter) return self
[docs] def highlight_residues( self, *residues, bond_colors: list = None, opacity: float = 0.6, linewidth: float = 2, draw_atoms: bool = False, ): if not isinstance(bond_colors, list): bond_colors = [bond_colors] * len(residues) residue_traces = [] for idx, residue in enumerate(residues): residue = self._src.get_residue(residue) atoms = self._atom_df[self._atom_df["residue_serial"] == residue.id[1]] bonds = self._bond_df[ self._bond_df["a"].isin(atoms.index) & self._bond_df["b"].isin(atoms.index) ] if bond_colors: bonds.loc[:, "bond_color"] = bond_colors[idx] bonds.loc[:, "bond_order"] = bonds["bond_order"] + linewidth _op = self.opacity self.opacity = opacity fig = self._setup_fig(atoms, bonds, draw_atoms=draw_atoms) residue_traces.extend(fig.data) self.opacity = _op bonds.loc[:, "bond_order"] = bonds["bond_order"] - linewidth self.add(residue_traces) return self
[docs] def draw_atoms( self, *atoms, names: list = None, colors: list = None, opacity: float = None, showlegend: bool = True, hoverinfo: str = "name", ): if not opacity: opacity = self.opacity self.highlight_atoms( *atoms, names=names, colors=colors, opacity=opacity, showlegend=showlegend, hoverinfo=hoverinfo, ) return self
[docs] def draw_residues( self, *residues, bond_colors: list = None, opacity: float = None, linewidth: float = 2, ): if not opacity: opacity = self.opacity self.highlight_residues( *residues, bond_colors=bond_colors, opacity=opacity, linewidth=linewidth ) return self
[docs] def draw_atom(self, atom, id=None, color=None, opacity=None, size=None): if color is None: color = self.__atom_colors__.get(atom.element) if opacity is None: opacity = min(1, self.opacity * 2) if id is None: id = str(atom.id) + " " + str(atom.serial_number) self.draw_point( id, atom.coord, color, opacity, size, ) return self
[docs] def draw_bond( self, atom_a, atom_b, color="black", linewidth=1.5, showlegend=True, elongate: float = 1.0, ): self.draw_vector( f"{atom_a.id}-{atom_b.id}", atom_a.coord, atom_b.coord, color, linewidth, showlegend, elongate=elongate, ) return self
[docs] class MoleculeViewer3D(PlotlyViewer3D):
[docs] def make_df(self, mol) -> tuple: _atom_df = { "x": [atom.coord[0] for atom in mol.get_atoms()], "y": [atom.coord[1] for atom in mol.get_atoms()], "z": [atom.coord[2] for atom in mol.get_atoms()], "atom_id": [atom.id for atom in mol.get_atoms()], "atom_serial": [atom.serial_number for atom in mol.get_atoms()], "atom_element": [atom.element.title() for atom in mol.get_atoms()], "residue_serial": [atom.get_parent().id[1] for atom in mol.get_atoms()], "residue_name": [atom.get_parent().resname for atom in mol.get_atoms()], "chain_id": [atom.get_parent().get_parent().id for atom in mol.get_atoms()], } _atom_df = pd.DataFrame(_atom_df) _atom_df.set_index("atom_serial", drop=False, inplace=True) _bond_df = { "a": [i[0].serial_number for i in mol.get_bonds()], "b": [i[1].serial_number for i in mol.get_bonds()], "bond_color": [self.bond_color for i in mol.get_bonds()], "bond_order": [self.bond_linewidth * i.order for i in mol.get_bonds()], } _bond_df = pd.DataFrame(_bond_df) return _atom_df, _bond_df
[docs] def setup(self, draw_atoms=True): """ Setup the viewer with the molecule. """ self.add(self._setup_fig(self._atom_df, self._bond_df, draw_atoms=draw_atoms))
def _setup_fig(self, atom_df, bond_df, draw_atoms=True): if not draw_atoms: fig = go.Figure() else: atom_df["__marker_size"] = self.size fig = px.scatter_3d( atom_df, x="x", y="y", z="z", color="atom_element", color_discrete_map=self.__atom_colors__, opacity=self.opacity, size="__marker_size", hover_data=[ "atom_id", "atom_serial", "residue_serial", "residue_name", "chain_id", ], template="none", ) bonds = [] for i, row in bond_df.iterrows(): a1 = atom_df.loc[row["a"]] a2 = atom_df.loc[row["b"]] new = go.Scatter3d( x=[a1["x"], a2["x"]], y=[a1["y"], a2["y"]], z=[a1["z"], a2["z"]], mode="lines", line=dict( color=row["bond_color"], width=row["bond_order"] ** 2, # opacity=min(1, self.opacity * 2), ), hoverinfo="skip", showlegend=False, ) bonds.append(new) fig.add_traces(bonds) return fig
[docs] def reset(self): self.figure = self._setup_fig(self._atom_df, self._bond_df)
[docs] def rainbow(self): """ Colorize the residues in rainbow colors """ self.highlight_residues( *self._src.get_residues(), bond_colors=[self._get_color() for i in self._src.get_residues()], )
[docs] class AtomGraphViewer3D(PlotlyViewer3D):
[docs] def make_df(self, graph): _atom_df = { "x": [atom.coord[0] for atom in graph.nodes], "y": [atom.coord[1] for atom in graph.nodes], "z": [atom.coord[2] for atom in graph.nodes], "atom_id": [atom.id for atom in graph.nodes], "atom_serial": [atom.serial_number for atom in graph.nodes], "atom_element": [atom.element.title() for atom in graph.nodes], "residue_serial": [atom.get_parent().id[1] for atom in graph.nodes], "residue_name": [atom.get_parent().resname for atom in graph.nodes], "chain_id": [atom.get_parent().get_parent().id for atom in graph.nodes], } _atom_df = pd.DataFrame(_atom_df) _atom_df.set_index("atom_serial", drop=False, inplace=True) bond_orders = nx.get_edge_attributes(graph, "bond_order") _bond_df = { "a": [i[0].serial_number for i in bond_orders.keys()], "b": [i[1].serial_number for i in bond_orders.keys()], "bond_color": [self.bond_color for i in bond_orders.keys()], "bond_order": [self.bond_linewidth * i for i in bond_orders.values()], } _bond_df = pd.DataFrame(_bond_df) return _atom_df, _bond_df
def _setup_fig(self, atom_df, bond_df): fig = px.scatter_3d( atom_df, x="x", y="y", z="z", color="atom_element", color_discrete_map=self.__atom_colors__, opacity=self.opacity, hover_data=[ "atom_id", "atom_serial", "residue_serial", "residue_name", "chain_id", ], template="none", ) bonds = [] for i, row in bond_df.iterrows(): a1 = atom_df.loc[row["a"]] a2 = atom_df.loc[row["b"]] new = go.Scatter3d( x=[a1["x"], a2["x"]], y=[a1["y"], a2["y"]], z=[a1["z"], a2["z"]], mode="lines", line=dict( color=row["bond_color"], width=row["bond_order"] ** 2, ), opacity=min(1, self.opacity * 2), hoverinfo="skip", showlegend=False, ) bonds.append(new) fig.add_traces(bonds) return fig
[docs] class ResidueGraphViewer3D(PlotlyViewer3D):
[docs] def make_df(self, graph): _atom_df = { "_id": [atom.get_id() for atom in graph.nodes], "x": [atom.coord[0] for atom in graph.nodes], "y": [atom.coord[1] for atom in graph.nodes], "z": [atom.coord[2] for atom in graph.nodes], "id": [str(atom.id) for atom in graph.nodes], "serial": [atom.serial_number for atom in graph.nodes], "element_or_resname": [ getattr(atom, "element", getattr(atom, "resname", "")).title() for atom in graph.nodes ], "parent_id": [str(atom.get_parent().id) for atom in graph.nodes], "parent_serial": [ getattr(atom.get_parent(), "serial_number", -1) for atom in graph.nodes ], } _atom_df = pd.DataFrame(_atom_df) _atom_df.set_index("_id", drop=False, inplace=True) _bond_df = { "a": [i[0].get_id() for i in graph.edges], "b": [i[1].get_id() for i in graph.edges], "bond_color": [self.bond_color for i in graph.edges], "bond_order": [self.bond_linewidth for i in graph.edges], } _bond_df = pd.DataFrame(_bond_df) return _atom_df, _bond_df
def _setup_fig(self, atom_df, bond_df): fig = px.scatter_3d( atom_df, x="x", y="y", z="z", color="element_or_resname", color_discrete_map=self.__atom_colors__, opacity=self.opacity, hover_data=[ "id", "serial", "parent_serial", "parent_id", ], template="none", ) bonds = [] for i, row in bond_df.iterrows(): a1 = atom_df.loc[row["a"]] a2 = atom_df.loc[row["b"]] new = go.Scatter3d( x=[a1["x"], a2["x"]], y=[a1["y"], a2["y"]], z=[a1["z"], a2["z"]], mode="lines", line=dict( color=row["bond_color"], width=row["bond_order"] ** 2, ), opacity=min(1, self.opacity * 2), hoverinfo="skip", showlegend=False, ) bonds.append(new) fig.add_traces(bonds) return fig
[docs] def rainbow(self): for node in self._src.nodes: if getattr(node, "element", None) is not None: continue self.draw_atom(node, color=self._get_color())
if __name__ == "__main__": import buildamol as bam bam.load_sugars() man = bam.molecule("MAN") man = man % "14bb" * 2 man.change_element(1, "Au") v = Chem2DViewer(man, drawer="png") v.label_atoms(lambda a: a.GetSymbol()) v.highlight_atoms( man.atoms, color=lambda a: (0, 0, 1, ((a.mass or 0) / man.mass) ** 0.5), # radius=lambda a: 0.2 + 0.2 * (a.mass or 0), ) v.highlight_bonds(man.get_bonds(man.get_residue(1)), color="blue") v.highlight_residues(1, color=(1, 1, 0, 0.3)) v.set_colors( { 6: "pink", } ) v.show(draw_hydrogens=False, linewidth=5, background="white") pass # v = MoleculeViewer3D() # v.link(man) # v.setup() # v.show() # man.repeat(5, "14bb") # # v = Chem2DViewer(man) # # v.show() # v = MoleculeViewer3D() # v.link(man) # v.setup() # v.show() # atoms = man.atoms[:10] # v.draw_atoms(*atoms) # v.show() # manv = ResidueGraphViewer3D() # manv.link(man.make_residue_graph(detailed=True)) # # manv.highlight_residues(1, bond_colors=["red"]) # manv.rainbow() # manv.show() # pass