Source code for biobuild.utils.visual

"""
Visualization auxiliary functions
"""

import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import networkx as nx
import matplotlib.colors as colors

import biobuild.utils.auxiliary as aux

try:
    import nglview
except:
    nglview = None

try:
    import py3Dmol
except:
    py3Dmol = None


Draw = aux.Draw
Chem = aux.Chem

default_plotly_opacity = 1.0
"""
The default opacity for plotly-based visualizations.
"""

default_plotly_bond_color = "black"
"""
The default color for plotly-based bond visualizations.
"""

default_plotly_linewidth = 1.2
"""
The default linewidth for plotly-based bond visualizations.
"""


[docs] class Chem2DViewer: """ View a molecule in 2D using the RDKit library. Parameters ---------- molecule The molecule to view. This may be any object that holds a biopython structure e.g. a Molecule, AtomGraph, or ResidueGraph. """ def __init__(self, molecule, highlight_color: str = "cyan"): if Chem is None: raise ImportError( "rdkit is not available. Please install it and be sure to use a compatible environment." ) if hasattr(molecule, "to_rdkit"): mol = molecule.to_rdkit() elif molecule.__class__.__name__ in ("AtomGraph", "ResidueGraph"): mol = molecule._molecule.to_rdkit() elif "Chem" in str(molecule.__class__.mro()[0]): mol = molecule else: raise ValueError( f"Unsupported molecule type: {molecule.__class__.__name__}" ) mol.RemoveAllConformers() self.mol = mol self._atoms_to_highlight = [] self._bonds_to_highlight = [] self.highlight_color = highlight_color
[docs] def draw(self, draw_hydrogens: bool = False, width: int = 1000, height: int = 500): """ Generate the 2D image. Parameters ---------- draw_hydrogens : bool Whether to draw hydrogens. width : int The width of the image in pixels. height : int The height of the image in pixels. """ if not draw_hydrogens: mol = Chem.rdmolops.RemoveHs(self.mol) else: mol = self.mol return Draw.MolToImage( mol, size=(width, height), highlightAtoms=self._atoms_to_highlight, highlightBonds=self._bonds_to_highlight, highlightColor=colors.to_rgb(self.highlight_color), )
[docs] def show(self, draw_hydrogens: bool = False): """ Show the molecule Parameters ---------- draw_hydrogens : bool Whether to draw hydrogens. """ return self.draw(draw_hydrogens=draw_hydrogens).show()
[docs] def highlight_atoms(self, *atoms): """ Highlight atoms in the molecule. Parameters ---------- atoms : list The Biobuild Atoms to highlight. """ self._atoms_to_highlight.extend(atom.serial_number for atom in atoms)
[docs] def highlight_bonds(self, *bonds): """ Highlight bonds in the molecule. Parameters ---------- bonds : list The bonds (tuples of Biobuild Atoms) to highlight. """ self._bonds_to_highlight.extend( self.mol.GetBondBetweenAtoms(a.serial_number, b.serial_number).GetIdx() for a, b in bonds )
[docs] class Py3DmolViewer: """ View a molecule in 3D using the py3Dmol library. Attributes ---------- view : py3Dmol.view The py3Dmol view object. Parameters ---------- molecule The molecule to view. This may be any object that holds a biopython structure e.g. a Molecule, AtomGraph, or ResidueGraph. width : int The width of the viewer in pixels. height : int The height of the viewer in pixels. """ def __init__(self, molecule, width: int = 500, height: int = 500) -> None: if py3Dmol is None or Chem is None: raise ImportError( "py3Dmol and/or rdkit are not available. Please install them and be sure to use a compatible (Jupyter) environment." ) if hasattr(molecule, "to_rdkit"): mol = molecule.to_rdkit() elif molecule.__class__.__name__ in ("AtomGraph", "ResidueGraph"): mol = molecule._molecule.to_rdkit() elif "Chem" in str(molecule.__class__.mro()[0]): mol = molecule else: raise ValueError( f"Unsupported molecule type: {molecule.__class__.__name__}" ) self.view = py3Dmol.view(width=width, height=height) self.view.addModel(Chem.MolToMolBlock(mol), "sdf") self.view.setStyle({"stick": {}})
[docs] def show(self): """ Show the molecule in a Jupyter notebook """ return self.view.show()
[docs] class NglViewer: """ View a molecule in 3D using the NGLView library. Parameters ---------- molecule The molecule to view. This may be any object that holds a biopython structure e.g. a Molecule, AtomGraph, or ResidueGraph. """ def __init__(self, molecule): if nglview is None: raise ImportError( "NGLView is not available. Please install it with `pip install nglview` and be sure to use a compatible environment." ) if hasattr(molecule, "to_biopython"): self.structure = molecule.to_biopython() elif molecule.__class__.__name__ in ("AtomGraph", "ResidueGraph"): self.structure = molecule.structure.to_biopython() else: raise ValueError( f"Unsupported molecule type: {molecule.__class__.__name__}" )
[docs] def show(self): """ Show the molecule in a Jupyter notebook """ fig = nglview.show_biopython(self.structure) return fig
def rgba_to_hex(rgba: tuple) -> str: """ Convert an rgba color to hex. Parameters ---------- rgba : tuple The rgba color to convert. Returns ------- str The hex color. """ return "#" + "".join([hex(int(i * 255))[2:] for i in rgba])
[docs] class PlotlyViewer3D: __continuous_colors__ = [ "navy", "blue", "teal", "green", "lightgreen", "yellow", "orange", "red", "crimson", "darkred", "brown", "purple", "pink", ] __atom_colors__ = { "C": "darkslategray", "O": "red", "H": "lightgray", "N": "blue", "S": "yellow", "P": "purple", "F": "green", "Cl": "green", "Br": "green", "I": "green", } def __init__(self) -> None: PlotlyViewer3D.reset(self) self._color_idx = 0 self.opacity = default_plotly_opacity self.bond_color = default_plotly_bond_color self.bond_linewidth = default_plotly_linewidth def _get_color(self): color = self.__continuous_colors__[self._color_idx] self._color_idx = (self._color_idx + 1) % len(self.__continuous_colors__) return color def _get_atom_color(self, atom): return self.__atom_colors__.get(atom.element.title(), "black")
[docs] def add(self, fig): """ Add a plotly figure to the viewer. """ data = getattr(fig, "data", fig) self.figure.add_traces(data)
[docs] def show(self): self.figure.show()
[docs] def write_html(self, path): self.figure.write_html(path)
[docs] def reset(self): self.figure = go.Figure( layout=go.Layout( scene=dict( xaxis=dict(showgrid=False, showline=False, showticklabels=False), yaxis=dict(showgrid=False, showline=False, showticklabels=False), zaxis=dict(showgrid=False, showline=False, showticklabels=False), # aspectmode="cube", ), template="simple_white", ) )
[docs] def draw_point(self, id: str, coord, color="black", opacity=1.0, showlegend=True): new = go.Scatter3d( x=[coord[0]], y=[coord[1]], z=[coord[2]], mode="markers", marker=dict(opacity=opacity, color=color), name=id, hoverinfo="name", showlegend=showlegend, ) self.add(new)
[docs] def draw_vector( self, id, coord_a, coord_b, color="black", linewidth=1.5, opacity=1.0, showlegend=True, hoverinfo: str = "skip", elongate: float = 1.0, legendgroup: str = None, ): new = go.Scatter3d( x=[coord_a[0], coord_a[0] + (coord_b[0] - coord_a[0]) * elongate], y=[coord_a[1], coord_a[1] + (coord_b[1] - coord_a[1]) * elongate], z=[coord_a[2], coord_a[2] + (coord_b[2] - coord_a[2]) * elongate], mode="lines", line=dict(color=color, width=linewidth), name=id, hoverinfo=hoverinfo, opacity=opacity, showlegend=showlegend, legendgroup=legendgroup, ) self.add(new)
[docs] def draw_edges( self, *edges, color="black", linewidth=1, opacity=1.0, elongate: float = 1.0, showlegend: bool = True, name: str = None, ): for edge in edges: self.draw_vector( f"{edge[0].id}-{edge[1].id}", edge[0].coord, edge[1].coord, color=color, linewidth=linewidth, opacity=opacity, elongate=elongate, showlegend=showlegend, legendgroup=name, )
[docs] def draw_points( self, ids: list, coords: list, colors: list = None, opacities: list = None, showlegends: list = None, ): if colors is None: colors = ["black" for _ in range(len(coords))] if opacities is None: opacities = [1.0 for _ in range(len(coords))] if showlegends is None: showlegends = [True for _ in range(len(coords))] for idx, coord in enumerate(coords): self.draw_point( ids[idx], coord, colors[idx], opacities[idx], showlegends[idx] )
[docs] def highlight_atoms( self, *atoms, names: list = None, colors: list = None, opacity: float = 1, showlegend: bool = True, hoverinfo: str = "name", ): if colors is not None and not isinstance(colors, list): colors = [colors] * len(atoms) atom_scatter = [] for idx, atom in enumerate(atoms): atom = self._src.get_atom(atom) if colors is None: color = self.__atom_colors__.get(atom.element.title(), "black") else: color = colors[idx] if names is None: name = repr(atom) else: name = names[idx] new = go.Scatter3d( x=[atom.coord[0]], y=[atom.coord[1]], z=[atom.coord[2]], mode="markers", marker=dict(color=color, opacity=opacity, size=10), hoverinfo=hoverinfo, showlegend=showlegend, name=name, legendgroup="Highlighted", ) atom_scatter.append(new) self.add(atom_scatter)
[docs] def highlight_residues( self, *residues, bond_colors: list = None, opacity: float = 0.6, linewidth: float = 2, ): if not isinstance(bond_colors, list): bond_colors = [bond_colors] * len(residues) residue_traces = [] for idx, residue in enumerate(residues): residue = self._src.get_residue(residue) atoms = self._atom_df[self._atom_df["residue_serial"] == residue.id[1]] bonds = self._bond_df[ self._bond_df["a"].isin(atoms.index) & self._bond_df["b"].isin(atoms.index) ] if bond_colors: bonds.loc[:, "bond_color"] = bond_colors[idx] bonds.loc[:, "bond_order"] = bonds["bond_order"] + linewidth _op = self.opacity self.opacity = opacity fig = self._setup_fig(atoms, bonds) residue_traces.extend(fig.data) self.opacity = _op bonds.loc[:, "bond_order"] = bonds["bond_order"] - linewidth self.add(residue_traces)
[docs] def draw_atoms( self, *atoms, names: list = None, colors: list = None, opacity: float = None, showlegend: bool = True, hoverinfo: str = "name", ): if not opacity: opacity = self.opacity self.highlight_atoms( *atoms, names=names, colors=colors, opacity=opacity, showlegend=showlegend, hoverinfo=hoverinfo, )
[docs] def draw_residues( self, *residues, bond_colors: list = None, opacity: float = None, linewidth: float = 2, ): if not opacity: opacity = self.opacity self.highlight_residues( *residues, bond_colors=bond_colors, opacity=opacity, linewidth=linewidth )
[docs] def draw_atom(self, atom, id=None, color=None, opacity=None): if color is None: color = self.__atom_colors__.get(atom.element) if opacity is None: opacity = min(1, self.opacity * 2) if id is None: id = str(atom.id) + " " + str(atom.serial_number) self.draw_point( id, atom.coord, color, opacity, )
[docs] def draw_bond( self, atom_a, atom_b, color="black", linewidth=1.5, showlegend=True, elongate: float = 1.0, ): self.draw_vector( f"{atom_a.id}-{atom_b.id}", atom_a.coord, atom_b.coord - atom_a.coord, color, linewidth, showlegend, elongate=elongate, )
[docs] class MoleculeViewer3D(PlotlyViewer3D):
[docs] def make_df(self, mol) -> tuple: _atom_df = { "x": [atom.coord[0] for atom in mol.get_atoms()], "y": [atom.coord[1] for atom in mol.get_atoms()], "z": [atom.coord[2] for atom in mol.get_atoms()], "atom_id": [atom.id for atom in mol.get_atoms()], "atom_serial": [atom.serial_number for atom in mol.get_atoms()], "atom_element": [atom.element.title() for atom in mol.get_atoms()], "residue_serial": [atom.get_parent().id[1] for atom in mol.get_atoms()], "residue_name": [atom.get_parent().resname for atom in mol.get_atoms()], "chain_id": [atom.get_parent().get_parent().id for atom in mol.get_atoms()], } _atom_df = pd.DataFrame(_atom_df) _atom_df.set_index("atom_serial", drop=False, inplace=True) bonds = nx.get_edge_attributes(mol._AtomGraph, "bond_order") _bond_df = { "a": [i[0].serial_number for i in bonds.keys()], "b": [i[1].serial_number for i in bonds.keys()], "bond_color": [self.bond_color for i in bonds.keys()], "bond_order": [self.bond_linewidth * i for i in bonds.values()], } _bond_df = pd.DataFrame(_bond_df) return _atom_df, _bond_df
def _setup_fig(self, atom_df, bond_df): fig = px.scatter_3d( atom_df, x="x", y="y", z="z", color="atom_element", color_discrete_map=self.__atom_colors__, opacity=self.opacity, hover_data=[ "atom_id", "atom_serial", "residue_serial", "residue_name", "chain_id", ], template="none", ) bonds = [] for i, row in bond_df.iterrows(): a1 = atom_df.loc[row["a"]] a2 = atom_df.loc[row["b"]] new = go.Scatter3d( x=[a1["x"], a2["x"]], y=[a1["y"], a2["y"]], z=[a1["z"], a2["z"]], mode="lines", line=dict( color=row["bond_color"], width=row["bond_order"] ** 2, # opacity=min(1, self.opacity * 2), ), hoverinfo="skip", showlegend=False, ) bonds.append(new) fig.add_traces(bonds) return fig
[docs] def reset(self): self.figure = self._setup_fig(self._atom_df, self._bond_df)
[docs] def rainbow(self): """ Colorize the residues in rainbow colors """ self.highlight_residues( *self._src.get_residues(), bond_colors=[self._get_color() for i in self._src.get_residues()], )
[docs] class AtomGraphViewer3D(PlotlyViewer3D):
[docs] def make_df(self, graph): _atom_df = { "x": [atom.coord[0] for atom in graph.nodes], "y": [atom.coord[1] for atom in graph.nodes], "z": [atom.coord[2] for atom in graph.nodes], "atom_id": [atom.id for atom in graph.nodes], "atom_serial": [atom.serial_number for atom in graph.nodes], "atom_element": [atom.element.title() for atom in graph.nodes], "residue_serial": [atom.get_parent().id[1] for atom in graph.nodes], "residue_name": [atom.get_parent().resname for atom in graph.nodes], "chain_id": [atom.get_parent().get_parent().id for atom in graph.nodes], } _atom_df = pd.DataFrame(_atom_df) _atom_df.set_index("atom_serial", drop=False, inplace=True) bond_orders = nx.get_edge_attributes(graph, "bond_order") _bond_df = { "a": [i[0].serial_number for i in bond_orders.keys()], "b": [i[1].serial_number for i in bond_orders.keys()], "bond_color": [self.bond_color for i in bond_orders.keys()], "bond_order": [self.bond_linewidth * i for i in bond_orders.values()], } _bond_df = pd.DataFrame(_bond_df) return _atom_df, _bond_df
def _setup_fig(self, atom_df, bond_df): fig = px.scatter_3d( atom_df, x="x", y="y", z="z", color="atom_element", color_discrete_map=self.__atom_colors__, opacity=self.opacity, hover_data=[ "atom_id", "atom_serial", "residue_serial", "residue_name", "chain_id", ], template="none", ) bonds = [] for i, row in bond_df.iterrows(): a1 = atom_df.loc[row["a"]] a2 = atom_df.loc[row["b"]] new = go.Scatter3d( x=[a1["x"], a2["x"]], y=[a1["y"], a2["y"]], z=[a1["z"], a2["z"]], mode="lines", line=dict( color=row["bond_color"], width=row["bond_order"] ** 2, ), opacity=min(1, self.opacity * 2), hoverinfo="skip", showlegend=False, ) bonds.append(new) fig.add_traces(bonds) return fig
[docs] class ResidueGraphViewer3D(PlotlyViewer3D):
[docs] def make_df(self, graph): _atom_df = { "_id": [atom.get_id() for atom in graph.nodes], "x": [atom.coord[0] for atom in graph.nodes], "y": [atom.coord[1] for atom in graph.nodes], "z": [atom.coord[2] for atom in graph.nodes], "id": [str(atom.id) for atom in graph.nodes], "serial": [atom.serial_number for atom in graph.nodes], "element_or_resname": [ getattr(atom, "element", getattr(atom, "resname", "")).title() for atom in graph.nodes ], "parent_id": [str(atom.get_parent().id) for atom in graph.nodes], "parent_serial": [ getattr(atom.get_parent(), "serial_number", -1) for atom in graph.nodes ], } _atom_df = pd.DataFrame(_atom_df) _atom_df.set_index("_id", drop=False, inplace=True) _bond_df = { "a": [i[0].get_id() for i in graph.edges], "b": [i[1].get_id() for i in graph.edges], "bond_color": [self.bond_color for i in graph.edges], "bond_order": [self.bond_linewidth for i in graph.edges], } _bond_df = pd.DataFrame(_bond_df) return _atom_df, _bond_df
def _setup_fig(self, atom_df, bond_df): fig = px.scatter_3d( atom_df, x="x", y="y", z="z", color="element_or_resname", color_discrete_map=self.__atom_colors__, opacity=self.opacity, hover_data=[ "id", "serial", "parent_serial", "parent_id", ], template="none", ) bonds = [] for i, row in bond_df.iterrows(): a1 = atom_df.loc[row["a"]] a2 = atom_df.loc[row["b"]] new = go.Scatter3d( x=[a1["x"], a2["x"]], y=[a1["y"], a2["y"]], z=[a1["z"], a2["z"]], mode="lines", line=dict( color=row["bond_color"], width=row["bond_order"] ** 2, ), opacity=min(1, self.opacity * 2), hoverinfo="skip", showlegend=False, ) bonds.append(new) fig.add_traces(bonds) return fig
[docs] def rainbow(self): for node in self._src.nodes: if getattr(node, "element", None) is not None: continue self.draw_atom(node, color=self._get_color())
if __name__ == "__main__": import biobuild as bb bb.load_sugars() man = bb.molecule("MAN") man.repeat(5, "14bb") v = Chem2DViewer(man) v.show() # v = MoleculeViewer3D() # v.link(man) # atoms = man.atoms[:10] # v.draw_atoms(*atoms) # v.show() # manv = ResidueGraphViewer3D() # manv.link(man.make_residue_graph(detailed=True)) # # manv.highlight_residues(1, bond_colors=["red"]) # manv.rainbow() # manv.show() # pass