"""
Tools for various FEP-related analyses.
Copyright Schrodinger, LLC. All rights reserved.
"""
# Contributors: Yujie Wu, Dan Sindhikara
import copy
import math
import os
import re
import sys
from past.utils import old_div
from pathlib import Path
from typing import Dict
from typing import List
from typing import Union
import numpy
from schrodinger.application.desmond import bennett
from schrodinger.application.desmond import cms
from schrodinger.application.desmond import config
from schrodinger.application.desmond import measurement
from schrodinger.application.desmond import util
from schrodinger.application.desmond.constants import BOLTZMANN
# FIXME: Replaces this class with Joe's version.
[docs]class Restraint:
[docs] def __init__(self, atom, ref, k):
self.atom = atom
self.ref = ref
self.k = k
def _float_range(start, stop, step, closed=True):
"""
Return evenly spaced float values from start to stop.
:param start: First value in the range
:param stop: Last value in the range if `closed`, first value outside of the
range if not `closed`.
:param step: The size of the step between elements of the result range.
:param closed: Should `stop` be included in the range?
"""
EPSILON = 1e-9
signed_epsilon = -EPSILON if step < 0 else EPSILON
if closed:
points = numpy.arange(start, stop + signed_epsilon, step)
if points.size > 1 and abs(points[-1] - stop) < EPSILON:
points[-1] = stop
return points
else:
return numpy.arange(start, stop - signed_epsilon, step)
[docs]def get_energy_table(fname, term_list):
"""
columns: (0, 0) (0, 1) (1, 1) ... "total"
rows : frame1, frame2, frame3, ...
Separate table for each term. table[row][col]
"""
with open(fname, "r") as fh:
s = fh.readlines()
table = {term: [] for term in term_list}
col_meaning = ["time"]
TITLE_PATTERN = re.compile(r"\([^\(\)]+\)")
for line in s:
word = TITLE_PATTERN.findall(line)
if (word != [] and word[0].lower() == "(pair)"):
for w in word[2:]:
if (w.lower() == "total"):
col_meaning.append("total")
else:
col_meaning.append(
tuple([int(e) for e in w[1:-1].split(",")]))
break
for line in s:
word = line.split()
for term in table:
if (word != [] and word[0] == term):
value = [
float(word[1][1:-1]),
]
value.extend([float(e) for e in word[2:]])
table[term].append(value)
break
return (
col_meaning,
table,
)
[docs]def get_global_quantity(fname, quantity_list):
"""
"""
with open(fname, "r") as fh:
s = fh.readlines()
table = {quantity: [] for quantity in quantity_list}
for line in s:
if (line[:4] == "time"):
quantity = line.split()
for e in quantity:
q, v = e.split("=")
if (q in table):
table[q].append(float(v))
return table
[docs]def get_mean(ene, index=-1, data_structure="table"):
"""
Returns (mean, std_error, std_dev,).
"""
ene_average = 0.0
ene_stddev = 0.0
num_data = 0
if (data_structure == "table"):
for row in ene:
ene_average += row[index]
num_data += 1
elif (data_structure == "array"):
for e in ene:
ene_average += e
num_data += 1
else:
raise ValueError("Unknown data structure: %s" % data_structure)
if (num_data == 0):
num_data = 1
ene_average /= num_data
if (data_structure == "table"):
for row in ene:
ene_stddev += (row[index] - ene_average) * (row[index] -
ene_average)
elif (data_structure == "array"):
for e in ene:
ene_stddev += (e - ene_average) * (e - ene_average)
if (num_data == 1):
ene_stddev = float("inf")
else:
ene_stddev /= num_data - 1
ene_stddev = math.sqrt(ene_stddev)
return ene_average, old_div(ene_stddev, math.sqrt(num_data)), ene_stddev
[docs]def parse_eneseq(eneseq_fname):
"""
Returns a 1D structured array with names equal to the column names
:param eneseq_fname: eneseq file name.
"""
with open(eneseq_fname, "r") as fh:
for line in fh.readlines():
line = line.strip()
if line and line[0] == "#":
split = line.split()
if len(split) > 1 and split[1] == '0:time':
# col_id:col_name (units) ...
header_labels = line.split()[1::2]
header_names = [
label.split(":")[1] for label in header_labels
]
break
else:
raise ValueError("Column headers not found")
return numpy.atleast_1d(numpy.genfromtxt(eneseq_fname, names=header_names))
[docs]def init_bennett(data: Union[str, numpy.ndarray],
n_win=12,
temperature=300.0,
begin_time=100.0,
end_time=-1.0,
random_seed=2111839,
result_file=None,
nresamples=0,
file_pattern='gibbs.%d.dE'):
"""
:param data: Either a directory or a numpy array.
As a directory, it must contain the dE files, which are named after the
pattern specified by the `file_pattern` argument.
As a numpy array, it is the dE data read from the dE files. The data is
an MxNx3 array, where M is the number of lambda windows, N the number of
time points, and the 3 are (time, forward energy, reverse energy).
"""
assert isinstance(data, (str, Path, numpy.ndarray)), \
"Must be a numpy array or str/Path"
if are_times_insane(begin_time, end_time):
print("Warning, BAR calculation initialized with unreasonable \n")
print(f"begin and end times: {begin_time}, {end_time}")
bar = bennett.CalcBAR(begin_time=begin_time,
end_time=end_time,
temperature=temperature,
seed=random_seed,
nresamples=nresamples)
if isinstance(data, (str, Path)):
if result_file is not None:
result_file = os.path.join(data, os.path.basename(result_file))
fns = [os.path.join(data, file_pattern % i) for i in range(0, n_win)]
bar.load_data(fns)
else:
# data is a numpy array
bar.dE = data
bar.filter_data(begin_time, end_time)
bar.set_output(result_file, None, None)
return bar
[docs]def run_bennett(bar, begin_time=100.0, end_time=-1.0, nresamples=None):
"""
"""
if nresamples is not None:
bar.set_nresamples(nresamples)
bar.set_seed(bar.seed)
bar.filter_data(begin_time, end_time)
try:
results = bar.analyze_data()
bar.write_results(results)
dF = [
measurement.Measurement(a[0], max(a[1], a[2])) for a in results[:-1]
]
result = measurement.Measurement(results[-1][0],
max(results[-1][1], results[-1][2]))
return result, bar.err, dF, results
except Exception as e:
# Reduce begin_time gradually until there's enough statistics for
# BAR calculation to succeed
if begin_time > 0:
begin_time -= 50.0
if begin_time < 0:
begin_time = 0.0
return run_bennett(bar, begin_time, end_time, nresamples)
return (None, bar.err + '\n' + repr(e), [], [])
[docs]def are_times_insane(begin_time, end_time):
"""
Are the given begin and end times reasonable?
"""
return ((end_time <= begin_time) and
end_time != -1.0) or (end_time == 0) or (begin_time < 0)
[docs]def get_delta_time(begin_time, end_time, delta_time, window=0):
if window >= end_time - begin_time:
# we can only use a single possibly truncated window
return 0
if isinstance(delta_time, str):
tokens = delta_time.split(":")
if tokens[1] == "points":
num_point = int(tokens[0])
if window:
num_point -= 1
span = end_time - begin_time - window
delta_time = span / num_point
else:
raise ValueError("Wrong syntax: %s" % delta_time)
return delta_time
[docs]def calc_free_energy_time_function(dir,
last_time,
n_win,
temperature=300.0,
begin_time=100.0,
end_time=-1.0,
delta_time=30.0,
random_seed=2111839):
"""
Calculates the free energy as a function of time.
"""
try:
end_time, dt = cleanup_time(begin_time, end_time, last_time, delta_time)
except TimeSanityException:
return [], []
stop_times = _float_range(begin_time, end_time, dt, closed=False) + dt
start_times = (begin_time,) * len(stop_times)
time_ranges = list(zip(start_times, stop_times, stop_times))
return calc_time_curve(dir, n_win, temperature, begin_time, end_time,
random_seed, time_ranges)
[docs]def calc_free_energy_rtime_function(dir,
last_time,
n_win,
temperature=300.0,
begin_time=100.0,
end_time=-1.0,
delta_time=30.0,
random_seed=2111839):
"""
Calculates the free energy as a function of reversed time.
"""
try:
end_time, dt = cleanup_time(begin_time, end_time, last_time, delta_time)
except TimeSanityException:
return [], []
label_times = _float_range(begin_time, end_time, dt, closed=False)
stop_times = (last_time,) * len(label_times)
start_times = last_time - label_times
time_ranges = list(zip(start_times, stop_times, label_times))
return calc_time_curve(dir, n_win, temperature, begin_time, end_time,
random_seed, time_ranges)
[docs]def calc_free_energy_stime_function(dir,
last_time,
n_win,
temperature=300.0,
begin_time=100.0,
end_time=-1.0,
delta_time=30.0,
window=500.0,
random_seed=2111839):
"""
Calculates the free energy as a function of time with sliding window.
"""
try:
end_time, dt = cleanup_time(begin_time, end_time, last_time, delta_time,
window)
except TimeSanityException:
return [], []
# cleanup_time will return 0 if window is larger than end_time - begin_time
# in that case we use a single possibly truncated window
if dt == 0:
start_times, stop_times = (begin_time,), (end_time,)
else:
stop_times = _float_range(begin_time + window, end_time, dt)
# assert stop_times.size, "expected _float_range to return at least " \
# "one point"
start_times = stop_times - window
time_ranges = list(zip(start_times, stop_times, start_times))
return calc_time_curve(dir, n_win, temperature, begin_time, end_time,
random_seed, time_ranges)
# backward compatibility
calc_freeenergy_time_function = calc_free_energy_time_function
calc_freeenergy_rtime_function = calc_free_energy_rtime_function
calc_freeenergy_stime_function = calc_free_energy_stime_function
[docs]class TimeSanityException(Exception):
pass
[docs]def cleanup_time(begin_time, end_time, last_time, delta_time, window=0):
if (end_time < 0 or end_time > last_time):
end_time = last_time
if are_times_insane(begin_time, end_time):
raise TimeSanityException
dt = get_delta_time(begin_time, end_time, delta_time, window)
return end_time, dt
[docs]def calc_time_curve(dir, n_win, temperature, begin_time, end_time, random_seed,
time_ranges):
"""
Calculates the free energy as a function of time_ranges
"""
bar = init_bennett(dir, n_win, temperature, begin_time, end_time,
random_seed)
data = []
last_frame = len(time_ranges) - 1
for i, (start_time, stop_time, time) in enumerate(time_ranges):
# Include the uncertainty every 10th frame, and the last frame
nresamples = 100 if i % 10 == 0 or i == last_frame else 0
result, err, dF, results = run_bennett(bar, start_time, stop_time,
nresamples)
if result is not None:
data.append((
time,
result,
dF,
))
bar.close_output()
if len(results) == 0:
raise ValueError("Could not process data: %s" % err)
return data, results
[docs]class DeltaEnergy(object):
"""
"""
[docs] def __init__(self):
self.forward = []
self.reversed = []
self.time = []
[docs]def read_dE_file(dE_fname, time_range=None):
if not time_range:
time_range = [
0.0,
float("inf"),
]
with open(dE_fname, "r") as fh:
lines = fh.read().split("\n")
dE = DeltaEnergy()
for line in lines:
line = line.strip()
if (line != "" and line[0] != "#"):
time, reversed, forward = [float(e) for e in line.split()]
if (time >= time_range[0] and time <= time_range[1]):
dE.time.append(time)
dE.reversed.append(reversed)
dE.forward.append(forward)
return dE
[docs]def calc_work_prob_distr(energy, energy_range=None):
"""
"""
if (energy_range is None):
e_min = min(energy)
e_max = max(energy)
e_span = e_max - e_min
pad = e_span * 0.05
energy_range = [
e_min - pad,
e_max + pad,
(e_span + 2 * pad) * 0.009,
]
num_bin = 1
if energy_range[2] != 0:
num_bin = int((energy_range[1] - energy_range[0]) / energy_range[2]) + 1
else:
# Special case if energies are all the same
# have one bin and set the energy range != 0
energy_range[2] = 1
bin = [0] * num_bin
num_dat = 0
for e in energy:
index = int((e - energy_range[0]) / energy_range[2])
if (index >= 0 and index < num_bin):
bin[index] += 1
num_dat += 1
x = [0] * num_bin
for i in range(num_bin):
bin[i] /= num_dat * energy_range[2]
x[i] = i * energy_range[2] + energy_range[0]
return x, bin
[docs]def calc_forward_reversed_work_overlap(dE0, dE1):
"""
"""
dE1.reversed = [e * -1 for e in dE1.reversed]
e_for_avg = numpy.mean(dE0.forward)
e_rev_avg = numpy.mean(dE1.reversed)
e_for_std = numpy.std(dE0.forward)
e_rev_std = numpy.std(dE1.reversed)
e_min = min(e_rev_avg - 3 * e_rev_std, e_for_avg - 3 * e_for_std)
e_max = max(e_rev_avg + 3 * e_rev_std, e_for_avg + 3 * e_for_std)
e_span = e_max - e_min
pad = e_span * 0.05
energy_range = [
e_min - pad,
e_max + pad,
(e_span + 2 * pad) * 0.009,
]
x, prob0 = calc_work_prob_distr(dE0.forward, energy_range)
x, prob1 = calc_work_prob_distr(dE1.reversed, energy_range)
def refine_energy_range(x, prob, bin_size):
x_min = x[0]
x_max = x[-1]
inte = 0.0
for i, e in enumerate(prob):
inte += e * bin_size
if (inte > 0.002):
if (i > 0):
x_min = x[i - 1]
break
inte = 0.0
prob.reverse()
for i, e in enumerate(prob):
inte += e * bin_size
if (inte > 0.002):
if (i > 0):
x_max = x[-i - 1]
break
return x_min, x_max
x0_min, x0_max = refine_energy_range(x, prob0, energy_range[2])
x1_min, x1_max = refine_energy_range(x, prob1, energy_range[2])
e_min = min(x0_min, x1_min)
e_max = max(x0_max, x1_max)
e_span = e_max - e_min
pad = e_span * 0.05
energy_range = [
e_min - pad,
e_max + pad,
(e_span + 2 * pad) * 0.009,
]
x, prob0 = calc_work_prob_distr(dE0.forward, energy_range)
x, prob1 = calc_work_prob_distr(dE1.reversed, energy_range)
return x, prob0, prob1
[docs]def calc_lambda_window_overlap(dE_fname0, dE_fname1, time_range):
"""
"""
dE0 = read_dE_file(dE_fname0, time_range)
dE1 = read_dE_file(dE_fname1, time_range)
if dE0.forward and dE1.forward and dE0.reversed and dE1.reversed:
return calc_forward_reversed_work_overlap(dE0, dE1)
raise RuntimeError("Simulation is too short, not enough data for analysis. "
"The simulation should be at least %.2f ps long" % \
time_range[0])
[docs]def plot_lambda_window_overlap(dE_fname0,
dE_fname1,
out_fname=None,
legend=None,
time_range=None,
filename=None,
reporter=None):
"""
"""
if not time_range:
time_range = [
0.0,
float("inf"),
]
x, prob0, prob1 = calc_lambda_window_overlap(dE_fname0, dE_fname1,
time_range)
if (out_fname):
with open(out_fname, "w") as fh:
for i in range(len(x)):
print(x[i], prob0[i], prob1[i], file=fh)
if (reporter):
return reporter.plot(x,
prob0,
prob1,
x_label="energy (kcal/mol)",
legend=legend,
filename=filename)
[docs]def calc_lambda_sim_matrix(num_lambda, *gibbs_dname, **kw):
"""
"""
import schrodinger.application.desmond.gchart as gchart
num_dname = len(gibbs_dname)
traj_length = kw["traj_length"] if ("traj_length" in kw) else 2000.0
temperature = kw["temperature"] if ("temperature" in kw) else 300.0
mat = [] # mat[i][j], i is lambda number, j is simulation number.
for i in range(num_lambda - 1):
sim = []
for g0 in gibbs_dname:
dE_fname0 = "gibbs.%d.dE" % i
dE_fname1 = "gibbs.%d.dE" % (i + 1)
util.remove_file("gibbs.0.dE")
util.symlink(os.path.join(g0, dE_fname0), "gibbs.0.dE")
for g1 in gibbs_dname:
util.remove_file("gibbs.1.dE")
util.symlink(os.path.join(g1, dE_fname1), "gibbs.1.dE")
data = calc_free_energy_rtime_function(".",
2,
temperature=temperature)
with open("%s_%d-%s_%d-rtime" % (
g0,
i,
g1,
i + 1,
), "w") as fh:
for d in data:
print(d[0], d[1].val, d[1].unc, file=fh)
url = gchart.get_xy_url([d[0] for d in data],
[d[1].val for d in data],
err_y=[[d[1].unc for d in data]],
x_label="time (ps)")
print("# " + url, file=fh)
print("# <img src=\"%s\" />" % (url.replace("&", "&"),),
file=fh)
closest_time = float("inf")
val_at_closest_time = float("inf")
for d in data:
if (abs(d[0] - traj_length) < closest_time):
val_at_closest_time = d[1]
sim.append(val_at_closest_time)
mat.append(sim)
# mat[i][j], i = 0 is the name of combination, i >= 1 is the lambda index, j is the result, which is a 2-tuple.
# The first element of the 2-tuple is the free energy, the second is the probability.
new_mat = [[]]
for g0 in gibbs_dname:
for g1 in gibbs_dname:
new_mat[0].append((
g0,
g1,
))
for i, m in enumerate(mat):
new_mat.append([])
prob = []
for e in m:
prob.append(math.exp(-e.val / temperature / 1.9872E-3))
j = 0
prob_sum = []
for p in prob:
if (j % num_dname == 0):
prob_sum.append(0)
prob_sum[-1] += p
j += 1
for j, p in enumerate(prob):
prob[j] = old_div(p, prob_sum[old_div(j, num_dname)])
for e, p in zip(m, prob):
new_mat[-1].append((
e,
p,
))
mat_fname = kw.get("mat_fname")
if mat_fname:
with open(mat_fname, "w") as fh:
print("# lambda,", end=' ', file=fh)
for m in new_mat[0]:
print(("%s_%s," % (m[0], m[1])), end=' ', file=fh)
print(file=fh)
for i, m in enumerate(new_mat[1:]):
print(i, end=' ', file=fh)
for e in m:
print(("%s (%g)" % (str(e[0]), e[1])), end=' ', file=fh)
print(file=fh)
return new_mat
def _get_pathway_helper(pathway, mat):
"""
:param pathway: a list of tuples. Each tuple consists of three elements: structure name, free energy, and probability.
"""
last_step_name = pathway[-1][0]
next_step_option = mat[0]
all_pathway = []
next_step_mat = [next_step_option] + mat[2:]
is_end = mat[2:] == []
for o, m in zip(next_step_option, mat[1]):
if (last_step_name == o[0]):
new_pathway = copy.copy(pathway)
new_pathway.append((
o[1],
m[0],
m[1],
))
if (is_end):
all_pathway.append(new_pathway)
else:
all_pathway += _get_pathway_helper(new_pathway, next_step_mat)
return all_pathway
[docs]def get_pathway(mat):
"""
"""
next_step_option = mat[0]
all_pathway = []
next_step_mat = [next_step_option] + mat[1:]
this_step = set()
for e in next_step_option:
this_step.add(e[0])
for e in this_step:
all_pathway += _get_pathway_helper([
(
e,
measurement.Measurement(0.0, 0.0),
1.0,
),
], next_step_mat)
for p in all_pathway:
free_energy = 0.0
prob = 1.0
for step in p:
free_energy += step[1]
prob *= step[2]
return all_pathway
[docs]def print_pathway(pathway):
"""
"""
pathway_sum = []
for i, p in enumerate(pathway):
print(i + 1, " ")
for step in p[:-1]:
sys.stdout.write("%s<=>" % step[0])
print(("%s:" % p[-1][0]), end=' ')
free_energy_sum = 0.0
prob = 1.0
for step in p:
free_energy_sum += step[1]
prob *= step[2]
pathway_sum.append((
p[0][0],
free_energy_sum,
prob,
))
print(("%s (%.4g); detail:" % (
str(free_energy_sum),
prob,
)), end=' ')
for step in p[1:]:
print(("%s (%.4g)," % (
str(step[1]),
step[2],
)), end=' ')
print("end")
final_sum = {}
for e in pathway_sum:
final_sum[e[0]] = 0.0
for e in pathway_sum:
final_sum[e[0]] += e[1] * e[2]
print("\nSummary:")
for e in final_sum:
print(e, final_sum[e])
[docs]class FreeEnergyContrib(object):
"""
"""
[docs] def __init__(self, coulomb=None, vdw=None, bonded=None):
"""
"""
self.fec_coulomb = coulomb
self.fec_vdw = vdw
self.fec_bonded = bonded
[docs]def calc_contrib(fname, cfg_fname):
"""
"""
contrib = FreeEnergyContrib()
with open(cfg_fname) as fh:
sea_map = config.sea.Map(fh.read())
if ("gibbs" not in sea_map.force.term.list.val and
"gibbs" not in sea_map.mdsim.plugin.list.val):
raise TypeError(
"Attempted to calculate free-energy components from non-FEP simulation."
)
if ("gibbs" in sea_map.force.term.list.val):
gibbs = sea_map.force.term.gibbs
elif ("gibbs" in sea_map.mdsim.plugin.list.val):
gibbs = sea_map.mdsim.plugin.gibbs
if (gibbs.type.val == "alchemical"):
return contrib
# Checks if vdw schedule has overlap with the coulomb schedule.
vdw = gibbs.weights.vdw
coul = gibbs.weights.es
prev = None
for v, c in zip(vdw, coul):
if (v.val < 1.0 and c.val > 0):
raise ValueError(
"Cannot decompose the free energy due to schedule overlap.")
if (c.val > 0 and prev and prev.val < 1.0):
raise ValueError(
"Cannot decompose the free energy due to schedule overlap.")
prev = v
# Reads the output file to get dG.
dG = []
with open(fname, "r") as fh:
for line in fh:
line = line.strip()
if (line[0] != "#"):
a = line.split()
dG.append(measurement.Measurement(a[0], a[1]))
# Gets the startings and endings of the VDW and Coulomb schedules.
i_vdw_s = None
i_vdw_e = None
i_coul_s = None
i_coul_e = None
for i, v in enumerate(vdw):
if (v.val == 0.0):
i_vdw_s = i
if (v.val == 1.0):
i_vdw_e = i
break
for i, c in enumerate(coul):
if (c.val == 0.0):
i_coul_s = i
if (c.val == 1.0):
i_coul_e = i
break
dG_vdw = measurement.Measurement(0.0, 0.0)
dG_coul = measurement.Measurement(0.0, 0.0)
if (i_vdw_s is not None and i_vdw_e is not None):
for i in range(i_vdw_s, i_vdw_e):
dG_vdw += dG[i]
if (i_coul_s is not None and i_coul_e is not None):
for i in range(i_coul_s, i_coul_e):
dG_coul += dG[i]
contrib.fec_vdw = dG_vdw
contrib.fec_coulomb = dG_coul
contrib.fec_bonded = measurement.Measurement(0.0, 0.0)
return contrib
[docs]def correct_restr(egout0, egout1, fname_out):
"""
"""
ene0 = get_energy_table(egout0, ["posre"])[1]["posre"]
ene1 = get_energy_table(egout1, ["posre"])[1]["posre"]
len0 = len(ene0)
len1 = len(ene1)
if (len0 == 0 or len1 == 0):
return None
mean0, error0, stddev0 = get_mean(ene0)
mean1, error1, stddev1 = get_mean(ene1)
a0 = measurement.Measurement(mean0, error0)
a1 = measurement.Measurement(mean1, error1)
with open(fname_out, "w") as fh:
print("Restraint energy for lambda=0 (mean, error): %f, %f" % (
mean0,
error0,
),
file=fh)
print("Restraint energy for lambda=1 (mean, error): %f, %f" % (
mean1,
error1,
),
file=fh)
print("Correction due to restraints: %s" % str(a0 - a1), file=fh)
return a0 - a1
[docs]def long_range_dispersion_energy(r_cut, c6, rho):
"""
r_cut: cutoff radius (Angstrom).
c6: average dispersion coefficient (kcal/mol * Angstrom**6).
rho: number density (1/ Angstrom**3)
"""
return -4.0 / 3.0 * 3.1415926 * rho * c6 / (r_cut * r_cut * r_cut)
[docs]def get_field_from_log(field, fname):
"""
"""
with open(fname, "r") as fh:
content = fh.read()
PATTERN = re.compile(field + " *= *([.0-9]+)")
m = PATTERN.search(content)
if m:
value = m.group(1)
return float(value)
[docs]def get_number_density_from_cms(model):
"""
Returns a tuple of elements as follows:
1. the number density in the unit of 1 / Angstrom**3
2. number of atoms in the system
3. volume of the system
"""
num_atom = model.fsys_ct.atom_total
volume = cms.get_boxvolume(model.box)
return (
old_div(num_atom, volume),
num_atom,
volume,
)
[docs]def get_average_box_volume(fname):
"""
'fname' must be a `*_simbox.dat` file.
"""
with open(fname, "r") as fh:
lines = fh.readlines()
volume_sum = 0.0
num_data = 0
for line in lines:
line = line.strip()
if ("" != line and '#' != line[0]):
line = line.replace("Chemical time:", " ")
line = line.replace("ps, Box vectors:", " ")
box = [float(e) for e in line.split()]
volume_sum += cms.get_boxvolume(box[1:])
num_data += 1
return old_div(volume_sum, num_data)
[docs]def calc_long_range_dispersion_energy(model,
atom_list,
log_fname=None,
simbox_fname=None,
cfg_fname=None,
r_cut=-1,
average_coefficient=-1):
"""
"""
if (log_fname and os.path.isfile(log_fname)):
if r_cut == -1:
r_cut = get_field_from_log("r_cut", log_fname)
if average_coefficient == -1:
average_coefficient = get_field_from_log("average_dispersion",
log_fname)
rho, num_atom, volume = get_number_density_from_cms(model)
if (average_coefficient < 0):
average_coefficient = cms.calc_average_vdw_coeff(model.comp_ct)
if (r_cut is None or r_cut < 0):
if (cfg_fname and os.path.isfile(cfg_fname)):
with open(cfg_fname) as fh:
sea_map = config.sea.Map(fh.read())
r_cut = sea_map.cutoff_radius.val
else:
raise ValueError(
"Lack of information to determine the cutoff radius.")
if (simbox_fname and os.path.isfile(simbox_fname)):
average_volume = get_average_box_volume(simbox_fname)
rho = old_div(num_atom, average_volume)
vdw = model.get_vdw()
energy = 0.0
for atom in atom_list:
i_atom = int(atom)
atom_c6 = vdw[i_atom].c6()
#print "atom %d: type %s, sigma %f, epsilon %f, c6 %f" % \
# (i_atom, vdw[i_atom].atom_type[0], vdw[i_atom].c[0], vdw[i_atom].c[1], atom_c6,)
mixed_c6 = math.sqrt(atom_c6 * average_coefficient)
energy += long_range_dispersion_energy(r_cut, mixed_c6, rho)
return energy, r_cut, average_coefficient, rho
[docs]def calc_free_energy_for_abfe_cross_link(restr: List[Restraint],
temperature=300.0):
"""
Calculates correction for the cross link restraints in absolute binding
free energy simulations.
Reference: Boresch, Stefan, Franz Tettinger, Martin Leitgeb, and Martin
Karplus. Absolute Binding Free Energies: A Quantitative Approach for Their
Calculation. The Journal of Physical Chemistry B 107, no. 35 (September
2003): 9535-9551. http://pubs.acs.org/doi/abs/10.1021/jp0217839.
Note: We could not reproduce the numbers on the 5th row of Table 5 in the reference.
"""
stretch = []
angle = []
torsion = []
for e in copy.deepcopy(restr):
if (len(e.atom) == 2):
stretch.append(e)
elif (len(e.atom) == 3):
angle.append(e)
elif (len(e.atom) == 4):
torsion.append(e)
if (len(stretch) > 1):
raise ValueError("More than 1 stretch restraints")
if (len(angle) > 2):
raise ValueError("More than 2 angle restraints")
if (len(torsion) > 3):
raise ValueError("More than 3 torsion restraints")
volume = 1660.0
deg2rad = old_div(math.pi, 180.0)
for e in angle + torsion:
e.ref = e.ref * deg2rad
kT = temperature * BOLTZMANN
a = 8 * math.pi**2 * volume * \
math.sqrt(stretch[0].k * angle[0].k * angle[1].k * torsion[0].k * torsion[1].k * torsion[2].k) / \
(stretch[0].ref**2 * math.sin(angle[0].ref) * math.sin(angle[1].ref) * (2 * math.pi * kT)**3)
return -kT * math.log(a)
[docs]def calc_free_energy_for_abfe_cross_link_xu(restr: List[Restraint],
temperature=300.0):
r"""
Calculates correction for the cross link restraints in absolute binding free energy simulations.
We use Huangfeng Xu's formula. The partition function of the restraint terms is this::
Z = \int dr r^2 exp( -\beta kr (r - r0)^2 )
\Prod_{i = 1, 2 } \int_0^\pi d\theta_i \sin\theta_i exp( -\beta ka (\theta_i - \theta_{i0})^2 )
\Prod_{i = 1, 2, 3} \int_{\psi_{i0} - \pi}^{\psi_{i0} + \pi} d\psi_i exp( -\beta kd (\psi_i - \psi_{i0})^2 )
The integration over theta is approximated by integration over {-\inf, \inf}.
"""
pi = math.pi
exp = math.exp
sqrt = math.sqrt
stretch = []
angle = []
torsion = []
for e in copy.deepcopy(restr):
if (len(e.atom) == 2):
stretch.append(e)
elif (len(e.atom) == 3):
angle.append(e)
elif (len(e.atom) == 4):
torsion.append(e)
if len(stretch) > 1:
raise ValueError("More than 1 stretch restraints")
if len(angle) > 2:
raise ValueError("More than 2 angle restraints")
if len(torsion) > 3:
raise ValueError("More than 3 torsion restraints")
volume = 1660.0
deg2rad = pi / 180.0
for e in angle + torsion:
e.ref *= deg2rad
kT = temperature * BOLTZMANN
beta = 1 / kT
sqrtpi = sqrt(pi)
zz = 1 # Z / Z0
for s in stretch:
r0 = s.ref
bk = beta * s.k
if bk == 0:
continue
erf_r0_sqrt_bk = math.erf(r0 * sqrt(bk))
exp_r0_r0_bk = exp(-r0 * r0 * bk)
zr = exp_r0_r0_bk * r0 / 2 / bk + sqrtpi * (1 + 2 * r0 * r0 * bk) * (
1 + erf_r0_sqrt_bk) / (4 * sqrt(bk) * bk)
zz *= 4 * pi * zr / volume
for a in angle:
bk = beta * a.k
if bk == 0:
continue
zz *= exp(-1 / 4 / bk) * sqrt(pi / bk) * math.sin(a.ref) / 2
for d in torsion:
bk = beta * d.k
if bk == 0:
continue
zz *= sqrt(pi / bk) * math.erf(pi * sqrt(bk)) / 2 / pi
return kT * math.log(zz)
[docs]def calc_free_energy_correction_due_to_restraint(r, fc, temperature) -> float:
"""
:param r: Cross-link restraint to calculate the free energy correction for
:param fc: Three force constants for the stretch, the angle, and the
torsion restraints, respectively.
"""
from schrodinger.application.desmond.packages.restraint import \
CrossLinkRestraint
assert (isinstance(r, CrossLinkRestraint) and 3 == len(fc))
stretch = Restraint([r.A, r.a], r.Aa[0], fc[0])
angle0 = Restraint([r.B, r.A, r.a], r.BAa[0], fc[1])
angle1 = Restraint([r.A, r.a, r.b], r.Aab[0], fc[1])
torsion0 = Restraint([r.B, r.A, r.a, r.b], r.BAab[0], fc[2])
torsion1 = Restraint([r.A, r.a, r.b, r.c], r.Aabc[0], fc[2])
torsion2 = Restraint([r.C, r.B, r.A, r.a], r.CBAa[0], fc[2])
return calc_free_energy_for_abfe_cross_link_xu(
[stretch, angle0, angle1, torsion0, torsion1, torsion2], temperature)
[docs]def get_abfep_cross_link(model,
ligand,
r_clone,
traj_fname,
first_frame=0,
max_frame=256):
"""
the model should be created by the topo.read_cms function
`ligand` is a list of ligand atoms' indices.
"""
# We don't import the `restraint` module upfront because this module has
# dependency on the Desmond product.
from schrodinger.application.desmond.packages import restraint
from schrodinger.application.desmond.packages import traj
tr = traj.read_traj(traj_fname)[first_frame:first_frame + max_frame]
return restraint.gen_cross_link_restraint(
model,
tr,
ligand_asl="a.n %s" % " ".join(map(str, ligand)),
receptor_asl="protein and (not a.e H)",
r_clone=r_clone)
[docs]def calc_free_energy(dir, last_time: float, n_win: int, temperature: float,
bennett_options: Dict, random_seed: int) -> Dict:
"""
Return forward, reverse and slide energies.
"""
output = {}
func_dict = {
'forward_time': calc_free_energy_time_function,
'reversed_time': calc_free_energy_rtime_function,
'sliding_time': calc_free_energy_stime_function
}
for attr_str, energy_func in func_dict.items():
attr = bennett_options[attr_str]
# copy data from attr since it (attr) might have more variables.
kwargs = dict(temperature=temperature,
begin_time=attr['begin'],
end_time=attr['end'],
delta_time=attr['dt'],
random_seed=random_seed)
if attr_str == 'sliding_time':
kwargs['window'] = attr['window']
data, results = energy_func(dir, last_time, n_win, **kwargs)
output[attr_str] = (data, results)
return output
[docs]def plot_convergence(data,
dG_fname,
dF_fname_pattern,
x_label,
dF_color,
dG_color="black",
reporter=None) -> Dict[str, Union[List[str], str]]:
"""process the `data` and write png files.
Return a dictionary with format {'url': url, 'dF': df, 'dG': dG_fname}
"""
url = []
df = []
if len(data) > 2:
x, y, err = [], [], []
with open(dG_fname, "w") as fh:
for d in data:
print(d[0], '%.4f' % d[1].val, '%.4f' % d[1].unc, file=fh)
x.append(d[0])
y.append(d[1].val)
err.append(d[1].unc)
if x[-1] >= 10000:
x = [e * 0.001 for e in x]
x_label_ = x_label + " (ns)"
else:
x_label_ = x_label + " (ps)"
if reporter:
url.append(
reporter.plot(x,
y,
err_y=[err],
x_label=x_label_,
y_label="dG",
color=[dG_color],
filename=dG_fname + ".png"))
for i in range(len(data[0][2])):
fname = dF_fname_pattern % (i, i + 1)
df.append(fname)
x, y, err = [], [], []
with open(fname, "w") as fh:
for d in data:
print(d[0],
'%.4f' % d[2][i].val,
'%.4f' % d[2][i].unc,
file=fh)
x.append(d[0])
y.append(d[2][i].val)
err.append(d[2][i].unc)
if x[-1] >= 10000:
x = [e * 0.001 for e in x]
x_label_ = x_label + " (ns)"
else:
x_label_ = x_label + " (ps)"
if reporter:
url.append(
reporter.plot(x,
y,
err_y=[err],
x_label=x_label_,
y_label="dF",
legend=["%d_%d" % (i, i + 1)],
color=[dF_color],
filename=fname + ".png"))
return {'url': url, 'dF': df, 'dG': dG_fname}