"""
Utilities for REMD.
Copyright Schrodinger, LLC. All rights reserved.
"""
# Contributors: Yujie Wu
import collections
import copy
import math
from past.utils import old_div
from typing import List
import numpy
import schrodinger.application.desmond.cms as cms
from schrodinger.application.desmond.constants import CT_TYPE
try:
    import scipy.special as scipy_special
    erf = scipy_special.erf
except:
    def erf(x):
        """
        Returns error function at x.
        """
        # constants
        a1 = 0.254829592
        a2 = -0.284496736
        a3 = 1.421413741
        a4 = -1.453152027
        a5 = 1.061405429
        p = 0.3275911
        # Saves the sign of x.
        sign = 1 if (x >= 0) else -1
        x = abs(x)
        # A & S 7.1.26
        t = old_div(1.0, (1.0 + p * x))
        y = 1.0 - ((((
            (a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * math.exp(-x * x)
        return sign * y
FROZEN_ATOM_MASS_THRESHOLD = 1E9
# Energy units in kcal/mol
A1 = 0.0181501
B1 = 0.0032194
D0 = 0.2791587
D1 = 0.0007113
KB = 1.9872131E-3
[docs]def is_water(site, constraint):
    """
    """
    # Checks the constraint block.
    if (len(constraint) == 1):
        if (constraint[1].funct == "HOH"):
            return True
        # Checks the sites block.
    num_o, num_h = 0, 0
    for e in site:
        if (e.mass > 0):
            if (int(abs(e.mass - 16.0) + 0.01) == 0):
                num_o += 1
            elif (int(abs(e.mass - 1.0) + 0.01) == 0):
                num_h += 1
            else:
                return False
    if (num_o == 1 and num_h == 1):
        return True
    return False 
[docs]def get_nonduplicated_atom(ct):
    """
    """
    atom = []
    for e in ct.fepio.atommap:
        if (e.ai < 0):
            atom.append(e.aj)
    return atom 
[docs]def get_num_wateratom(model, selected_atom):
    """
    """
    num_selected = 0
    num_unselected = 0
    num_total = 0
    for sa, ct in zip(selected_atom, model.comp_ct):
        atom_total = ct.atom_total
        if (ct.hasFepio()):
            nonduplicated_atom = get_nonduplicated_atom(ct)
            atom_total = len(nonduplicated_atom)
            sa = set(sa) & set(nonduplicated_atom)
        if (is_water(ct.ffio.site, ct.ffio.constraint)):
            num_selected += len(sa)
            num_unselected += atom_total - num_selected
            num_total += atom_total
    return num_selected, num_unselected, num_total 
[docs]def get_num_nonwateratom(model, selected_atom):
    """
    """
    num_selected = 0
    num_unselected = 0
    num_total = 0
    for sa, ct in zip(selected_atom, model.comp_ct):
        atom_total = ct.atom_total
        if (ct.hasFepio()):
            nonduplicated_atom = get_nonduplicated_atom(ct)
            atom_total = len(nonduplicated_atom)
            sa = set(sa) & set(nonduplicated_atom)
        if (not is_water(ct.ffio.site, ct.ffio.constraint)):
            num_selected += len(sa)
            num_unselected += atom_total - num_selected
            num_total += atom_total
    return num_selected, num_unselected, num_total 
[docs]def get_num_constraint(model, selected_atom):
    """
    """
    def accumulate_constraint(i, j, sa, shift, num_constraint):
        atom_pair = set([
            i + shift,
            j + shift,
        ])
        if (atom_pair <= sa):
            num_constraint += 1
    num_constraint = 0
    for sa, ct, unroll in zip(selected_atom, model.comp_ct, model._unroll):
        num_constraint_of_this_ct = 0
        sa = set(sa)
        for i in range(unroll):
            shift = i * ct.atom_total / unroll
            for e in ct.ffio.constraint:
                if (e.funct == "HOH"):
                    accumulate_constraint(e.ai, e.aj, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.aj, e.ak, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.ai, e.ak, sa, shift,
                                          num_constraint_of_this_ct)
                elif (e.funct == "AH1"):
                    accumulate_constraint(e.ai, e.aj, sa, shift,
                                          num_constraint_of_this_ct)
                elif (e.funct == "AH2"):
                    accumulate_constraint(e.ai, e.aj, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.ai, e.ak, sa, shift,
                                          num_constraint_of_this_ct)
                elif (e.funct == "AH3"):
                    accumulate_constraint(e.ai, e.aj, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.ai, e.ak, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.ai, e.al, sa, shift,
                                          num_constraint_of_this_ct)
                elif (e.funct == "AH4"):
                    accumulate_constraint(e.ai, e.aj, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.ai, e.ak, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.ai, e.al, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.ai, e.am, sa, shift,
                                          num_constraint_of_this_ct)
                elif (e.funct == "AH5"):
                    accumulate_constraint(e.ai, e.aj, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.ai, e.ak, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.ai, e.al, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.ai, e.am, sa, shift,
                                          num_constraint_of_this_ct)
                    accumulate_constraint(e.ai, e.an, sa, shift,
                                          num_constraint_of_this_ct)
                else:
                    raise SyntaxError("unknown type of constraint: %s" %
                                      e.funct)
        num_constraint += num_constraint_of_this_ct
    return num_constraint 
[docs]def get_degrees_of_freedom(model, selected_atom):
    """
    """
    selected_wateratom, unselected_wateratom, total_wateratom = get_num_wateratom(
        model, selected_atom)
    selected_nonwateratom, unselected_nonwateratom, total_nonwateratom = get_num_nonwateratom(
        model, selected_atom)
    num_constraint = get_num_constraint(model, selected_atom)
    return (selected_wateratom + selected_nonwateratom) * 3 - num_constraint 
[docs]def get_rest_params(model, asl):
    """
    Returns parameters that are required for temperature ladder prediction
    :type model: `schrodinger.application.desmond.cms.Cms`
    :type asl: str
    :return: the tuple with the following values: degrees of freedom, number of
        selected waters, number of selected non-waters, number of constraints.
    :rtype: tuple(int, int, int, int)
    """
    selected_atom = model.select_atom_comp(asl)
    if not sum(selected_atom, []):
        raise Exception(f"No atoms selected by ASL expression: {asl}")
    ndf = get_degrees_of_freedom(model, selected_atom)
    nw = old_div(get_num_wateratom(model, selected_atom)[0], 3)
    np = get_num_nonwateratom(model, selected_atom)[0]
    nc = get_num_constraint(model, selected_atom)
    return ndf, nw, np, nc 
[docs]def predict_temperature_ladder(temperature,
                               exchange_probability,
                               model,
                               asl,
                               should_fix=True,
                               floaty=False):
    """
    :param temperature: a tuple (t0, t1) temperature range in Kelvin
    :type  temperature: tuple(float, float)
    :type exchange_probability: float
    :type model: `schrodinger.application.desmond.cms.Cms`
    :type asl: str
    :param should_fix: ???
    :type  should_fix: bool
    :param floaty: ???
    :type  floaty: bool
    :return: returns a tuple with two lists. First list is a temperature
            profile, second list is probability profile.
    :rtype: tuple(list of floats, list of floats)
    """
    ndf, nw, np, nc = get_rest_params(model, asl)
    return _predict_temperature_ladder(temperature, exchange_probability, ndf,
                                       nw, np, nc, should_fix, floaty) 
[docs]def get_prob_from_temp_ladder(temp_ladder: List[float], model: cms.Cms,
                              asl: str) -> List[float]:
    """
    Return the probability given the temperature ladder `temp_ladder` and the
    corresponding cms `model` and asl `asl` for the hot atoms.
    """
    ndf, nw, np, nc = get_rest_params(model, asl)
    ps = []
    for t1, t2 in zip(temp_ladder, temp_ladder[1:]):
        max_t = max([t1, t2])
        min_t = min([t1, t2])
        ps.append(_calc_probability(max_t, min_t, ndf, nw, np, nc))
    return ps 
def _calc_probability(t1: float, t2: float, ndf: int, nw: int, np: int,
                      nc: int):
    """
    Note that t1 should be greater than or equal to t2 (t1>=t2)
    :param t1: high temperature (in Kelvin)
    :type  t1: float
    :param t2: low temperature (in Kelvin)
    :type  t2: float
    :param ndf: degrees of freedom
    :param nw: number of selected waters
    :param np: number of selected non-waters
    :param nc: number of constraints
    :rtype: float
    """
    assert t1 >= t2
    try:
        c = (((1 / t1) - (1 / t2)) / KB)
    except ZeroDivisionError:
        c = (-((1 / t1)) / KB)
    mu12 = (t1 - t2) * (A1 * nw + B1 * np - KB * nc / 2.0)
    sigma12 = math.sqrt(ndf * (D1**2 * (t2**2 + t1**2) + 2 * D1 * D0 *
                               (t2 + t1) + 2 * D0**2))
    part1 = ((1 + erf(-mu12 / sigma12 / 1.414)) / 2.0)
    part_ = (mu12 + c * sigma12 * sigma12) / sigma12 / 1.414
    part_ = 1 + erf(part_)
    if abs(part_) < 1E-10:
        return part1
    exp_ = numpy.exp(c * mu12 + c**2 * sigma12**2 / 2.0)
    part2 = exp_ * part_ / 2.0
    return part1 + part2
def _predict_temperature_ladder(temperature, exchange_probability, ndf, nw, np,
                                nc, should_fix, floaty):
    """
    Backend calculation for predicting temperature ladder
    """
    def find_t2(t1, t2_guess, t2_bottom, t2_top):
        if (abs(t2_guess - t2_top) < 0.1):
            if (floaty):
                t2_top += t2_top - t1
            else:
                return t2_top, _calc_probability(t2_top, t1, ndf, nw, np, nc)
        p = _calc_probability(t2_guess, t1, ndf, nw, np, nc)
        if (p + 1E-2 < exchange_probability):
            return find_t2(t1, (t2_guess + t2_bottom) * 0.5, t2_bottom,
                           t2_guess)
        elif (p - 1E-2 > exchange_probability):
            return find_t2(t1, (t2_guess + t2_top) * 0.5, t2_guess, t2_top)
        return t2_guess, p
    temp_low = float(min(temperature))
    temp_high = float(max(temperature))
    temp_profile = [temp_low]
    prob_profile = []
    t2 = temp_low
    while (t2 + 1E-2 < temp_high):
        t1 = temp_profile[-1]
        t2_guess = (t1 + temp_high) * 0.5
        t2, p = find_t2(t1, t2_guess, t1, temp_high)
        temp_profile.append(t2)
        prob_profile.append(p)
    if (should_fix and prob_profile[-1] > 2 * exchange_probability and
            len(prob_profile) > 2):
        delta_temp = []
        for i, temp in enumerate(temp_profile[1:]):
            delta_temp.append(temp - temp_profile[i])
        ratio = []
        for dt in delta_temp[:-1]:
            ratio.append(old_div(dt, delta_temp[0]))
        s = sum(ratio)
        x = old_div(delta_temp[-1], s)
        delta_temp = delta_temp[:-1]
        for i, t in enumerate(delta_temp):
            delta_temp[i] += ratio[i] * x
        tp = [temp_profile[0]]
        for dt in delta_temp:
            tp.append(tp[-1] + dt)
        pp = []
        for i, t in enumerate(tp[:-1]):
            p = _calc_probability(tp[i + 1], t, ndf, nw, np, nc)
            pp.append(p)
        temp_profile = tp
        prob_profile = pp
    return temp_profile, prob_profile
[docs]def predict_with_temp_and_exch(temp, exchange_probability, model, asl):
    """
    Given temperature range and exchange_probability, predicts the number of
    replica and returns the temperature ladder.
    """
    return predict_temperature_ladder(temp, exchange_probability, model, asl) 
[docs]def predict_with_nreplica_and_exch(n_replica, exchange_probability, base_temp,
                                   model, asl):
    """
    Given the base temperature, number of replicas, and exchange_probability,
    predicts the top temperature and returns the temperature ladder.
    :type  n_replica: int
    :type exchange_probability: float
    :type base_temp: float
    :type model: `schrodinger.application.desmond.cms.Cms`
    :type asl: str
    :return: a list of temperatures in Kelvin (length n_replica) and exchange a
            list of corresponding exchange probabilities (length n_replica-1)
    :rtype: (list of floats, list of floats)
    """
    high, low = base_temp + 20.0, base_temp
    trial_temp = high
    ndf, nw, np, nc = get_rest_params(model, asl)
    temp_profile, prob_profile = \
        
_predict_temperature_ladder((base_temp, trial_temp), exchange_probability,
                                    ndf, nw, np, nc, should_fix=True, floaty=True)
    while (len(temp_profile) != n_replica):
        if (len(temp_profile) < n_replica):
            if (trial_temp == high):
                tmp = high
                high += (high - low) * 2
                low = tmp
                trial_temp = high
            else:
                low = trial_temp
                trial_temp += (high - low) * 0.5
        else:
            high = trial_temp
            trial_temp -= (high - low) * 0.5
        temp_profile, prob_profile = \
            
_predict_temperature_ladder((base_temp, trial_temp), exchange_probability,
                                        ndf, nw, np, nc, should_fix=True, floaty=True)
    return temp_profile, prob_profile 
[docs]def predict_with_temp_and_nreplica(temperature, n_replica, model, asl):
    """
    Given the temperature range and number of replicas, predicts the exchange
    probability and returns the temperature ladder.
    :param temperature: a tuple of (min, max) temperatures, in Kelvin
    :type  temperature: tuple(float, float)
    :type exchange_probability: float
    :type base_temp: float
    :type model: `schrodinger.application.desmond.cms.Cms`
    :type asl: str
    :return: a list of temperatures (length n_replica) and exchange a list of
            corresponding exchange probabilities (length n_replica-1)
    :rtype: (list of floats, list of floats)
    """
    exchange_probability = 0.95
    temp_profile, prob_profile = predict_temperature_ladder(
        temperature, exchange_probability, model, asl)
    delta_exchange_probability = 0.1
    n = len(temp_profile)
    while (n > n_replica):
        exchange_probability -= delta_exchange_probability
        temp_profile, prob_profile = predict_temperature_ladder(
            temperature, exchange_probability, model, asl)
        n = len(temp_profile)
        if (n < n_replica):
            n = 1000000000
            exchange_probability += delta_exchange_probability
            delta_exchange_probability *= 0.5
    return temp_profile, prob_profile 
[docs]def split_ct(ct, selected_atom):
    """
    """
    for a in ct.atom:
        a.property["i_des_remd_orig_index"] = int(a)
    ct0 = ct  # `ct0' will contain only the non-selected atoms.
    ct1 = ct.copy()  # `ct1' will contain only the selected atoms.
    ct1._ffh = cms.mm.mmffio_ff_duplicate(ct0._ffh)
    # Gets molecule indices.
    mol0 = set(range(1, ct0.mol_total + 1))
    mol1 = set()
    ct0.mol_total
    for i_atom in selected_atom:
        mol1.add(ct0.atom[i_atom].molecule_number)
    mol0 -= mol1
    mol0 = list(mol0)
    mol1 = list(mol1)
    mol0.sort()
    mol1.sort()
    # Resets restrain block in.
    rb = cms.Ffio.get_restrain_block(ct0)
    restrained_atom = set(list(rb))
    restrained_atom0 = restrained_atom - selected_atom
    restrained_atom1 = restrained_atom & selected_atom
    ct0.deleteAtoms(selected_atom)
    ct1.deleteAtoms(set(range(1, ct1.atom_total + 1)) - selected_atom)
    index_map0 = {}
    index_map1 = {}
    for a in ct0.atom:
        index_map0[a.property["i_des_remd_orig_index"]] = int(a)
    for a in ct1.atom:
        index_map1[a.property["i_des_remd_orig_index"]] = int(a)
    rb0 = cms.RestrainBlock()
    rb1 = cms.RestrainBlock()
    for i_atom in restrained_atom0:
        rb0[index_map0[i_atom]] = rb[i_atom]
    for i_atom in restrained_atom1:
        rb1[index_map1[i_atom]] = rb[i_atom]
    cms.Ffio.set_restrain_block(ct0, rb0)
    cms.Ffio.set_restrain_block(ct1, rb1)
    # Resets the pseudo block.
    pseudo = cms.Ffio.get_pseudo_block(ct0)
    if (len(pseudo) > 0):
        pseudo0 = []
        pseudo1 = []
        for i in mol0:
            pseudo0.append(pseudo[i - 1])
        for i in mol1:
            pseudo1.append(pseudo[i - 1])
        ct0.pseudo = pseudo0
        ct1.pseudo = pseudo1
    return ct0, ct1 
[docs]def set_freezing_atommass(model, mass_scale):
    """
    """
    for ct, site_block in zip(model.comp_ct, model._site_block):
        if (ct.frozen_atom is not None):
            if (ct.frozen_atom == -1):
                # All atom sites will be set with the new mass.
                for site in site_block:
                    site.mass *= mass_scale
            else:
                # Only some atom sites will be set with the new mass.
                for i in ct.frozen_atom:
                    site_block[i - 1].mass *= mass_scale
            cms.Ffio.set_site_block(ct, site_block)
            ct.property["r_des_frozenatommass_threshold"] = mass_scale 
[docs]def freeze_atom(model,
                asl,
                frozen_atom_mass_threshold=FROZEN_ATOM_MASS_THRESHOLD):
    """
    """
    free_atom = model.select_atom_comp(f"(not {asl} and atom.{CT_TYPE} '{CT_TYPE.VAL.SOLUTE}') or " \
                                       
f"(fillmol(not {asl} and not atom.{CT_TYPE} '{CT_TYPE.VAL.SOLUTE}'))" % (asl, asl,))
    frozen_atom = []
    for ct, selected_atom in zip(model.comp_ct, free_atom):
        total = set(range(1, ct.atom_total + 1))
        free = set(selected_atom)
        frozen_atom.append(total - free)
    new_ct = []
    for ct, unroll, selected_atom in zip(model.comp_ct, model._unroll,
                                         frozen_atom):
        num_selected_atom = len(selected_atom)
        if (unroll == 1):
            ct.frozen_atom = None if (num_selected_atom == 0) else selected_atom
            new_ct.append(ct)
        else:
            if (num_selected_atom == ct.atom_total):
                ct.frozen_atom = -1  # Value `-1' means all atoms will be frozen.
                new_ct.append(ct)
            elif (num_selected_atom == 0):
                ct.frozen_atom = None  # Value `None' means none of atoms will be frozen.
                new_ct.append(ct)
            else:
                free_ct, frozen_ct = split_ct(ct, selected_atom)
                free_ct.frozen_atom = None  # Value `None' means none of atoms will be frozen.
                frozen_ct.frozen_atom = -1
                new_ct.extend([
                    free_ct,
                    frozen_ct,
                ])
    model.comp_ct = new_ct
    model.synchronize_fsys_ct()
    model.get_site_block()
    set_freezing_atommass(model, frozen_atom_mass_threshold)
    # Deals with pseudo atoms.
    pseudo_ct_index = []
    for i, ct in enumerate(new_ct):
        try:
            a = ct.pseudo
            pseudo_ct_index.append(i + 1)
        except AttributeError:
            pass
    if (pseudo_ct_index != []):
        s = model.write_to_string()
        struc = [
            model._raw_fsys_ct,
        ] + model.comp_ct
        for i in pseudo_ct_index:
            ah = cms.mm.mmct_ct_m2io_get_additional_data_handle(struc[i].handle)
            cms.mm.m2io_goto_block(ah, "ffio_ff", 1)
            cms.mm.m2io_delete_named_block(ah, "ffio_pseudo")
        s = model.write_to_string()
        model = cms.Cms(string=s)
        for i in pseudo_ct_index:
            pseudo = new_ct[i - 1].pseudo
            ffh = model.comp_ct[i - 1]._ffh
            cms.mm.mmffio_add_pseudos(ffh, len(pseudo))
            for j, p in enumerate(pseudo):
                cms.mm.mmffio_pseudo_set_x_coord(ffh, j + 1, p.x)
                cms.mm.mmffio_pseudo_set_y_coord(ffh, j + 1, p.y)
                cms.mm.mmffio_pseudo_set_z_coord(ffh, j + 1, p.z)
    return model 
[docs]def write_ff_solute_tempering(model,
                              out_fname,
                              asl,
                              scaling_factor,
                              should_scale_torsion=True):
    """
    write cms.Cms to a file with scaled force field terms
    :type model: cms.Cms
    :param model: input cms.Cms
    :type out_fname: str
    :param out_fname: output file name
    :type asl: str
    :param asl: ASL expression to define rest atoms
    :type scaling_factor: float
    :param scaling_factor: T_ref/T_hot
    :type should_scale_torsion: Boolean
    :param should_scale_torsion: whether to scale dihedral terms
    """
    rescaled_model = rescale_ff_solute_tempering(model, asl, scaling_factor,
                                                 should_scale_torsion)
    rescaled_model.write(out_fname) 
[docs]def rescale_ff_solute_tempering(model,
                                asl,
                                scaling_factor,
                                should_scale_torsion=True):
    """
    return cms.Cms with scaled force field terms
    :type model: cms.Cms
    :param model: input cms.Cms
    :type asl: str
    :param asl: ASL expression to define rest atoms
    :type scaling_factor: float
    :param scaling_factor: T_ref/T_hot
    :type should_scale_torsion: Boolean
    :param should_scale_torsion: whether to scale dihedral terms
    :rtype: cms.Cms
    :return: rescaled cms.Cms
    """
    rescaled_model = copy.copy(model)
    hot_atoms = model.select_atom_comp(asl)
    for ct, atom_list in zip(rescaled_model.comp_ct, hot_atoms):
        rescale_ff(ct, atom_list, scaling_factor, should_scale_torsion)
    return rescaled_model 
[docs]def rescale_ff(ct, rest_atoms, scaling_factor, should_scale_torsion=True):
    """
    scales force field terms for rest atoms
    :type ct: ffiostructure.FFIOStructure
    :param ct: input/output structure handle
    :type rest_atoms: list of int
    :param rest_atoms: sequence of atom indices in rest region
    :type scaling_factor: float
    :param scaling_factor: T_ref/T_hot
    :type should_scale_torsion: Boolean
    :param should_scale_torsion: whether to scale dihedral terms
    """
    _rescale_charge(ct, rest_atoms, scaling_factor)
    _rescale_vdw(ct, rest_atoms, scaling_factor)
    if should_scale_torsion:
        _rescale_dihedral(ct, rest_atoms, scaling_factor) 
def _rescale_charge(ct, rest_atoms, scaling_factor):
    """
    scales rest atom charge
    :type ct: ffiostructure.FFIOStructure
    :param ct: input/output structure handle
    :type rest_atoms: list of int
    :param rest_atoms: sequence of atom indices in rest region
    :type scaling_factor: float
    :param scaling_factor: T_ref/T_hot
    Charge scaling in place::
        q_i *= sf_i
        where
            sf_i = sqrt(scaling_factor) when i is a rest atom
            otherwise sf_i = 1.0
    """
    sqrt_scaling_factor = math.sqrt(scaling_factor)
    for i_site in rest_atoms:
        site = ct.ffio.site[i_site]
        site.charge *= sqrt_scaling_factor
    # Scale virtual site when its parent atom is in rest region
    for virtual in ct.ffio.virtual:
        if virtual.ai in rest_atoms:
            virtual_site = ct.ffio.site[virtual.virtual_index]
            virtual_site.charge *= sqrt_scaling_factor
def _rescale_vdw(ct, rest_atoms, scaling_factor):
    """
    scales vdwtype and creates new vdwtype by appending _S for rest atoms
    :type ct: ffiostructure.FFIOStructure
    :param ct: input/output structure handle
    :type rest_atoms: list of int
    :param rest_atoms: sequence of atom indices in rest region
    :type scaling_factor: float
    :param scaling_factor: T_ref/T_hot
    vdw.t1 = scaling_factor
    """
    vdwtype_dict = {vdw.name: vdw for vdw in ct.ffio.vdwtype}
    vdwcombined_name1 = collections.defaultdict(list)
    vdwcombined_name2 = collections.defaultdict(list)
    vdw_combined_scaled = {}
    for vdw in ct.ffio.vdwtypescombined:
        vdwcombined_name1[vdw.name1].append(vdw)
        vdwcombined_name2[vdw.name2].append(vdw)
    new_to_old = {}
    old_to_new = {}
    for i_site in rest_atoms:
        s = ct.ffio.site[i_site]
        old_vdwtype = s.vdwtype
        new_vdwtype = s.vdwtype + '_S'
        s.vdwtype = new_vdwtype
        new_to_old[new_vdwtype] = old_vdwtype
        old_to_new[old_vdwtype] = new_vdwtype
        if new_vdwtype not in vdwtype_dict:
            old_vdw = vdwtype_dict[old_vdwtype]
            new_vdw = ct.ffio.addVdwtype()
            new_vdw.name = new_vdwtype
            new_vdw.c1 = old_vdw.c1
            new_vdw.c2 = old_vdw.c2
            new_vdw.funct = old_vdw.funct
            if ct.ffio.name == 'OPLS_2005':
                new_vdw.c2 *= scaling_factor
            else:
                new_vdw.t1 = scaling_factor
            vdwtype_dict[new_vdwtype] = new_vdw
        for v in vdwcombined_name1.get(old_vdwtype, []):
            if v.name1 != new_vdwtype:
                if not (v.name1, v.name2) in vdw_combined_scaled:
                    vdw_combined_scaled[(v.name1, v.name2)] = (v.c1, v.c2,
                                                               v.funct)
                    # both names get scaled, make sure (new, old), (old, new)
                    # are perserved
                    if v.name2 in new_to_old:
                        vdw_combined_scaled[(new_vdwtype,
                                             new_to_old[v.name2])] = (v.c1,
                                                                      v.c2,
                                                                      v.funct)
                    v.c2 *= scaling_factor
                v.name1 = new_vdwtype
        for v in vdwcombined_name2.get(old_vdwtype, []):
            if v.name2 != new_vdwtype:
                if not (v.name1, v.name2) in vdw_combined_scaled:
                    vdw_combined_scaled[(v.name1, v.name2)] = (v.c1, v.c2,
                                                               v.funct)
                    # both names get scaled, make sure (new, old), (old, new)
                    # are perserved
                    if v.name1 in new_to_old:
                        vdw_combined_scaled[(new_to_old[v.name1],
                                             new_vdwtype)] = (v.c1, v.c2,
                                                              v.funct)
                    v.c2 *= scaling_factor
                v.name2 = new_vdwtype
    for k, v in vdw_combined_scaled.items():
        new_combo = ct.ffio.addVdwtypescombined()
        new_combo.name1, new_combo.name2 = k
        new_combo.c1, new_combo.c2, new_combo.funct = v
def _rescale_dihedral(ct, rest_atoms, scaling_factor):
    """
    scales dihedral terms
    :type ct: ffiostructure.FFIOStructure
    :param ct: input/output structure handle
    :type rest_atoms: list of int
    :param rest_atoms: sequence of atom indices in rest region
    :type scaling_factor: float
    :param scaling_factor: T_ref/T_hot
    Dihedral scaling::
        d.t1 = sf_i * sf_l
        where
            d.t1 = scaling factor for a dihedral term with i, j, k, l atoms
            sf_i = sqrt(scaling_factor) when i is a rest atom
            sf_l = sqrt(scaling_factor) when l is a rest atom
            otherwise sf_i = sf_l = 1.0
    """
    sqrt_scaling_factor = math.sqrt(scaling_factor)
    for dihedral in ct.ffio.dihedral:
        if ct.ffio.name == 'OPLS_2005':
            opls_scaling_factor = 1.0
            if dihedral.ai in rest_atoms:
                opls_scaling_factor *= sqrt_scaling_factor
            if dihedral.al in rest_atoms:
                opls_scaling_factor *= sqrt_scaling_factor
            dihedral.c0 *= opls_scaling_factor
            dihedral.c1 *= opls_scaling_factor
            dihedral.c2 *= opls_scaling_factor
            dihedral.c3 *= opls_scaling_factor
            dihedral.c4 *= opls_scaling_factor
            dihedral.c5 *= opls_scaling_factor
            dihedral.c6 *= opls_scaling_factor
        else:
            dihedral.t1 = 1.0
            if dihedral.ai in rest_atoms:
                dihedral.t1 *= sqrt_scaling_factor
            if dihedral.al in rest_atoms:
                dihedral.t1 *= sqrt_scaling_factor
if ("__main__" == __name__):
    def test_get_num_wateratom(model):
        print(get_num_wateratom(model))
    def test_get_num_nonwateratom(model):
        print(get_num_nonwateratom(model))
    def test_get_num_constraint(model):
        print(get_num_constraint(model))
    def test_get_degrees_of_freedom(model):
        print(get_degrees_of_freedom(model))
    def test_predict_temperature(model):
        print(predict_temperature_ladder((
            300.0,
            340.0,
        ), 0.3, model, "solute"))
    def test_freeze_atom(model):
        model = freeze_atom(model, "atom.num 2600-2700")
        model.write("splitted_ct.cms")
    def test_predict_with_nreplica_and_exch(model):
        print(predict_with_nreplica_and_exch(5, 0.3, 300.0, model, "solute"))
    def test_predict_with_temp_and_nreplica(model):
        print(
            predict_with_temp_and_nreplica((
                300.0,
                340.0,
            ), 10, model, "solute"))
    model = cms.Cms(file="test.cms")
    #test_predict_with_temp_and_nreplica( model )
    test_predict_with_nreplica_and_exch(model)
    #model = cms.Cms( file = "desmond_job_restrained.cms" )
#     test_get_num_wateratom     ( model )
#     test_get_num_nonwateratom  ( model )
#     test_get_num_constraint    ( model )
#     test_get_degrees_of_freedom( model )
#test_freeze_atom           ( model )