"""
The DistanceRotatron environment evaulates conformations based on the pairwise distances between nodes in the optimized graph.
It uses two forces, a global "unfolding" force to maximize spacial separation between nodes, and a local "pushback" force to maximize distances between the closest nodes.
The evaluation is computed as:
.. math::
e_i = \\sum_{j \\neq i} d_{ij}^{unfold} + pushback \\cdot \\sum_{k=1}^N \\text{sorted}(d)_{ik}
There are multiple variations of this basic formulation available (see the functions below).
"""
import gym
import numpy as np
from scipy.spatial.distance import cdist
import buildamol.optimizers.base_rotatron as Rotatron
import buildamol.graphs.base_graph as base_graph
import buildamol.utils.auxiliary as aux
import buildamol.structural.base as structural
# Rotatron = Rotatron.Rotatron
def concatenation_wrapper(x):
pass
[docs]
def simple_concatenation_function(x, unfold, pushback, n_smallest, clash_distance):
"""
A simple concatentation function that computes the evaluation as:
Mean distance ** unfold + (mean of n smallest distances) ** pushback
"""
k = min(n_smallest, len(x) - 1)
smallest = np.partition(x, k)[:k] # np.sort(x)[:n_smallest]
e = np.power(np.mean(x), unfold) + np.power(np.mean(smallest), pushback)
return e
_numba_wrapper_simple_concatenation_function = aux.njit(simple_concatenation_function)
[docs]
def concatenation_function_with_penalty(
x, unfold, pushback, n_smallest, clash_distance
):
"""
A concatentation function that computes the evaluation as:
[(Mean distance ** unfold + (mean of n smallest distances) ** pushback)] / clash penalty
"""
k = min(n_smallest, len(x) - 1)
smallest = np.partition(x, k)[:k] # np.sort(x)[:n_smallest]
penalty = np.sum(x < 1.5 * clash_distance)
e = np.power(np.mean(x), unfold) + np.power(np.mean(smallest), pushback)
e /= (1 + penalty) ** 2
return e
_numba_wrapper_concatenation_function_with_penalty = aux.njit(
concatenation_function_with_penalty
)
[docs]
def concatenation_function_no_pushback(x, unfold, pushback, n_smallest, clash_distance):
"""
A concatentation function that computes the evaluation as:
Mean distance ** unfold
"""
e = np.power(np.mean(x), unfold)
return e
_numba_wrapper_concatenation_function_no_pushback = aux.njit(
concatenation_function_no_pushback
)
[docs]
def concatenation_function_no_unfold(x, unfold, pushback, n_smallest, clash_distance):
"""
A concatentation function that computes the evaluation as:
(Mean of n smallest distances) ** pushback
"""
k = min(n_smallest, len(x) - 1)
smallest = np.partition(x, k)[:k] # np.sort(x)[:n_smallest]
e = np.power(np.mean(smallest), pushback)
return e
_numba_wrapper_concatenation_function_no_unfold = aux.njit(
concatenation_function_no_unfold
)
[docs]
def concatenation_function_linear(x, unfold, pushback, n_smallest, clash_distance):
"""
A concatentation function that computes the evaluation as:
Mean distance * unfold + (mean of n smallest distances) * pushback
"""
k = min(n_smallest, len(x) - 1)
smallest = np.partition(x, k)[:k] # np.sort(x)[:n_smallest]
e = np.multiply(np.mean(x), unfold) + np.multiply(np.mean(smallest), pushback)
return e
_numba_wrapper_concatenation_function_linear = aux.njit(concatenation_function_linear)
__numba_wrappers__ = {
simple_concatenation_function: _numba_wrapper_simple_concatenation_function,
concatenation_function_with_penalty: _numba_wrapper_concatenation_function_with_penalty,
concatenation_function_no_pushback: _numba_wrapper_concatenation_function_no_pushback,
concatenation_function_no_unfold: _numba_wrapper_concatenation_function_no_unfold,
concatenation_function_linear: _numba_wrapper_concatenation_function_linear,
}
[docs]
class DistanceRotatron(Rotatron.Rotatron):
"""
A distance-based Rotatron environment.
Parameters
----------
graph : AtomGraph or ResidueGraph
The graph to optimize
rotatable_edges : list
A list of edges that can be rotated during optimization.
If None, all non-locked edges are used.
radius : float
The radius around rotatable edges to include in the distance calculation.
Set to -1 to disable.
pushback : float
Short distances between atoms are given higher weight in the evaluation using this factor.
unfold : float
The exponent to use when computing the mean distance to others for
each node. Higher values give higher values to global unfolding of the graph.
clash_distance : float
The distance at which atoms are considered to be clashing.
crop_nodes_further_than : float
Nodes that are further away than this factor times the radius from any rotatable edge at the beginning
of the optimization are removed from the graph and not considered during optimization. This speeds up
computation. Set to -1 to disable.
n_smallest : int
The number of smallest distances to use when computing the evaluation for each node.
concatenation_function : callable
A custom function to use when computing the evaluation for each node.
This function should take the state array as first argument and may take any additional arguments.
These additional arguments must be passed as keyword arguments to the environment during setup.
The function must return a float.
bounds : tuple
The bounds for the minimal and maximal rotation angles.
n_processes : int
The number of processes to use for parallel computation during edge mask generation.
"""
def __init__(
self,
graph: "base_graph.BaseGraph",
rotatable_edges: list = None,
radius: float = 20,
pushback: float = 3,
unfold: float = 2,
clash_distance: float = 1.2,
crop_nodes_further_than: float = -1,
n_smallest: int = 10,
concatenation_function: callable = None,
bounds: tuple = (-np.pi, np.pi),
n_processes: int = 1,
**kwargs,
):
self.hyperparameters = {
"pushback": pushback,
"unfold": unfold,
"clash_distance": clash_distance,
"crop_nodes_further_than": crop_nodes_further_than,
"n_smallest": n_smallest,
"concatenation_function": concatenation_function,
"bounds": bounds,
"radius": radius,
"n_processes": n_processes,
**kwargs,
}
self.kwargs = kwargs
self.radius = radius
self.crop_radius = crop_nodes_further_than * radius if radius > 0 else -1
self.clash_distance = clash_distance
self.pushback = pushback
self.unfold = unfold
self.n_smallest = n_smallest
# =====================================
if self.crop_radius > 0:
graph, rotatable_edges = self._setup_helpers_crop_faraway_nodes(
self.crop_radius, graph, rotatable_edges
)
# =====================================
if radius > 0:
self._radius = radius
else:
self._radius = np.inf
# =====================================
# self.concatenation_function = concatenation_wrapper
self._state_dists = np.zeros((len(graph.nodes), len(graph.nodes)))
# =====================================
self.edx = 0
Rotatron.Rotatron.__init__(
self, graph, rotatable_edges, n_processes=n_processes, **kwargs
)
self.action_space = gym.spaces.Box(
low=bounds[0], high=bounds[1], shape=(len(self.rotatable_edges),)
)
self.observation_space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(len(self.graph.nodes), 3)
)
# =====================================
self._last_eval = 0.0
self.n_smallest = min(n_smallest, self.n_nodes)
self.hyperparameters["n_smallest"] = self.n_smallest
# =====================================
self._numba_func_signature = (
"unfold",
"pushback",
"n_smallest",
"clash_distance",
)
if concatenation_function is None:
concatenation_function = concatenation_function_with_penalty
if (
kwargs.get("numba", False)
or aux.USE_ALL_NUMBA
or (self.n_nodes**2 > 100000 and aux.USE_NUMBA)
):
concatenation_function = __numba_wrappers__.get(
concatenation_function, concatenation_function
)
self.eval = self._numba_eval
else:
self.eval = self._normal_eval
self._concatenation_function = concatenation_function
self._concatenation_function_kwargs = aux.get_args(
self._concatenation_function, self.hyperparameters
)
self._bounds_tuple = bounds
# =====================================
def _normal_eval(self, state):
"""
Calculate the evaluation score for a given state
Parameters
----------
state : np.ndarray
The state of the environment
Returns
-------
float
The evaluation for the state
"""
pairwise_dists = cdist(state, state)
np.fill_diagonal(pairwise_dists, self._radius)
mask = pairwise_dists < self._radius
mask = np.logical_and(mask, self.rotation_unit_masks)
dist_eval = np.zeros(len(pairwise_dists))
_changed_entries = mask.any(axis=1)
dist_eval[~_changed_entries] = -1
for i in np.where(_changed_entries)[0]:
dist_eval[i] = self._concatenation_function(
pairwise_dists[i][mask[i]], **self._concatenation_function_kwargs
)
mean_dist_eval = np.divide(1.0, np.mean(dist_eval[dist_eval > -1]))
final = np.log(mean_dist_eval) # - self._backup_eval
min_dist = np.min(pairwise_dists)
self._state_dists = min_dist
self._last_eval = final
return final
def _numba_eval(self, state):
min_dist, final = _numba_wrapper_eval(
state=state,
concatenation_function=self._concatenation_function,
rotation_unit_masks=self.rotation_unit_masks,
last_eval=self._last_eval,
radius=self._radius,
**self._concatenation_function_kwargs,
)
self._state_dists = min_dist
self._last_eval = final
return final
[docs]
def is_done(self, state):
return np.min(self._state_dists) > self.clash_distance
[docs]
def concatenation_function(self, x):
mask = x < self._radius
mask = np.logical_and(mask, self.rotation_unit_masks[self.ndx])
self.ndx += 1
if not np.logical_or.reduce(mask):
return -1
return self._concatenation_function(
x[mask], **self._concatenation_function_kwargs
)
@aux.njit
def _numba_wrapper_eval(
state,
concatenation_function,
rotation_unit_masks,
last_eval,
radius,
unfold,
pushback,
n_smallest,
clash_distance,
):
pairwise_dists = structural._numba_wrapper_euclidean_distances(state, state)
np.fill_diagonal(pairwise_dists, radius)
dist_eval = np.zeros(len(state))
mask = pairwise_dists < radius
mask = np.logical_and(mask, rotation_unit_masks)
dist_eval = np.zeros(len(pairwise_dists))
for i in range(len(pairwise_dists)):
if not mask[i].any():
dist_eval[i] = -1
else:
dist_eval[i] = concatenation_function(
pairwise_dists[i][mask[i]], unfold, pushback, n_smallest, clash_distance
)
mask = dist_eval > -1
mean_dist_eval = np.divide(1.0, np.mean(dist_eval[mask]))
final = np.log(mean_dist_eval) # - self._backup_eval
min_dist = np.min(pairwise_dists)
return min_dist, final
__all__ = [
"DistanceRotatron",
"simple_concatenation_function",
"concatenation_function_with_penalty",
"concatenation_function_no_pushback",
"concatenation_function_no_unfold",
"concatenation_function_linear",
]
if __name__ == "__main__":
import buildamol as bam
DistanceRotatron._backup_eval = 0.0
mol = bam.Molecule.from_json(
"/Users/noahhk/GIT/biobuild/buildamol/optimizers/__testing__/files/EX8.json"
)
print("init: ", mol.count_clashes())
graph = mol.get_residue_graph(True)
env = DistanceRotatron(
graph, numba=False, concatenation_function=simple_concatenation_function
) # , concatenatiofn_function=simple_concatenation_function)
from time import time
for i in range(15):
t0 = time()
out = bam.optimizers.optimize(mol.copy(), env)
print(time() - t0, out.count_clashes())
# import matplotlib.pyplot as plt
# import seaborn as sns
# mol = bam.molecule("/Users/noahhk/GIT/biobuild/_tutorials copy/ext8_opt.pdb")
# graph = mol.get_residue_graph(True)
# edges = graph.find_rotatable_edges(min_descendants=3)
# from time import time
# for i in range(5, 10):
# t0 = time()
# d = DistanceRotatron(
# graph,
# edges,
# pushback=3,
# concatenation_function=concatenation_function_with_penalty,
# n_processes=i,
# )
# # opt = bam.optimizers.optimize(mol, d, "swarm")
# print(time() - t0)
# exit()
# t1 = time()
# d = DistanceRotatron(
# graph,
# edges,
# pushback=3,
# concatenation_function=concatenation_function_with_penalty,
# n_processes=5,
# )
# opt = bam.optimizers.optimize(mol, d, "swarm")
# print(opt.count_clashes())
# opt.to_pdb(f"opt9_pow_pushback_{d.pushback}.pdb")
# import stable_baselines3 as sb3
# model = sb3.PPO("MlpPolicy", d, verbose=1)
# model.learn(total_timesteps=10000)
# model.save("ppo_distance_rotatron")
# x_ = d._best_eval
# x0 = d.step(d.blank())
# a = d.blank()
# a[22] = 0.1
# x = d.step(a)
# print(x[1])
# pass
# -------------------------------------
# bam.load_sugars()
# glc = bam.molecule("GLC")
# glc.repeat(2, "14bb")
# bonds = [glc.get_bonds("O4", "C4")[0]]
# env = DistanceRotatron(glc.make_atom_graph(), bonds, radius=20)
# actions = np.arange(
# -np.pi,
# np.pi,
# np.pi / 80,
# )
# cmap = sns.color_palette("Blues", len(actions))
# evals = []
# v = glc.draw()
# for i in actions:
# new_state, e, done, _ = env.step(np.array([i]))
# evals.append([i, e])
# _glc = glc.copy()
# _glc.rotate_around_bond(*bonds[0], np.degrees(i), descendants_only=True)
# color = "lightgray"
# if e > 0.2157:
# color = "red"
# # elif e > 0.1970:
# # color = "orange"
# # elif e < 0.1965:
# # color = "green"
# if color == "lightgray":
# opacity = 0.4
# else:
# opacity = 1.0
# v.draw_edges(
# _glc.get_bonds(_glc.residues[1]), color=color, linewidth=3, opacity=opacity
# )
# evals = np.array(evals)
# plt.plot(evals[:, 0], evals[:, 1])
# v.show()
# plt.show()
# pass