"""
A script for predicting temperature profile for REMD simulations. The
prediction is based on the work by Patriksson and van der Spoel [Phys. Chem.
Chem. Phys., 10:2073-2077 (2008). http://dx.doi.org/10.1039/b716554d].
This script can optionally generate a new .cms file with a portion of the
system made frozen. Such a treatment can increase the temperature span of
REMD.
Copyright Schrodinger, LLC. All rights reserved.
"""
# Contributors: Yujie Wu
import math
import os
import sys
from past.utils import old_div
import schrodinger.utils.cmdline as cmdline
from schrodinger.application.desmond import cms
from schrodinger.application.desmond.constants import CT_TYPE
try:
    import scipy.special as scipy_special
    erf = scipy_special.erf
except:
    def erf(x):
        """
        """
        # constants
        a1 = 0.254829592
        a2 = -0.284496736
        a3 = 1.421413741
        a4 = -1.453152027
        a5 = 1.061405429
        p = 0.3275911
        # Saves the sign of x.
        sign = 1 if (x >= 0) else -1
        x = abs(x)
        # A & S 7.1.26
        t = old_div(1.0, (1.0 + p * x))
        y = 1.0 - ((((
            (a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * math.exp(-x * x)
        return sign * y
FROZEN_ATOM_MASS_THRESHOLD = 1E9
def _is_water(site_block, constraint):
    """
    """
    # Checks the constraint block.
    if (len(constraint) == 1):
        if (constraint[0].func == "HOH"):
            return True
        # Checks the sites block.
    num_o, num_h = 0, 0
    for site in site_block:
        if (site.mass > 0):
            if (int(abs(site.mass - 16.0) + 0.01) == 0):
                num_o += 1
            elif (int(abs(site.mass - 1.0) + 0.01) == 0):
                num_h += 1
            else:
                return False
    if (num_o == 1 and num_h == 1):
        return True
    return False
def _get_num_wateratom(model,
                       frozen_atom_mass_threshold=FROZEN_ATOM_MASS_THRESHOLD):
    """
    """
    num_free = 0
    num_frozen = 0
    num_total = 0
    constr_list = model.get_constraint()
    for ct, constr_block in zip(model.comp_ct, constr_list):
        if (_is_water(ct.ffio.site, constr_block)):
            free_atom = 0
            for site in ct.ffio.site:
                if (site.mass > 0 and site.mass < frozen_atom_mass_threshold):
                    free_atom += 1
            if (free_atom > 3):
                raise ValueError("water molecule has %d atoms!" % free_atom)
            num_free += int(ct.atom_total * free_atom / 3.0 + 0.01)
            num_frozen += int(ct.atom_total * (3 - free_atom) / 3.0 + 0.01)
            num_total += ct.atom_total
    return num_free, num_frozen, num_total
def _get_num_nonwateratom(model,
                          frozen_atom_mass_threshold=FROZEN_ATOM_MASS_THRESHOLD
                         ):
    """
    """
    num_free = 0
    num_frozen = 0
    num_total = 0
    constr_list = model.get_constraint()
    for ct, constr_block in zip(model.comp_ct, constr_list):
        if (not _is_water(ct.ffio.site, constr_block)):
            free_atom = 0
            num_site = len(ct.ffio.site)
            for site in ct.ffio.site:
                if (site.mass > 0 and site.mass < frozen_atom_mass_threshold):
                    free_atom += 1
            num_free += int(ct.atom_total * free_atom / num_site + 0.01)
            num_frozen += int(ct.atom_total *
                              (num_site - free_atom) / num_site + 0.01)
            num_total += ct.atom_total
    return num_free, num_frozen, num_total
def _get_num_constraint(model,
                        frozen_atom_mass_threshold=FROZEN_ATOM_MASS_THRESHOLD):
    """
    """
    num_constraint = 0
    site_list = [ct.ffio.site for ct in model.comp_ct]
    constr_list = model.get_constraint()
    for site_block, constraint_block, unroll in zip(site_list, constr_list,
                                                    model._unroll):
        num_constraint_of_this_ct = 0
        for constraint in constraint_block:
            if (constraint.func == "HOH"):
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 2
                if (site_block[constraint.atom_j].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_k].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
            if (constraint.func == "AH1"):
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_j].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
            if (constraint.func == "AH2"):
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_j].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_k].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
            if (constraint.func == "AH3"):
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_j].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_k].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
            if (constraint.func == "AH4"):
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_j].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_k].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_l].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
            if (constraint.func == "AH5"):
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_j].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_k].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_l].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
                if (site_block[constraint.atom_i].mass <
                        frozen_atom_mass_threshold or
                        site_block[constraint.atom_m].mass <
                        frozen_atom_mass_threshold):
                    num_constraint_of_this_ct += 1
        num_constraint += num_constraint_of_this_ct * unroll
    return num_constraint
def _get_degrees_of_freedom(
        model, frozen_atom_mass_threshold=FROZEN_ATOM_MASS_THRESHOLD):
    """
    """
    free_wateratom, frozen_wateratom, total_wateratom = _get_num_wateratom(
        model, frozen_atom_mass_threshold)
    free_nonwateratom, frozen_nonwateratom, total_nonwateratom = _get_num_nonwateratom(
        model, frozen_atom_mass_threshold)
    num_constraint = _get_num_constraint(model, frozen_atom_mass_threshold)
    return (free_wateratom + free_nonwateratom) * 3 - num_constraint
[docs]def predict_temperature(low_temp,
                        high_temp,
                        exchange_probability,
                        model,
                        frozen_atom_mass_threshold=FROZEN_ATOM_MASS_THRESHOLD,
                        should_fix=True):
    """
    :param low_temp: minimum temperature
    :param high_temp: maximum temperature
    :param exchange_probability: A good default is 30% (0.3). 
    :param model: should be a `Cms` object.
    :return: a tuple of (temp_profile, prob_profile):
        - temp_profile is a list of temperature values.
        - prob_profile is a list of predicted exchange probabilities.
    """
    KJ2KCAL = old_div(1.0, 4.184)
    A0 = -59.22  # in kJ/mol
    A1 = 0.07594  # in kJ/(mol * K)
    B0 = -22.84  # in kJ/mol
    B1 = 0.01347  # in kJ/(mol * K)
    D0 = 1.168  # in kJ/mol
    D1 = 0.002976  # in kJ/(mol * K)
    # Converts to kcal
    A0 *= KJ2KCAL
    A1 *= KJ2KCAL
    B0 *= KJ2KCAL
    B1 *= KJ2KCAL
    D0 *= KJ2KCAL
    D1 *= KJ2KCAL
    #
    KB = 0.0083145 * KJ2KCAL
    #
    ndf = _get_degrees_of_freedom(model, frozen_atom_mass_threshold)
    nw = old_div(_get_num_wateratom(model, frozen_atom_mass_threshold)[0], 3)
    np = _get_num_nonwateratom(model, frozen_atom_mass_threshold)[0]
    nc = _get_num_constraint(model, frozen_atom_mass_threshold)
    def calc_probability(t1, t2):
        """
        Note that `t1` should be higher than `t2`.
        """
        c = old_div((old_div(1, t1) - old_div(1, t2)), KB)
        mu12 = (t1 - t2) * (A1 * nw + B1 * np - KB * nc / 2.0)
        sigma12 = math.sqrt(ndf * (D1 * D1 * (t2 * t2 + t1 * t1) + 2 * D1 * D0 *
                                   (t2 + t1) + 2 * D0 * D0))
        #print "t1-t2, mu12, sigma12 = ", t1 - t2, mu12, sigma12
        part1 = old_div((1 + erf(-mu12 / sigma12 / 1.414)), 2.0)
        part_ = (mu12 + c * sigma12 * sigma12) / sigma12 / 1.414
        # Original formula: part2 = math.exp( c * mu12 + c * c * sigma12 * sigma12 / 2.0 ) * (1 + erf( part_ )) / 2.0
        # The problem with this formula is that the first term can likely cause math overflow exception if the exponent is big.
        # In this situation, the second term is effectively zero. (So the final result is still zero.)
        # We need to revise the order of the computing a bit: Calculate the second term first, and if it is zero, then we set
        # the final result directly to zero without calculating the first term.
        part_ = old_div((1 + erf(part_)), 2.0)
        if (part_ < 1E-200):
            part2 = 0.0
        else:
            part2 = math.exp(c * mu12 + c * c * sigma12 * sigma12 / 2.0) * part_
        probability = part1 + part2
        #print "part1, part_, part2, probability = ", part1, part_, part2, probability
        return probability
    def find_t2(t1, t2_guess, t2_bottom, t2_top):
        if (abs(t2_guess - t2_top) < 1E-2):
            return t2_top, calc_probability(t2_top, t1)
        p = calc_probability(t2_guess, t1)
        if (p + 1E-2 < exchange_probability):
            return find_t2(t1, (t2_guess + t2_bottom) * 0.5, t2_bottom,
                           t2_guess)
        elif (p - 1E-2 > exchange_probability):
            return find_t2(t1, (t2_guess + t2_top) * 0.5, t2_guess, t2_top)
        return t2_guess, p
    temp_profile = [
        low_temp,
    ]
    prob_profile = []
    t2 = low_temp
    while (abs(t2 - high_temp) > 1E-2):
        t1 = temp_profile[-1]
        t2_guess = (t1 + high_temp) * 0.5
        t2, p = find_t2(t1, t2_guess, t1, high_temp)
        temp_profile.append(t2)
        prob_profile.append(p)
    #print "Prefix profile:"
    #print temp_profile
    #print prob_profile
    if (should_fix and prob_profile[-1] > 2 * exchange_probability and
            len(prob_profile) > 2):
        delta_temp = []
        for i, temp in enumerate(temp_profile[1:]):
            delta_temp.append(temp - temp_profile[i])
        ratio = []
        for dt in delta_temp[:-1]:
            ratio.append(old_div(dt, delta_temp[0]))
        s = sum(ratio)
        x = old_div(delta_temp[-1], s)
        delta_temp = delta_temp[:-1]
        for i, t in enumerate(delta_temp):
            delta_temp[i] += ratio[i] * x
        tp = [
            temp_profile[0],
        ]
        for dt in delta_temp:
            tp.append(tp[-1] + dt)
        pp = []
        for i, t in enumerate(tp[:-1]):
            p = calc_probability(tp[i + 1], t)
            pp.append(p)
        temp_profile = tp
        prob_profile = pp
    return temp_profile, prob_profile 
def _split_ct(ct, selected_atom):
    """
    """
    for a in ct.atom:
        a.property["i_des_remd_orig_index"] = int(a)
    ct0 = ct  # `ct0' will contain only the non-selected atoms.
    ct1 = ct.copy()  # `ct1' will contain only the selected atoms.
    #ct1._ffh = mm.mmffio_ff_duplicate( ct0._ffh )
    # Gets molecule indices.
    mol0 = set(range(1, ct0.mol_total + 1))
    mol1 = set()
    ct0.mol_total
    for i_atom in selected_atom:
        mol1.add(ct0.atom[i_atom].molecule_number)
    mol0 -= mol1
    mol0 = list(mol0)
    mol1 = list(mol1)
    mol0.sort()
    mol1.sort()
    # Resets restrain block in.
    rb = {}
    for e in model.get_restrain():
        if (len(e.atom) == 1):
            rb[e.atom[0]] = e
    restrained_atom = set(list(rb))
    restrained_atom0 = restrained_atom - selected_atom
    restrained_atom1 = restrained_atom & selected_atom
    ct0.deleteAtoms(selected_atom)
    ct1.deleteAtoms(set(range(1, ct1.atom_total + 1)) - selected_atom)
    index_map0 = {}
    index_map1 = {}
    for a in ct0.atom:
        index_map0[a.property["i_des_remd_orig_index"]] = int(a)
    for a in ct1.atom:
        index_map1[a.property["i_des_remd_orig_index"]] = int(a)
    rb0 = {}
    rb1 = {}
    for i_atom in restrained_atom0:
        rb0[index_map0[i_atom]] = rb[i_atom]
    for i_atom in restrained_atom1:
        rb1[index_map1[i_atom]] = rb[i_atom]
    ct0.ffio.deleteRestraints(list(range(1, len(ct0.ffio.restraint) + 1)))
    ct1.ffio.deleteRestraints(list(range(1, len(ct1.ffio.restraint) + 1)))
    def set_restrain(ct, rb):
        for e in rb.values():
            if (isinstance(e.k, list)):
                # Harmonic position restraint
                a = ct.ffio.addRestraint()
                a.ai = e.atom[0]
                a.c1 = e.k[0]
                a.c2 = e.k[1]
                a.c3 = e.k[2]
                a.t1 = e.ref[0]
                a.t2 = e.ref[1]
                a.t3 = e.ref[2]
                a.funct = "harm"
            elif (isinstance(e.ref, list) and len(e.ref) == 3):
                # position fbhw
                a = ct.ffio.addPosfbhws()
                a.ai = e.atom[0]
                a.fc = e.k
                a.x0 = e.ref[0]
                a.y0 = e.ref[1]
                a.z0 = e.ref[2]
                a.sigma = e.sigma
    set_restrain(ct0, rb0)
    set_restrain(ct1, rb1)
    # Resets the pseudo block.
    pseudo = ct0.ffio.pseudo
    if (len(pseudo) > 0):
        pseudo0 = []
        pseudo1 = []
        for i in mol0:
            pseudo0.append(cms.Pseudo(pseudo[i].x, pseudo[i].y, pseudo[i].z))
        for i in mol1:
            pseudo1.append(cms.Pseudo(pseudo[i].x, pseudo[i].y, pseudo[i].z))
        ct0.deletePseudos(list(range(1, len(ct0.ffio.pseudo) + 1)))
        ct1.deletePseudos(list(range(1, len(ct1.ffio.pseudo) + 1)))
        ct0.ffio.addPseudos(len(pseudo0))
        ct1.ffio.addPseudos(len(pseudo1))
        for i, e in enumerate(pseudo0):
            ct0.ffio.pseudo[i].x = e.x
            ct0.ffio.pseudo[i].y = e.y
            ct0.ffio.pseudo[i].z = e.z
        for i, e in enumerate(pseudo1):
            ct1.ffio.pseudo[i].x = e.x
            ct1.ffio.pseudo[i].y = e.y
            ct1.ffio.pseudo[i].z = e.z
    return ct0, ct1
def _set_freezing_atommass(model, mass_scale):
    """
    """
    for ct in model.comp_ct:
        if (ct.frozen_atom is not None):
            if (ct.frozen_atom == -1):
                # All atom sites will be set with the new mass.
                for e in ct.ffio.site:
                    e.mass *= mass_scale
            else:
                # Only some atom sites will be set with the new mass.
                for i in ct.frozen_atom:
                    ct.ffio.site[i].mass *= mass_scale
            ct.property["r_des_frozenatommass_threshold"] = mass_scale
[docs]def freeze_atoms(model,
                 asl,
                 frozen_atom_mass_threshold=FROZEN_ATOM_MASS_THRESHOLD):
    """
    Freeze atoms in <model> specificed by the <asl>.
    """
    free_atom = model.select_atom_comp(f"(not {asl} and atom.{CT_TYPE} '{CT_TYPE.VAL.SOLUTE}') or " \
                                       
f"(fillmol(not {asl} and not atom.{CT_TYPE} '{CT_TYPE.VAL.SOLUTE}'))")
    frozen_atom = []
    for ct, selected_atom in zip(model.comp_ct, free_atom):
        total = set(range(1, ct.atom_total + 1))
        free = set(selected_atom)
        frozen_atom.append(total - free)
    new_ct = []
    for ct, unroll, selected_atom in zip(model.comp_ct, model._unroll,
                                         frozen_atom):
        num_selected_atom = len(selected_atom)
        if (unroll == 1):
            ct.frozen_atom = None if (num_selected_atom == 0) else selected_atom
            new_ct.append(ct)
        else:
            if (num_selected_atom == ct.atom_total):
                ct.frozen_atom = -1  # Value `-1' means all atoms will be frozen.
                new_ct.append(ct)
            elif (num_selected_atom == 0):
                ct.frozen_atom = None  # Value `None' means none of atoms will be frozen.
                new_ct.append(ct)
            else:
                free_ct, frozen_ct = _split_ct(ct, selected_atom)
                free_ct.frozen_atom = None  # Value `None' means none of atoms will be frozen.
                frozen_ct.frozen_atom = -1
                new_ct.extend([
                    free_ct,
                    frozen_ct,
                ])
    model.comp_ct = new_ct
    model.synchronize_fsys_ct()
    _set_freezing_atommass(model, frozen_atom_mass_threshold)
    return model 
if (__name__ == '__main__'):
    # Parses arguments.
    cmd = cmdline.SingleDashOptionParser(
        usage="$SCHRODINGER/run %prog [options] <input .cms file>\n\n"
        "A script for predicting temperature profile for REMD simulations.\n"
        "It can optionally generate a new .cms file with a designated\n"
        "portion of the system made effectively frozen. Such a treatment can\n"
        "increase the temperature span of REMD.",
        version_source="$Revision: 1.7 $")
    cmd.add_option("-t",
                   dest="temp",
                   help="temperature range. Default: '300 400' (with quotes)")
    cmd.add_option("-asl", dest="asl", help="ASL expression of frozen atoms")
    cmd.add_option("-prob",
                   dest="prob",
                   type="float",
                   help="replica exchange probability. Default: 0.3")
    cmd.add_option("-interval",
                   dest="interval",
                   type="float",
                   help="exchange trial interval (ps). Default: 12")
    cmd.add_option("-o",
                   dest="out",
                   help="output .cms file. Default: predict_remd_temp-out.cms")
    cmd.add_option(
        "-mass-scale",
        dest="mass_scale",
        type="float",
        help=
        "scale factor to increase the masses of atoms to be frozen. We effectively freeze "
        "atoms in simulations by assigning them a huge mass. Default: 1E9")
    cmd.set_defaults(temp="300 400",
                     prob=0.3,
                     interval=12.0,
                     out="predict_remd_temp-out.cms",
                     mass_scale=1E9)
    opt, args = cmd.parse_args()
    num_args = len(args)
    if (num_args < 1):
        cmd.error("The input .cms file is not given.")
    elif (num_args > 1):
        cmd.error(
            "More than one input .cms file. You can specify only one input .cms file."
        )
    in_fname = args[0]
    if (not os.path.isfile(in_fname)):
        cmd.error("File not found: %s" % in_fname)
        sys.exit(1)
    temp = opt.temp.split()
    num_temp = len(temp)
    if (num_temp != 2):
        cmd.error(
            "The temperature range should be a pair of numbers with quotes: e.g., '300 400'."
        )
    low_temp = float(temp[0])
    high_temp = float(temp[1])
    if (opt.mass_scale <= 1):
        cmd.error(
            "The number you gave for -mass-scale should be larger than zero")
    if (opt.mass_scale < 1E3):
        print("predict_remd_temp.py: warning: The number you gave for -mass-scale is likely too small" \
              " to effectively freeze the atoms.\a")
    model = cms.Cms(file=in_fname)
    sys_fname = in_fname
    if (opt.asl is not None):
        print("Degrees of freedom of the original system:", \
              _get_degrees_of_freedom(model, frozen_atom_mass_threshold=opt.mass_scale))
        model = freeze_atoms(model,
                             opt.asl,
                             frozen_atom_mass_threshold=opt.mass_scale)
        model.write(opt.out)
        model = cms.Cms(file=opt.out)
        sys_fname = opt.out
        print("Degrees of freedom of the system after freezing some atoms:", \
              _get_degrees_of_freedom(model, frozen_atom_mass_threshold=opt.mass_scale))
    temp_profile, prob_profile = predict_temperature(
        low_temp,
        high_temp,
        opt.prob,
        model,
        should_fix=False,
        frozen_atom_mass_threshold=opt.mass_scale)
    i = 1
    print(" Replica#   Temperature   Probability  <Exch. Interval>")
    print("----------+-------------+-------------+----------------")
    print("%10d%14.3f%14.3s%16.3s" % (
        0,
        low_temp,
        "n/a",
        "n/a",
    ))
    for temp, prob in zip(temp_profile[1:], prob_profile):
        print("%10d%14.3f%14.3f%16.3f" % (
            i,
            temp,
            prob,
            old_div(opt.interval, prob),
        ))
        i += 1
    print()
    print("Here is how to use the temperatures:")
    print()
    print("1. If you want to launch a Desmond REMD job from console,")
    print("   add the following settings into the production simulate")
    print("   stage:")
    print("remd = [")
    for temp in temp_profile:
        print("   {temperature = %.3f}" % temp)
    print("]")
    print()
    print("2. If you want to launch the job from the Replica Exchange\n" \
          "   panel, you can set the Temperature profile to 'manual'\n" \
          "   and then fill in the temperatures.")