"""
The DistanceRotatron environment evaulates conformations based on the pairwise distances between nodes in the optimized graph.
It uses two forces, a global "unfolding" force to maximize spacial separation between nodes, and a local "pushback" force to maximize distances between the closest nodes.
The evaluation is computed as:
.. math::
e_i = \\sum_{j \\neq i} d_{ij}^{unfold} + pushback \\cdot \\sum_{k=1}^N \\text{sorted}(d)_{ik}
There are multiple variations of this basic formulation available (see the functions below).
"""
import gym
import numpy as np
from scipy.spatial.distance import cdist
import biobuild.optimizers.Rotatron as Rotatron
import biobuild.graphs.BaseGraph as BaseGraph
# Rotatron = Rotatron.Rotatron
[docs]
def simple_concatenation_function(self, x):
"""
A simple concatentation function that computes the evaluation as:
Mean distance ** unfold + (mean of n smallest distances) ** pushback
"""
smallest = np.sort(x)[: self.n_smallest]
e = np.power(np.mean(x), self.unfold) + np.power(np.mean(smallest), self.pushback)
return e
[docs]
def concatenation_function_with_penalty(self, x):
"""
A concatentation function that computes the evaluation as:
(Mean distance ** unfold + (mean of n smallest distances) ** pushback) / clash penalty
"""
smallest = np.sort(x)[: self.n_smallest]
penalty = np.sum(x < 1.5 * self.clash_distance)
e = np.power(np.mean(x), self.unfold) + np.power(np.mean(smallest), self.pushback)
e /= (1 + penalty) ** 2
return e
[docs]
def concatenation_function_no_pushback(self, x):
"""
A concatentation function that computes the evaluation as:
Mean distance ** unfold
"""
e = np.power(np.mean(x), self.unfold)
return e
[docs]
def concatenation_function_no_unfold(self, x):
"""
A concatentation function that computes the evaluation as:
Mean distance + pushback * mean of n smallest distances
"""
smallest = np.sort(x)[: self.n_smallest]
e = np.power(np.mean(smallest), self.pushback)
return e
[docs]
def concatenation_function_linear(self, x):
"""
A concatentation function that computes the evaluation as:
Mean distance * unfold + (mean of n smallest distances) * pushback
"""
smallest = np.sort(x)[: self.n_smallest]
e = np.multiply(np.mean(x), self.unfold) + np.multiply(
np.mean(smallest), self.pushback
)
return e
[docs]
class DistanceRotatron(Rotatron):
"""
A distance-based Rotatron environment.
Parameters
----------
graph : AtomGraph or ResidueGraph
The graph to optimize
rotatable_edges : list
A list of edges that can be rotated during optimization.
If None, all non-locked edges are used.
radius : float
The radius around rotatable edges to include in the distance calculation.
Set to -1 to disable.
pushback : float
Short distances between atoms are given higher weight in the evaluation using this factor.
unfold : float
The exponent to use when computing the mean distance to others for
each node. Higher values give higher values to global unfolding of the graph.
clash_distance : float
The distance at which atoms are considered to be clashing.
crop_nodes_further_than : float
Nodes that are further away than this factor times the radius from any rotatable edge at the beginning
of the optimization are removed from the graph and not considered during optimization. This speeds up
computation. Set to -1 to disable.
n_smallest : int
The number of smallest distances to use when computing the evaluation for each node.
concatenation_function : callable
A custom function to use when computing the evaluation for each node.
This function should take the environment (self) as first argument and a 1D array of pairwise-distances from one node to all others as second argument and return a scalar.
bounds : tuple
The bounds for the minimal and maximal rotation angles.
"""
def __init__(
self,
graph: "BaseGraph.BaseGraph",
rotatable_edges: list = None,
radius: float = 20,
pushback: float = 3,
unfold: float = 2,
clash_distance: float = 0.9,
crop_nodes_further_than: float = -1,
n_smallest: int = 10,
concatenation_function: callable = None,
bounds: tuple = (-np.pi, np.pi),
**kwargs,
):
self.kwargs = kwargs
self.radius = radius
self.crop_radius = crop_nodes_further_than * radius if radius > 0 else -1
self.clash_distance = clash_distance
self.pushback = pushback
self.unfold = unfold
self.n_smallest = n_smallest
if concatenation_function is None:
concatenation_function = concatenation_function_with_penalty
self._concatenation_function = concatenation_function
self._bounds_tuple = bounds
# =====================================
if self.crop_radius > 0:
rotatable_edges = self._get_rotatable_edges(graph, rotatable_edges)
edge_coords = np.array(
[(a.coord + b.coord) / 2 for a, b in rotatable_edges]
)
nodes = list(graph.nodes)
node_coords = np.array([node.coord for node in nodes])
dists = cdist(edge_coords, node_coords)
dists = dists > self.crop_radius
dists = np.apply_along_axis(np.all, 0, dists)
if np.max(dists) != 0:
nodes_to_drop = [nodes[i] for i, d in enumerate(dists) if d]
graph.remove_nodes_from(nodes_to_drop)
# =====================================
if radius > 0:
self._radius = radius
else:
self._radius = np.inf
# =====================================
def concatenation_wrapper(x):
mask = x < self._radius
mask = np.logical_and(mask, self.rotation_unit_masks[self.ndx])
self.ndx += 1
if not np.logical_or.reduce(mask):
return -1
return self._concatenation_function(self, x[mask])
self.concatenation_function = concatenation_wrapper
self._state_dists = np.zeros((len(graph.nodes), len(graph.nodes)))
# =====================================
self.edx = 0
Rotatron.__init__(self, graph, rotatable_edges)
self.action_space = gym.spaces.Box(
low=bounds[0], high=bounds[1], shape=(len(self.rotatable_edges),)
)
self.observation_space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(len(self.graph.nodes), 3)
)
# =====================================
self._best_clashes = self.count_clashes()
[docs]
def eval(self, state):
"""
Calculate the evaluation score for a given state
Parameters
----------
state : np.ndarray
The state of the environment
Returns
-------
float
The evaluation for the state
"""
pairwise_dists = cdist(state, state)
np.fill_diagonal(pairwise_dists, self._radius)
self.ndx = 0
dist_eval = np.apply_along_axis(self.concatenation_function, 1, pairwise_dists)
mask = dist_eval > -1
if not np.logical_or.reduce(mask):
return self._last_eval
mean_dist_eval = np.divide(1.0, np.mean(dist_eval[mask]))
final = np.log(mean_dist_eval)
self._state_dists[:, :] = pairwise_dists
self._last_eval = final
return final
[docs]
def step(self, action):
for i, edge in enumerate(self.rotatable_edges):
self.edx = i
new_state = self._rotate(i, action[i])
self._last_eval = self.eval(new_state)
clashes = self.count_clashes()
done = clashes == 0
self._action_history += action
if self._last_eval < self._best_eval and clashes <= self._best_clashes:
self._best_eval = self._last_eval
self._best_state *= 0
self._best_state += new_state
self._best_action *= 0
self._best_action += action
self._best_clashes = clashes
return new_state, self._last_eval, done, {}
[docs]
def is_done(self, state):
return self.count_clashes() == 0
[docs]
def count_clashes(self):
return np.sum(self._state_dists < self.clash_distance)
__all__ = [
"DistanceRotatron",
"simple_concatenation_function",
"concatenation_function_with_penalty",
"concatenation_function_no_pushback",
"concatenation_function_no_unfold",
"concatenation_function_linear",
]
if __name__ == "__main__":
import biobuild as bb
import matplotlib.pyplot as plt
import seaborn as sns
mol = bb.molecule("/Users/noahhk/GIT/biobuild/_tutorials copy/ext8_opt.pdb")
graph = mol.get_residue_graph(True)
edges = graph.find_rotatable_edges(min_descendants=3)
d = DistanceRotatron(
graph,
edges,
pushback=3,
concatenation_function=concatenation_function_with_penalty,
)
opt = bb.optimizers.optimize(mol, d, "swarm")
print(opt.count_clashes())
opt.to_pdb(f"opt9_pow_pushback_{d.pushback}.pdb")
# import stable_baselines3 as sb3
# model = sb3.PPO("MlpPolicy", d, verbose=1)
# model.learn(total_timesteps=10000)
# model.save("ppo_distance_rotatron")
x_ = d._best_eval
x0 = d.step(d.blank())
a = d.blank()
a[22] = 0.1
x = d.step(a)
print(x[1])
pass
# -------------------------------------
# bb.load_sugars()
# glc = bb.molecule("GLC")
# glc.repeat(2, "14bb")
# bonds = [glc.get_bonds("O4", "C4")[0]]
# env = DistanceRotatron(glc.make_atom_graph(), bonds, radius=20)
# actions = np.arange(
# -np.pi,
# np.pi,
# np.pi / 80,
# )
# cmap = sns.color_palette("Blues", len(actions))
# evals = []
# v = glc.draw()
# for i in actions:
# new_state, e, done, _ = env.step(np.array([i]))
# evals.append([i, e])
# _glc = glc.copy()
# _glc.rotate_around_bond(*bonds[0], np.degrees(i), descendants_only=True)
# color = "lightgray"
# if e > 0.2157:
# color = "red"
# # elif e > 0.1970:
# # color = "orange"
# # elif e < 0.1965:
# # color = "green"
# if color == "lightgray":
# opacity = 0.4
# else:
# opacity = 1.0
# v.draw_edges(
# _glc.get_bonds(_glc.residues[1]), color=color, linewidth=3, opacity=opacity
# )
# evals = np.array(evals)
# plt.plot(evals[:, 0], evals[:, 1])
# v.show()
# plt.show()
# pass