Source code for buildamol.graphs.base_graph

"""
The basic Class for Molecular Graphs
"""

from abc import abstractmethod
import warnings

import Bio.PDB as bio
import networkx as nx
import numpy as np

# from scipy.spatial.transform import Rotation
import buildamol.structural.base as base


[docs] class BaseGraph(nx.Graph): """ The basic class for molecular graphs """ def __init__(self, id, bonds: list): super().__init__(bonds) self.id = id self._structure = None self._molecule = None self._neighborhood = None self._locked_edges = set() self._structure_was_searched = False self.__descendent_cache = {} self.__last_cache_size = len(self.nodes) @property def structure(self): """ Returns the underlying `bio.PDB.Structure` object """ if not self._structure_was_searched: self._structure = self._get_structure() self._structure_was_searched = True return self._structure @property def chains(self): """ Returns the chains in the molecule """ if not self.structure: return return list(self.structure.get_chains()) @property def residues(self): """ Returns the residues in the molecule """ if not self.structure: return return list(self.structure.get_residues()) @property def atoms(self): """ Returns the atoms in the molecule """ if not self.structure: return return list(self.structure.get_atoms()) @property def central_node(self): """ Returns the central most node of the graph. This is computed based on the mean of all node coordinates. """ # get the central node center = np.mean([i.coord for i in self.nodes]) # get the node closest to the center root_node = min(self.nodes, key=lambda x: np.linalg.norm(x.coord - center)) return root_node @property def nodes_in_cycles(self) -> set: """ Returns the nodes in cycles """ cycles = nx.cycle_basis(self) if len(cycles) == 0: return set() return set.union(*[set(i) for i in cycles]) @property def bonds(self): """ Returns the bonds in the molecule """ return list(self.edges)
[docs] def show(self): """ Show the graph """ self.draw().show()
[docs] def draw(self): """ Prepare a 3D view of the graph but do not show it yet Returns ------- PlotlyViewer3D A 3D viewer """ raise NotImplementedError
[docs] @abstractmethod def get_neighbors(self, node, n: int = 1, mode="upto"): """ Get the neighbors of a node Parameters ---------- node The target node n : int, optional The number of edges to separate the node from its neighbors. mode : str, optional The mode to use for getting the neighbors, by default "upto" - "upto": get all neighbors up to a distance of `n` edges - "exact": get all neighbors exactly `n` edges away Returns ------- set The neighbors of the node """ raise NotImplementedError
[docs] def get_descendants(self, node_1, node_2, use_cache: bool = True): """ Get all descendant nodes that come after a specific edge defined in the direction from node1 to node2 (i.e. get all nodes that come after node2). This method is directed in contrast to the `get_neighbors()` method, which will get all neighboring nodes of an anchor node irrespective of direction. Parameters ---------- node_1, node_2 The nodes that define the edge use_cache: bool, optional Whether to use the cache for the descendants, by default True. If True and the graph has not received new nodes since the last time the cache was updated, a simple lookup is performed. Otherwise the descendant nodes are recursively calculated again. Returns ------- set The descendant nodes Examples -------- In case of this graph: .. code-block:: A---B---C---D---E \\ F---H | G ``` A---B---C---D---E \\ F---H | G ``` >>> graph.get_descendants("B", "C") {"D", "E"} >>> graph.get_descendants("B", "F") {"H", "G"} >>> graph.get_descendants("B", "A") set() # because in this direction there are no other nodes """ if node_1 is node_2: raise KeyError("Cannot get descendants of a node with itself!") if use_cache: _seen = self.__descendent_cache.get((node_1, node_2)) if _seen: size, _seen = _seen if size == len(self.nodes): return _seen __all_nodes = set(self.nodes) # if use_cache: # _seen = self.__descendent_cache.get((node_2, node_1)) # if _seen: # size, _seen = _seen # if size == len(self.nodes): # return __all_nodes - _seen - {node_1, node_2} _seen = set((node_1, node_2)) _new_neighs = {node_2} descendants = set() while _new_neighs: neigh = _new_neighs.pop() _seen.add(neigh) descendants.clear() for d in self.adj[neigh]: if d in _seen: continue _desc_from_cache = self.__descendent_cache.get((neigh, d)) if _desc_from_cache: _desc_from_cache = _desc_from_cache[1] _seen.add(d) _seen.update(_desc_from_cache) else: descendants.add(d) _new_neighs.update(descendants) _new_neighs.difference_update(_seen) _seen.difference_update((node_1, node_2)) self.__descendent_cache[(node_1, node_2)] = len(self.nodes), _seen self.__descendent_cache[(node_2, node_1)] = len(self.nodes), ( __all_nodes - _seen - {node_1, node_2} ) # v = self._molecule.draw() # v.draw_vector("edge", node_1.coord, node_2.coord, elongate=1.3, linewidth=4, color="limegreen") # for i in _seen: # v.draw_point("n", i.coord, color="grey", showlegend=False) # v.show() return _seen
[docs] def get_ancestors(self, node_1, node_2, use_cache: bool = True): """ Get all ancestor nodes that come before a specific edge defined in the direction from node1 to node2 (i.e. get all nodes that comebefore node1). This method is directed in contrast to the `get_neighbors()` method, which will get all neighboring nodes of an anchor node irrespective of direction. Parameters ---------- node_1, node_2 The nodes that define the edge use_cache: bool, optional Whether to use the cache for the ancestors, by default True. If True and the graph has not received new nodes since the last time the cache was updated, a simple lookup is performed. Otherwise the ancestor nodes are recursively calculated again. Returns ------- set The ancestor nodes Examples -------- In case of this graph: .. code-block:: A---B---C---D---E \\ F---H | G ``` A---B---C---D---E \\ F---H | G ``` >>> graph.get_ancestors("B", "C") {"A", "F", "G", "H"} >>> graph.get_ancestors("F", "B") {"H", "G"} >>> graph.get_ancestors("A", "B") set() # because in this direction there are no other nodes """ return self.get_descendants(node_2, node_1, use_cache)
[docs] def search_by_constraints(self, constraints: list) -> list: """ Search for neighboring nodes that match a set of constraints. Parameters ---------- constraints : list A list of constraint functions, where each entry represents the constraints for a specific node. All constraints must be satisfied for all nodes in the neighborhood to be considered a match. Returns ------- list A list of dictionaries where each dictionary contains nodes that match the constraints. The keys represent the constraint index which the nodes satisfy and the values are the nodes themselves. """ raise NotImplementedError
[docs] def find_cycles(self) -> list: """ Find all cycles in the graph Returns ------- list A list of cycles in the graph, where each cycle is a list of nodes """ return nx.cycle_basis(self)
[docs] def find_nodes_in_cycles(self) -> set: """ Find all nodes that are in cycles Returns ------- set The nodes in cycles """ cycles = [set(i) for i in nx.cycle_basis(self)] if len(cycles) == 0: return set() return set.union(*cycles)
[docs] def find_edges_in_cycles(self) -> set: """ Find all edges that connect nodes in cycles, where both nodes are in the same cycle Returns ------- set The edges in cycles """ nodes_in_cycles = self.find_cycles() return set( (i, j) for i, j in self.edges if self.in_same_cycle(i, j, cycles=nodes_in_cycles) )
[docs] def find_rotatable_edges( self, root_node=None, min_descendants: int = 1, min_ancestors: int = 1, max_descendants: int = None, max_ancestors: int = None, ) -> list: """ Find all edges in the graph that are rotatable (i.e. not locked, single, and not in a circular constellation). You can also filter and direct the edges. Parameters ---------- root_node A root node by which to direct the edges (closer to further). min_descendants: int, optional The minimum number of descendants that an edge must have to be considered rotatable. min_ancestors: int, optional The minimum number of ancestors that an edge must have to be considered rotatable. max_descendants: int, optional The maximum number of descendants that an edge must have to be considered rotatable. max_ancestors: int, optional The maximum number of ancestors that an edge must have to be considered rotatable. Returns ------- list A list of rotatable edges """ if not max_descendants: max_descendants = np.inf if not max_ancestors: max_ancestors = np.inf circulars = [set(i) for i in nx.cycle_basis(self)] # we changed stuff to generators to gain some performance # revert if it causes issues. We know that the root_node # step needs a list so we unpack if needed... rotatable_edges = ( i for i in self.edges if not self.is_locked(*i) # and (hasattr(i[0], "element") and hasattr(i[1], "element")) and self[i[0]][i[1]].get("bond_order", 1) == 1 and not self.in_same_cycle(*i, circulars) ) if root_node is not None: rotatable_edges = list(rotatable_edges) _directed = nx.dfs_tree(self, root_node) rotatable_edges = [ i for i in _directed.edges if i in rotatable_edges or i[::-1] in rotatable_edges ] rotatable_edges = [ i for i in rotatable_edges if min_descendants < len(self.get_descendants(*i)) < max_descendants and min_ancestors < len(self.get_ancestors(*i)) < max_ancestors ] return rotatable_edges
[docs] def find_edges( self, root_node=None, min_descendants: int = 1, min_ancestors: int = 1, max_descendants: int = None, max_ancestors: int = None, bond_order: int = None, exclude_cycles: bool = False, only_cycles: bool = False, exclude_locked: bool = False, only_locked: bool = False, ) -> list: """ Find edges in the graph according to the given criteria. This does not restrict for edges that are rotatable. Parameters ---------- root_node A root node by which to direct the edges (closer to further). min_descendants: int, optional The minimum number of descendants that an edge must have to be considered rotatable. min_ancestors: int, optional The minimum number of ancestors that an edge must have to be considered rotatable. max_descendants: int, optional The maximum number of descendants that an edge must have to be considered rotatable. max_ancestors: int, optional The maximum number of ancestors that an edge must have to be considered rotatable. bond_order: int or tuple, optional The bond order to filter by. If a tuple is given, the bond order must be one of the values in the tuple. exclude_cycles: bool, optional Whether to exclude edges that are in cycles, by default False only_cycles: bool, optional Whether to only include edges that are in cycles, by default False exclude_locked: bool, optional Whether to exclude locked edges, by default False only_locked: bool, optional Whether to only include locked edges, by default False Returns ------- list A list of rotatable edges """ if not max_descendants: max_descendants = np.inf if not max_ancestors: max_ancestors = np.inf matching_edges = iter(self.edges) if only_locked and exclude_locked: raise ValueError( "Cannot exclude and include locked edges at the same time!" ) if exclude_locked: matching_edges = (i for i in matching_edges if i not in self._locked_edges) elif only_locked: matching_edges = (i for i in matching_edges if i in self._locked_edges) if bond_order is not None: if isinstance(bond_order, int): matching_edges = ( i for i in matching_edges if self[i[0]][i[1]].get("bond_order", 1) == bond_order ) elif isinstance(bond_order, tuple): matching_edges = ( i for i in matching_edges if self[i[0]][i[1]].get("bond_order", 1) in bond_order ) else: raise ValueError(f"Invalid datatype {type(bond_order)} for bond_order!") if exclude_cycles and only_cycles: raise ValueError("Cannot exclude and include cycles at the same time!") elif exclude_cycles: circulars = self.find_edges_in_cycles() if len(circulars) > 0: matching_edges = ( i for i in matching_edges if not self.in_same_cycle(*i, circulars) ) elif only_cycles: circulars = self.find_edges_in_cycles() if len(circulars) > 0: matching_edges = ( i for i in matching_edges if self.in_same_cycle(*i, circulars) ) if root_node is not None: matching_edges = list(matching_edges) _directed = nx.dfs_tree(self, root_node) matching_edges = [ i for i in _directed.edges if i in matching_edges or i[::-1] in matching_edges ] matching_edges = [ i for i in matching_edges if min_descendants < len(self.get_descendants(*i)) < max_descendants and min_ancestors < len(self.get_ancestors(*i)) < max_ancestors ] return matching_edges
[docs] def sample_edges( self, edges: list = None, n: int = 3, m: int = 3, ) -> list: """ Sample a number of rotatable edges from the graph. This is done by clustering the nodes together to sample "representive" edges from each cluster. This is useful for subsampling the rotatable edges for an optimization to reduce the search space. Parameters ---------- edges : list, optional The edges to sample from, by default None, in which case all rotatable edges are sampled. n: int The number of clusters to sample from. m : int The number of edges to sample from each cluster root_node A root node to direct the edges (optional) Returns ------- list A list of sampled edges """ center = np.mean([i.coord for i in self.nodes]) if edges is None: edges = self.find_rotatable_edges() rotatable_edges = np.array(edges) if len(rotatable_edges) == 0: raise ValueError("No rotatable edges found!") elif len(rotatable_edges) < n * m: return rotatable_edges.tolist() # x, y, z, n_neighbors_a_3, n_neighbors_b_3, n_descendants, dist_to_center data = np.zeros((len(rotatable_edges), 7)) for idx, edge in enumerate(rotatable_edges): node_a, node_b = edge data[idx, 0:3] = (node_a.coord + node_b.coord) / 2 data[idx, 3] = len(self.get_neighbors(node_a, 3)) data[idx, 4] = len(self.get_neighbors(node_b, 3)) data[idx, 5] = len(self.get_descendants(*edge)) data[idx, 6] = np.linalg.norm(data[idx, 0:3] - center) from sklearn.cluster import KMeans kmeans = KMeans(n_clusters=min(n, len(rotatable_edges)), n_init="auto") kmeans.fit(data) labels = kmeans.predict(data) _rotatable_edges = [] for i in range(kmeans.n_clusters): mask = np.where(labels == i) cluster = rotatable_edges[mask] if len(cluster) > m: prob = ( 0.3 * (data[mask, 3] + data[mask, 4]) + data[mask, 5] - data[mask, 6] ).squeeze() prob += np.abs(prob.min()) prob /= prob.sum() cluster = np.random.choice( np.arange(len(cluster)), m, replace=False, p=prob ) cluster = rotatable_edges[mask][cluster] _rotatable_edges.extend(cluster.tolist()) _rotatable_edges = [tuple(i) for i in _rotatable_edges] return _rotatable_edges
[docs] def in_same_cycle(self, node_1, node_2, cycles=None) -> bool: """ Check if two nodes are in the same cycle Parameters ---------- node_1, node_2 The nodes to check Returns ------- bool True if the nodes are in the same cycle, False otherwise """ if not cycles: cycles = nx.cycle_basis(self) for cycle in cycles: if node_1 in cycle and node_2 in cycle: return True return False
[docs] def in_cycle(self, node, cycles=None) -> bool: """ Check if a node is in a cycle Parameters ---------- node The node to check Returns ------- bool True if the node is in a cycle, False otherwise """ if not cycles: cycles = nx.cycle_basis(self) for cycle in cycles: if node in cycle: return True return False
[docs] def get_cycle(self, node, cycles=None) -> set: """ Get the cycle that a node is in Parameters ---------- node The node to check Returns ------- set The nodes in the cycle that the node is in. If the node is not in a cycle, None is returned. """ if not cycles: cycles = nx.cycle_basis(self) for cycle in cycles: if node in cycle: return set(cycle) return None
[docs] def direct_edges(self, root_node=None, edges: list = None) -> list: """ Sort the edges such that the first node in each edge is the one closer to the root node. If no root node is provided, the central node is used. Parameters ---------- root_node The root node to use for sorting the edges. If not provided, the central node is used. edges : list, optional The edges to sort, by default None, in which case all edges are sorted. Returns ------- list The sorted edges """ if not root_node: root_node = self.central_node if edges is None: edges = list(self.edges) if root_node not in self.nodes: raise ValueError(f"Root node {root_node} not in graph") _directed = nx.dfs_tree(self, source=root_node).edges _tupled_edges = [(i[0], i[1]) for i in edges] out = [edge if edge in _directed else edge[::-1] for edge in _tupled_edges] return out
[docs] def clear_cache(self): """ Clear the descendant cache """ self.__descendent_cache.clear()
[docs] def lock_edge(self, node_1, node_2): """ Lock an edge, preventing it from being rotated. Parameters ---------- node_1, node_2 The nodes that define the edge """ self._locked_edges.add((node_1, node_2))
[docs] def unlock_edge(self, node_1, node_2): """ Unlock an edge, allowing it to be rotated. Parameters ---------- node_1, node_2 The nodes that define the edge """ if (node_1, node_2) in self._locked_edges: self._locked_edges.remove((node_1, node_2))
[docs] def is_locked(self, node_1, node_2): """ Check if an edge is locked Parameters ---------- node_1, node_2 The nodes that define the edge Returns ------- bool Whether the edge is locked """ return (node_1, node_2) in self._locked_edges
[docs] def get_locked_edges(self): """ Get all locked edges Returns ------- set The locked edges """ return self._locked_edges
[docs] def get_unlocked_edges(self): """ Get all unlocked edges Returns ------- set The unlocked edges """ return set(self.edges) - self._locked_edges
[docs] def lock_all(self): """ Lock all edges """ self._locked_edges = set(self.edges)
[docs] def unlock_all(self): """ Unlock all edges """ self._locked_edges = set()
[docs] def rotate_around_edge( self, node_1, node_2, angle: float, descendants_only: bool = False, update_coords: bool = True, ): """ Rotate descending nodes around a specific edge by a given angle. Parameters ---------- node_1, node_2 The nodes that define the edge around which to rotate. angle: float The angle to rotate by, in radians. descendants_only: bool, optional Whether to only rotate the descending nodes, by default False, in which case the entire graph will be rotated. update_coords: bool, optional Whether to update the coordinates of the nodes after rotation, by default True. Returns ------- new_coords: dict The new coordinates of the nodes after rotation. """ # ---------- sanity checks ---------- # We can skip these here for a little performance boost # since we should assume that these methods are only ever # called from their wrappers in the entity classes... # ---------- sanity checks ---------- # if node_1 not in self.nodes or node_2 not in self.nodes: # raise ValueError("One or more nodes not in graph!") if node_1 is node_2: raise ValueError("Cannot rotate around an edge with only one node!") elif self.is_locked(node_1, node_2): raise ValueError("Cannot rotate around a locked edge!") # we need to get a reference node index to normalise the rotated # coordinates to the original coordinate system # indices = list(self.nodes) # idx_1 = next(idx for idx, i in enumerate(self.nodes) if i is node_1) # define the axis of rotation as the cross product of the edge's vectors edge_vector = node_2.coord - node_1.coord edge_vector /= np.linalg.norm(edge_vector) # create the rotation matrix # r = Rotation.from_rotvec(angle * edge_vector) # create a numpy array of the node coordinates if descendants_only: nodes = {i: i.coord for i in self.get_descendants(node_1, node_2)} nodes[node_2] = node_2.coord else: nodes = {i: i.coord for i in self.nodes} node_coords = np.array(tuple(nodes.values())) # indices = list(nodes.keys()) # idx_2 = indices.index(node_2) idx_2 = next(idx for idx, i in enumerate(nodes.keys()) if i is node_2) # apply the rotation matrix to the node coordinates # node_coords_rotated = r.apply(node_coords) node_coords_rotated = base.rotate_coords( node_coords - node_coords[idx_2], angle, edge_vector ) node_coords_rotated += node_coords[idx_2] # # now adjust for the translatisonal shift around the axis # _diff = node_coords_rotated[idx_2] - node_coords[idx_2] # node_coords_rotated -= _diff # update the node coordinates in the graph new_coords = {i: node_coords_rotated[idx] for idx, i in enumerate(nodes.keys())} if update_coords: for node, coord in new_coords.items(): node.coord = coord _new_coords = {i: i.coord for i in self.nodes} _new_coords.update(new_coords) return _new_coords
def _get_structure(self): """ Get the underlying `bio.PDB.Structure` object """ if not hasattr(list(self.nodes)[0], "get_parent"): warnings.warn("Nodes are not Biopython entities with linked parents!") return None structure = list(self.nodes)[0].get_parent() if structure is None: warnings.warn("Nodes do not seem to have linked parents!") return None while not isinstance(structure, bio.Structure.Structure): structure = structure.get_parent() return structure def __str__(self): lines = "\n".join(nx.generate_network_text(self)) return lines
if __name__ == "__main__": import buildamol as bam from timeit import timeit import seaborn as sns import matplotlib.pyplot as plt from functools import partial mol = bam.molecule( "/Users/noahhk/GIT/biobuild/buildamol/optimizers/_testing/files/EX8.json" ) v = mol.draw() g = BaseGraph(None, mol.bonds) nx.set_edge_attributes(g, 1, "bond_order") _g = mol.get_atom_graph() _g.sample_rotatable_edges = partial(BaseGraph.sample_rotatable_edges, _g) edges = _g.sample_edges(_g.find_rotatable_edges(min_descendants=10), n=4, m=3) v.draw_edges(*edges, color="limegreen", linewidth=8, elongate=1.1) v.show() pass # ref_atom_graph = mol.get_atom_graph() # a, b = mol.get_atoms(68, 65) # x = g.get_descendants(a, b) # for i in x: # v.draw_atom(i, color="purple") # measure_performance = True # repeats = 500 # number = 800 # if measure_performance: # test_old = lambda: BaseGraph.get_descendants_old(ref_atom_graph, a, b) # test_new = lambda: g.get_descendants(a, b) # # test_2 = lambda: g.get_descendants_2(a, b) # times_old = [timeit(test_old, number=number) for _ in range(repeats)] # times_new = [timeit(test_new, number=number) for _ in range(repeats)] # # times_2 = [timeit(test_2, number=number) for _ in range(repeats)] # sns.distplot(times_old, label="old", kde=True, bins=20) # sns.distplot(times_new, label="new", kde=True, bins=20) # # sns.distplot(times_2, label="2", kde=True, bins=20) # plt.legend() # plt.show() # pass # v.draw_edges((a, b), color="limegreen", linewidth=4, elongate=1.1) # ref_decendants = ref_atom_graph.get_descendants(a, b) # descendants = g.get_descendants_2(a, b) # for i in descendants: # v.draw_atom(i, color="purple") # for i in ref_decendants: # v.draw_atom(i, color="orange") # v.show() # pass