Source code for buildamol.optimizers.translatron

"""
This is the Translatron environment that can be used to place a molecule according to some constraints.
"""

import gym
from gym import spaces
import numpy as np

import buildamol.utils.auxiliary as aux
import buildamol.structural.base as structural

__all__ = ["Translatron"]


[docs] class Translatron(gym.Env): """ The Translatron environment can be used to place a molecule in space according to some constraints. It will produce a vector of 6 values, the first 3 are the translation in x, y, and z, and the last 3 are the rotations around the x, y, and z axes. Parameters ---------- graph : AtomGraph or ResidueGraph The graph of the molecule to be placed. constraint_func : callable A function that takes the environment and the new coordinates and returns a reward value. finish_func : callable A function that takes the environment and the new coordinates and returns a boolean value indicating whether the optimization is finished. bounds : tuple The bounds for the minimal and maximal translation and rotation values. If a tuple of length two this is interpreted as the low and high bounds for translation only. Otherwise provide a tuple of length 6 for the bounds of translation and rotation. In this case values can be either singular (int/float) in which case they are interpreted as symmetric extrama (min=-value, max=+value) or as tuples with (min=value[0], max=value[1]). Mixed inputs are allowed. """ def __init__( self, graph, constraint_func: callable, finish_func: callable = None, bounds: tuple = (-10, 10), ): self.graph = graph self.constraint_func = constraint_func self.finish_func = finish_func or self.__no_finish if len(bounds) == 2: low = [bounds[0], bounds[0], bounds[0], -np.pi, -np.pi, -np.pi] high = [bounds[1], bounds[1], bounds[1], np.pi, np.pi, np.pi] elif len(bounds) == 6: bounds = [(-i, i) if isinstance(i, (int, float)) else i for i in bounds] low = [ bounds[0][0], bounds[1][0], bounds[2][0], bounds[3][0], bounds[4][0], bounds[5][0], ] high = [ bounds[0][1], bounds[1][1], bounds[2][1], bounds[3][1], bounds[4][1], bounds[5][1], ] else: raise ValueError( f"Expected 'bounds' to be of length 2 (only translation min/max) or length 6 (min/max for translation and rotations), but got length {len(bounds)}" ) self.action_space = spaces.Box( low=np.array(low), high=np.array(high), shape=(6,), dtype=np.float32, ) self._setup(graph) if aux.USE_NUMBA: self.apply = self._numba_apply else: self.apply = self._normal_apply
[docs] def reset(self, *args, **kwargs): self.state[:] = self.__backup_coords
[docs] def step(self, action): new_coords = self.apply(action) reward = self.constraint_func(self, new_coords) done = self.finish_func(self, new_coords) return new_coords, reward, done, {}
[docs] def eval(self, state): return self.constraint_func(self, state)
def _normal_apply(self, action): """ Apply an action to the current state. """ new = structural.rotate_coords(self.state, action[3], structural.x_axis) new = structural.rotate_coords(new, action[4], structural.y_axis) new = structural.rotate_coords(new, action[5], structural.z_axis) new = new + action[:3] return new def _numba_apply(self, action): new = structural._numba_wrapper_rotate_coords( self.state, action[3], structural.x_axis ) new = structural._numba_wrapper_rotate_coords(new, action[4], structural.y_axis) new = structural._numba_wrapper_rotate_coords(new, action[5], structural.z_axis) new = new + action[:3] return new
[docs] def blank(self): return np.zeros(6)
def _setup(self, graph): coords = np.array([i.coord for i in graph.nodes]) self.state = coords self.__backup_coords = coords.copy() def __no_finish(self, coords, *args, **kwargs): return False
if __name__ == "__main__": import buildamol as bam mol = bam.read_pdb("/Users/noahhk/GIT/biobuild/0.pdb") mol2 = mol.copy() mol2.move([10, 10, 10]) mol2.rotate(90, [1, 0, 0]) mol2.rotate(36, [0, 1, 0]) ref_coords = np.array([i.coord for i in mol.get_atoms()]) def constraint(env, coords): dist = (coords - ref_coords) ** 2 dist = np.mean(dist) env._dist = dist return dist def finish(env, coords): return env._dist < 0.1 env = Translatron(mol2.get_atom_graph(), constraint) sol, eval = bam.optimizers.swarm_optimize(env, n_particles=50) v = mol.draw(atoms=False) mol2.move(sol[:3]) mol2.rotate(sol[3], "x") mol2.rotate(sol[4], "y") mol2.rotate(sol[5], "z") v.draw_edges(*mol2.get_bonds(), color="blue", linewidth=3, showlegend=False) v.show()