"""
Restraint generation for cross-CT terms and all terms supported
by desmond backend, including alchemical terms.
Either a single term or a generator can be used.
For single term, each selection corresponds to a single atom.
Two kinds of generators are implemented now:
product: The product of all selections is used to generate all
         the terms.  Use case is to keep alchemical ions way from
         places they may get stuck.
connected: One selection is evaluated to generate terms for bond, angle
           and torsion.  Use case is the alchemical restraints
           on protein conformations.
Reference distance, angle and torsion values are computed for generated
terms.  For alchemical terms, reference coordinates saved previously will
be used for these calculations if available.
Copyright Schrodinger, LLC. All rights reserved.
"""
import copy
import dataclasses
import enum
import functools
import itertools
import json
import pprint
from typing import List
from typing import Tuple
import numpy as np
from schrodinger.application.desmond import cms
from schrodinger.application.desmond import constants
from schrodinger.application.desmond.packages import msys
from schrodinger.structutils import analyze
from schrodinger.structutils import measure
from schrodinger.utils import log
from schrodinger.utils import sea
from .utils import PERSISTENT_RESTRAINT_KEY
from .utils import RESTRAINT_KEY
from .utils import b64_decode
from .utils import b64_encode
from .utils import get_encoded_restraints
from .utils import set_encoded_restraints
logger = log.get_logger(name="restraint_builder")
_ATOMS_KEY = "atoms"
_TERM_ITERATORS = {
    2: analyze.bond_iterator,
    3: analyze.angle_iterator,
    4: analyze.torsion_iterator,
}
_MEASURE_FUNCTIONS = {
    2: measure.measure_distance,
    3: measure.measure_bond_angle,
    4: measure.measure_dihedral_angle,
}
_FBHW = 'fbhw'
_FC = 'fc'
_POSRE_HARM = 'posre_harm'
_POSRE_FBHW = 'posre_fbhw'
_SIGMA = 'sigma'
_POSRE_HARM_OR_FBHW = frozenset({_POSRE_HARM, _POSRE_FBHW})
[docs]class GeneratorType(enum.Enum):
    PRODUCT = 'product'
    CONNECTED = 'connected' 
_dummy_msys = msys.CreateSystem()  # empty msys to get table schemas
__all__ = [
    'AtomID',
    'Restraints',
    'RestraintBuilder',
    'generate_conf_pose_restraints',
    'has_positional_restraints',
]
[docs]def decode_atom_ids(encoded_ids):
    return map(AtomID.make, encoded_ids) 
[docs]def get_natoms_in_term(table_name: str) -> int:
    """
    Returns arity (number of atoms for each term) for the with
    name `table_name`.
    :param table_name: name of the desmond term table
    :return: number of atoms for each term
    """
    # addTableFromSchema only generate a new table once
    # if the named table already exists, a pointer will be returned
    # c.f. http://opengrok/xref/desmond-gpu-src/other/msys/python/__init__.py#1277
    table = _dummy_msys.addTableFromSchema(table_name)
    return table.natoms 
[docs]def get_table_schema(table_name: str):
    """
    Returns schema for desmond term table with name `table_name`.
    :param table_name: name of the desmond term table
    :return: (param_props, table_props) where both param_props
        and table_props are frozensets of property name strings
    :rtype: `tuple(frozenset(str), frozenset(str))`
    """
    # the following only adds table when named table does not exist
    # see link in the above function
    table = _dummy_msys.addTableFromSchema(table_name)
    term_props = (k for k in table.term_props if k != 'constrained')
    return (frozenset(table.params.props), frozenset(term_props)) 
[docs]@dataclasses.dataclass
class AtomID:
    """
    Atom is specified by two numbers: ct index and atom index
    within this ct; ct indices starts from 0; ct == 0 indicates that
    the atom number is a gid used by desmond backend; ct == 1
    is the "full system"; ct >= 2 correspond to the component cts
    """
    ct: int
    atom: int
[docs]    @staticmethod
    def make(v):
        return AtomID(
            **v) if isinstance(v, dict) else AtomID(ct=v[0], atom=v[1])  
class _GenericParams(dict):
    def __init__(self, *, table_name: str, dct=None):
        """
        :param table_name: desmond term table name
        :type table_name: str
        :param dct: dictionary deserialized from JSON
        :type dct: dict
        """
        self._is_alchemical = table_name.startswith('alchemical_')
        props = get_table_schema(table_name)
        for p in sorted(set([_ATOMS_KEY]) | props[0] | props[1]):
            self[p] = []
        if dct:
            for k, v in dct.items():
                if k == _ATOMS_KEY:
                    self[k] += [list(decode_atom_ids(atoms)) for atoms in v]
                else:
                    self[k] += v
    def addTerm(self, atoms: List[AtomID], props: dict) -> int:
        """
        props contain the actual force-field parameters, force constants,
        equilibrium angles and other parameters that specific to the term,
        e.g. schedule for alchemical terms
        :param atoms: list of atom ids
        :param props: dictionary of parameters keyed by name of the
                      parameter (`str`), including both table properties
                      (e.g. schedule) and force field parameter properties
        :return: index of the term added
        """
        atoms_list = self[_ATOMS_KEY]
        atoms_list.append(atoms)
        for k, v in props.items():
            self[k].append(v)
        return len(atoms_list) - 1
    @property
    def is_positional(self):
        return False
    @property
    def is_alchemical(self):
        return self._is_alchemical
    def incorporate(self, other: '_GenericParams'):
        """
        merges `other` into `self`
        """
        assert self.keys() == other.keys()
        if self[_ATOMS_KEY] and other[_ATOMS_KEY]:
            raise NotImplementedError(
                "merging is supported for positional restraints only")
        elif other[_ATOMS_KEY]:
            for k, v in self.items():
                v += copy.deepcopy(other[k])
class _PosreHarmParams:
    """
    Stores group of position restraints as a numpy structured array
    for faster creation (if from existing arrays) and merging.
    """
    AID_TYPE = [('ct', 'i8'), ('atom', 'i8')]
    DTYPE = [("atom_id", AID_TYPE), ("k", "f8", (3,)), ("ref", "f8", (3,))]
    INTERSECT_KEYS = "atom_id"
    def __init__(self, *, atom_ids=None, k=None, ref=None, dct=None, **kwargs):
        """
        :param atom_ids: list of lists of AtomIDs
        :param k: force constants
        :param ref: reference positions
        :param dct: dictionary de-serialized from json; if truthy,
            overrides other arguments
        """
        if dct:
            atom_ids = self._transcode_atom_ids(dct[_ATOMS_KEY])
            ref = list(zip(dct['x0'], dct['y0'], dct['z0']))
            k = list(zip(dct['fcx'], dct['fcy'], dct['fcz']))
        num_entries = len(atom_ids or [])
        self.arr = np.empty(shape=num_entries, dtype=self.DTYPE)
        if num_entries:
            self.arr['atom_id'] = atom_ids
            self.arr['ref'] = ref
            self.arr['k'] = k
    @staticmethod
    def _transcode_atom_ids(list_of_encoded_aid_lists):
        """
        Helper function that takes list-of-lists of atom IDs encoded
        either as (ct, atom) pairs or {'ct': ct, 'atom': atom} dicts
        and returns list of (ct, atom) tuples (see c-tor).
        """
        list_of_tuples = []
        for encoded_aids in list_of_encoded_aid_lists:
            aid = next(decode_atom_ids(encoded_aids))
            list_of_tuples.append(dataclasses.astuple(aid))
        return list_of_tuples
    def _intersect_with(self, other) -> (np.array, np.array):
        """
        Find the duplicate entries in self and other and return their indices,
        respectively
        """
        _, self_idx, other_idx = np.intersect1d(self.arr[self.INTERSECT_KEYS],
                                                other.arr[self.INTERSECT_KEYS],
                                                assume_unique=True,
                                                return_indices=True)
        return self_idx, other_idx
    def incorporate(self, other: '_PosreHarmParams'):
        """
        merges `other` into `self`
        """
        assert isinstance(other, _PosreHarmParams)
        self_idx, other_idx = self._intersect_with(other)
        other_k = other.arr['k'][other_idx]
        self.arr['k'][self_idx] = np.maximum(self.arr['k'][self_idx], other_k)
        self.arr['ref'][self_idx] = other.arr['ref'][other_idx]
        mask = np.ones(other.arr.size, dtype=bool)
        mask[other_idx] = False
        self.arr = np.concatenate([self.arr, other.arr[mask]])
    def copyRefs(self, other: '_PosreHarmParams'):
        """
        copies reference positions from the `other`
        """
        assert isinstance(other, _PosreHarmParams)
        self_idx, other_idx = self._intersect_with(other)
        self.arr['ref'][self_idx] = other.arr['ref'][other_idx]
    @property
    def atom_ids_as_tuples(self):
        return [
            [(int(aid['ct']), int(aid['atom']))] for aid in self.arr['atom_id']
        ]
    @property
    def refs(self):
        for a, v in zip('xyz', self.arr['ref'].transpose().tolist()):
            yield f'{a}0', v
    def asdict(self):
        outcome = {_ATOMS_KEY: self.atom_ids_as_tuples}
        outcome.update(self.refs)
        for a, v in zip('xyz', self.arr['k'].transpose().tolist()):
            outcome[f'{_FC}{a}'] = v
        return outcome
    def __len__(self):
        return len(self.arr)
    @property
    def is_positional(self):
        return True
    @property
    def is_alchemical(self):
        return False
class _PosreFBHWParams(_PosreHarmParams):
    DTYPE = [("atom_id", _PosreHarmParams.AID_TYPE), ("k", "f8"),
             ("ref", "f8", (3,)), (_SIGMA, "f8")]
    INTERSECT_KEYS = ["atom_id", _SIGMA]
    def __init__(self,
                 *,
                 atom_ids=None,
                 k=None,
                 ref=None,
                 sigma=None,
                 dct=None,
                 **kwargs):
        if dct:
            atom_ids = self._transcode_atom_ids(dct[_ATOMS_KEY])
            ref = list(zip(dct['x0'], dct['y0'], dct['z0']))
            k = dct[_FC]
            sigma = dct[_SIGMA]
        super().__init__(atom_ids=atom_ids, k=k, ref=ref)
        if sigma:
            # will broadcast as needed
            self.arr[_SIGMA] = sigma
    def asdict(self):
        outcome = {
            _ATOMS_KEY: self.atom_ids_as_tuples,
            _FC: self.arr['k'].tolist(),
            _SIGMA: self.arr[_SIGMA].tolist()
        }
        outcome.update(self.refs)
        return outcome
class _JsonEncoder(json.JSONEncoder):
    @functools.singledispatchmethod
    def default(self, obj):
        return super().default(obj)
    @default.register
    def _(self, obj: AtomID):
        return dataclasses.astuple(obj)
    @default.register
    def _(self, obj: _PosreHarmParams):
        return obj.asdict()
[docs]class Restraints:
    """
    Holds restraint terms parameters. Assumes that "persistent"
    (aka "permanent") tables support merging (only "posre_*" at the moment).
    """
[docs]    def __init__(self, *, text=None):
        """
        :param text: pre-existing serialized restraints (as json) to build upon
        :type text: str or NoneType
        """
        data = {
            RESTRAINT_KEY: {},
            PERSISTENT_RESTRAINT_KEY: {}
        } if text is None else json.loads(text)
        def instantiate(d):
            return {
                name: self._getTableClass(name)(table_name=name, dct=dct)
                for name, dct in d.items()
            }
        self._restraints = instantiate(data.get(RESTRAINT_KEY, {}))
        self._persistent = instantiate(data.get(PERSISTENT_RESTRAINT_KEY, {})) 
    @staticmethod
    def _getTableClass(table_name):
        return {
            _POSRE_HARM: _PosreHarmParams,
            _POSRE_FBHW: _PosreFBHWParams,
        }.get(table_name, _GenericParams)
[docs]    def getTable(self, table_name: str, persistent: bool = False) -> object:
        """
        Gets parameters table by name.
        :param table_name: name of the desmond term table
        :param persistent: persistent vs regular table
        :return: requested table
        :rtype: `_GenericParams` or `_PosreHarmParams` or `_PosreFBHWParams`
        """
        group = self._persistent if persistent else self._restraints
        try:
            return group[table_name]
        except KeyError:
            table_class = self._getTableClass(table_name)
            table = table_class(table_name=table_name)
            group[table_name] = table
            return table 
[docs]    def addTerm(self,
                table_name: str,
                atoms: List[AtomID],
                props: dict,
                persistent: bool = False) -> int:
        """
        Adds single restrain term.
        table_name is the desmond interaction table, stretch_harm,
        alchemical_improper_harm etc.
        An atom is specified by two numbers, ct number and atom number
        in ct.  ct number starts from 0, that means the atom number is
        gid used by desmond backend.  ct number 1 means full system.
        ct numbers greater than or equal to 2 mean component cts.
        props contain the actual force-field parameters, force constants,
        equilibrium angles and other parameters that specific to the term,
        e.g. schedule for alchemical terms
        :param table_name: name of the table, this is one of the term
                           tables supported by desmond
        :param atoms: atom ids
        :param props: dictionary of parameter keyed by name of the
                      parameter (`str`), including both table properties
                      (e.g. schedule) and force field parameter properties.
        :param persistent: is this a persistent term
        :return: index of the term added
        """
        table = self.getTable(table_name, persistent)
        return table.addTerm(atoms, props) 
    @property
    def has_persistent(self) -> bool:
        return bool(self._persistent)
[docs]    def toJson(self) -> str:
        """
        :return: json string to be loaded by msys
        """
        if self.has_persistent:
            restraints = copy.deepcopy(self._restraints)
            for name, persistent_table in self._persistent.items():
                if table := restraints.get(name):
                    table.incorporate(persistent_table)
                    restraints[name] = table
                else:
                    restraints[name] = persistent_table
            dct = {
                RESTRAINT_KEY: restraints,
                PERSISTENT_RESTRAINT_KEY: self._persistent
            }
        else:
            dct = {RESTRAINT_KEY: self._restraints}
        return json.dumps(dct, cls=_JsonEncoder) 
    @property
    def has_positional(self) -> bool:
        """
        Set to True if the Restraint has any positional restraints.
        """
        return any(table.is_positional for table in itertools.chain(
            self._restraints.values(), self._persistent.values())) 
def _measure_ref(atoms):
    """
    Measure wrapper for distance, angle and torsion.
    :param atoms: list/tuple of atoms
    :type atoms: list of `structure.Atom` or list of 3 floats
    :rtype: float
    """
    n_atom = len(atoms)
    func = _MEASURE_FUNCTIONS.get(n_atom)
    if func:
        return func(*atoms)
    else:
        raise ValueError("Cannot measure on %d number of atoms" % n_atom)
def _generic_terms(struct, atom_selection, n_atom):
    """
    Enumerate terms from atom selection according to number of atoms
    n_atom 2: bond, 3: angle, 4: dihedral.
    :param struct: structure for connection
    :type struct: `structure.Structure`
    :param atom_selection: selected atoms
    :type atom_selection: list of atom indices
    :param n_atom: number of atoms in a term
    :type n_atom: integer value of 2, 3 or 4
    :rtype: list of list
    """
    it = _TERM_ITERATORS.get(n_atom)
    if it:
        return it(struct=struct, atoms=atom_selection)
    else:
        raise ValueError("Unknown term using %d atoms" % n_atom)
def _select_atoms_single_term(cms_mol, atom_sel, n_atoms):
    """
    Select atoms for a single term, only the first atom selected
    in each CT will be used.  Atoms returned are in the form of
    (CT number, `structure.Atom`)
    :param cms_mol: input system for selection
    :type cms_mol: `cms.Cms` object
    :param atom_sel: atom selection specification
    :type atom_sel: `sea.List` object
    :param n_atoms: number of atoms in term
    :type n_atoms: `int`
    :rtype: list of atom tuples, (`int`, `structure.Atom`)
    """
    atoms_ret = []
    for i in range(n_atoms):
        for ct_idx, atoms in enumerate(cms_mol.select_atom_comp(
                atom_sel[i].val),
                                       start=2):
            if len(atoms) == 1:
                atoms_ret.append(
                    (ct_idx, cms_mol.comp_ct[ct_idx - 2].atom[atoms[0]]))
                break
            elif len(atoms) > 1:
                raise ValueError(
                    "More than 1 atom selected for the term, selection: %s" %
                    atom_sel[i].val)
    if len(atoms_ret) == n_atoms:
        return atoms_ret
    else:
        raise ValueError(
            "Selected atoms do not match required %d versus %d, selection: %s\n"
            % (len(atoms_ret), n_atoms, atom_sel))
def _select_atoms_for_generator(cms_mol, atom_sel):
    """
    Select atoms to be used for actual term generation.
    Atoms retuned are in (CT number, `structure.Atom`) form.
    :param cms_mol: input system for selection
    :type cms_mol: `cms.Cms`
    :param atom_sel: atom selection specification, ASL
    :type atom_sel: `sea.Map`
    :rtype: list of atom tuples, (`int`, `structure.Atom`)
    """
    atoms_ret = []
    for ct_idx, atoms in enumerate(cms_mol.select_atom_comp(atom_sel.val),
                                   start=2):
        for a in atoms:
            atoms_ret.append((ct_idx, cms_mol.comp_ct[ct_idx - 2].atom[a]))
    return atoms_ret
def _find_force_constants(param_props: List[str]) -> Tuple[List[str]]:
    """
    Find all force constant keys in parameter properties and
    return tuple of force constants and other properties.
    :param param_props: force field parameters
    """
    fcs = []
    ref = []
    for p in param_props:
        if p.startswith(_FC):
            fcs.append(p)
        else:
            ref.append(p)
    return (sorted(fcs), sorted(ref))
[docs]class RestraintBuilder:
[docs]    def __init__(self,
                 restraint_terms: sea.List,
                 existing: constants.EXISTING_RESTRAINT,
                 cms_sys: cms.Cms,
                 persistent: bool = False):
        """
        :param restraint_terms: all restraint terms to be added
        :param existing: One of `constants.EXISTING_RESTRAINT`, determines
            whether to `IGNORE` current restraints and replace them with
            `restraint_terms` or `RETAIN` them and update them with
            `restraint_terms`.
        :param cms_sys: cms object for molecules
        :param persistent: build "persistent" restraints
        """
        self._cms_sys = cms_sys
        self._all_restraints = restraint_terms
        self._persistent = persistent
        encoded = get_encoded_restraints(cms_sys)
        if encoded:
            self._restrain = Restraints(text=self._applyRestraintsDisposition(
                b64_decode(encoded), existing))
        else:
            self._restrain = Restraints()
        fep_cts = cms_sys.get_fep_cts()
        coords = (None, None)
        if all(fep_cts):
            try:
                from schrodinger.application.scisol.packages.core_hopping.int_fepio import \
                    
get_reference_coordinates_for_two_molecules
                wt_ct, mut_ct = fep_cts
                coords = get_reference_coordinates_for_two_molecules(
                    wt_ct, mut_ct)
                if wt_ct is not None and mut_ct is not None:
                    logger.debug(
                        "Reference coordinates found, using them for alchemical terms."
                    )
            except ImportError:
                logger.debug(
                    "Cannot import scisol function, use current coordinates.")
            except RuntimeError:
                # get_reference_coorrdinates_for_two_molecules throws RuntimeError
                pass
        fep_ref_coord = {k: v for k, v in zip(fep_cts, coords) if v is not None}
        self._ref_coords = {
            ct_idx: fep_ref_coord[ct] if ct in fep_ref_coord else ct.getXYZ()
            for ct_idx, ct in enumerate(cms_sys.comp_ct, start=2)
        } 
    def _applyRestraintsDisposition(
            self, text: str, existing: constants.EXISTING_RESTRAINT) -> str:
        """
        :param text: Serialized restraints (see `Restraints.toJson()`).
        :param existing: Restraints disposition (ignore/retain/ignore_posre).
        :return: Serialized restraints ready for `Restraints` constructor.
        """
        retain_all = existing == constants.EXISTING_RESTRAINT.RETAIN
        ignore_posre = existing == constants.EXISTING_RESTRAINT.IGNORE_POSRE
        dct = json.loads(text)
        stage_restraints = dct.get(RESTRAINT_KEY, {})
        persistent = dct.get(PERSISTENT_RESTRAINT_KEY, {})
        group = persistent if self._persistent else stage_restraints
        if retain_all or ignore_posre:
            if ignore_posre:
                for k in _POSRE_HARM_OR_FBHW:
                    group.pop(k, None)
        else:
            group.clear()
        return json.dumps({
            RESTRAINT_KEY: stage_restraints,
            PERSISTENT_RESTRAINT_KEY: persistent
        })
    def _addAllPosreTerms(self, spec: 'sea.Map'):
        """
        Add positional restraint terms as directed by the `spec`.
        :param spec: restraint specs for posre_harm or posre_fbhw
        """
        fc = spec.force_constants.val
        ref = spec.ref.val if 'ref' in spec else 'reset'
        model = self._cms_sys
        atom_idxs = model.select_atom_comp(spec.atoms.val if isinstance(
            spec.atoms, sea.Atom) else spec.atoms[0].val)
        table_name = spec.name.val
        table = self._restrain.getTable(table_name, self._persistent)
        make_table = functools.partial(
            _PosreFBHWParams,
            sigma=spec.sigma.val) if _FBHW in table_name else _PosreHarmParams
        for ct_idx, (ct, ct_atom_idxs) in enumerate(zip(model.comp_ct,
                                                        atom_idxs),
                                                    start=2):
            xyz = ct.getXYZ(copy=False)[np.array(ct_atom_idxs, dtype=int) - 1]
            assert xyz.base is None  #  owns the memory
            if ref not in ('reset', 'retain'):
                arr = np.array(ref)
                num_atoms = len(ct_atom_idxs)
                if arr.size < num_atoms:
                    logger.warning(
                        "WARNING: restrain reference array is too short, "
                        "using existing positions for remaining atoms.")
                xyz.flat[:arr.size] = arr[:3 * num_atoms]
            to_merge = make_table(atom_ids=list(
                zip(itertools.repeat(ct_idx), ct_atom_idxs)),
                                  k=fc,
                                  ref=xyz)
            if ref == 'retain':
                to_merge.copyRefs(table)
            table.incorporate(to_merge)
    def _addOneTerm(self, table_name, atoms_selected, res_spec, fc_keys,
                    ref_keys, term_props):
        """
        Add a single term, if reference is not provided in the res_spec,
        measure from the coordinates.  For alchemical terms, prestored
        reference coordinates are used other than the corrent coordinates
        to compute term reference values if possible.
        :param table_name: name of the term table
        :type table_name:  `str`
        :param atoms_selected:    atoms in the term, each atom is
                                  (ct_number, `structure.Atom`) pair
        :type atoms_selected:     `tuple` of (`int`, `structure.Atom`)'s
        :param res_spec: parameters for the term
        :type res_spec: `Sea.Map` object
        :param fc_keys: force constant keys
        :type fc_keys: `list` of `str`
        :param ref_keys: keys of reference distance, angle and dihedral
        :type ref_keys: `list` of `str`
        :param term_props: other term properties that define the force field
                           e.g. schedule for alchemical terms
        :type term_props: `list` of `str`
        """
        term_atoms = [
            AtomID(ct_idx, atom.index) for ct_idx, atom in atoms_selected
        ]
        params_dict = {
            k: v.val for k, v in zip(fc_keys, res_spec.force_constants)
        }
        for k in term_props:
            params_dict[k] = res_spec[k].val
        for k in ref_keys:
            if k in res_spec:
                params_dict[k] = res_spec[k].val
            else:
                cts, atoms = list(zip(*atoms_selected))
                # alchemical potential, need to measure from reference coordinates
                if self._restrain.getTable(table_name).is_alchemical:
                    coords = (self._ref_coords[ct_idx][atom.index - 1]
                              for ct_idx, atom in atoms_selected)
                    params_dict[k] = _measure_ref(tuple(coords))
                else:
                    params_dict[k] = _measure_ref(list(atoms))
        self._restrain.addTerm(table_name,
                               term_atoms,
                               params_dict,
                               persistent=self._persistent)
    def _addTermFromGenerator(self, table_name, n_atoms, res_spec, fc_keys,
                              ref_keys, term_props):
        """
        Generate terms form the atom selections.
        :param table_name: name of the term table
        :type table_name:  `str`
        :param n_atoms:    number of atoms in each term
        :type n_atoms:     `int`
        :param res_spec: parameters to generate terms
        :type res_spec: `sea.Map` object
        :param fc_keys: force constant keys
        :type fc_keys: `list` of `str`
        :param ref_keys: keys of reference distance, angle and dihedral
        :type ref_keys: `list` of `str`
        :param term_props: other term properties that define the force field
                           e.g. schedule for alchemical terms
        :type term_props: `list` of `str`
        """
        if res_spec.generator.val == GeneratorType.PRODUCT.value:
            assert n_atoms == len(res_spec.atoms)
            atoms_selected = [
                _select_atoms_for_generator(self._cms_sys, sel)
                for sel in res_spec.atoms
            ]
            for prod_atoms in itertools.product(*atoms_selected):
                self._addOneTerm(table_name, prod_atoms, res_spec, fc_keys,
                                 ref_keys, term_props)
        elif res_spec.generator.val == GeneratorType.CONNECTED.value:
            atoms_selected = []
            for ct_idx, atoms in enumerate(self._cms_sys.select_atom_comp(
                    res_spec.atoms[0].val),
                                           start=2):
                #filter out empty selctions
                if len(atoms) == 0:
                    continue
                current_ct = self._cms_sys.comp_ct[ct_idx - 2]
                for t in _generic_terms(current_ct, atoms, n_atoms):
                    atoms_selected.append([
                        (ct_idx, current_ct.atom[a]) for a in t
                    ])
            for prod_atoms in atoms_selected:
                self._addOneTerm(table_name, list(prod_atoms), res_spec,
                                 fc_keys, ref_keys, term_props)
[docs]    def addRestraints(self):
        """
        Add all restraint terms to the cms object passed in the constructor.
        This should be the only function called to process all the restraints specified
        """
        for r in self._all_restraints:
            table_name = r.name.val
            table = self._restrain.getTable(table_name, self._persistent)
            if not table.is_positional:
                param_props, term_props = get_table_schema(table_name)
                fcs, refs = _find_force_constants(param_props)
                n_atoms = get_natoms_in_term(table_name)
                if "generator" in r:
                    self._addTermFromGenerator(table_name, n_atoms, r, fcs,
                                               refs, term_props)
                else:
                    atoms_selected = _select_atoms_single_term(
                        self._cms_sys, r.atoms, n_atoms)
                    self._addOneTerm(table_name, atoms_selected, r, fcs, refs,
                                     term_props)
            else:
                self._addAllPosreTerms(r)
        set_encoded_restraints(self._cms_sys, self.getEncoded()) 
[docs]    def getEncoded(self):
        """
        Reports restraints built as b64 encoded JSON string.
        :rtype: `str`
        """
        return b64_encode(self.getJson()) 
[docs]    def getJson(self):
        """
        :rtype: `str`
        """
        return self._restrain.toJson() 
[docs]    def getString(self, skip_tables=None, **kwargs) -> str:
        """
        :param skip_tables: Skip the listed tables in the result string.
        """
        skip_tables = skip_tables or {}
        dct = json.loads(self.getJson())
        return '\n'.join(
            f"'{k}':\n{pprint.pformat({k2: v2 for k2, v2 in v.items() if k2 not in skip_tables}, **kwargs)}"
            for k, v in dct.items())  
[docs]def generate_conf_pose_restraints(cts,
                                  ct_numbers,
                                  enable_pose_restraint=False,
                                  pose_restraint_cfg=None,
                                  pose_restraint_terms=None,
                                  enable_conf_restraint=False,
                                  conf_restraint_cfg=None,
                                  conf_restraint_terms=None):
    """
    Generate pose and conf restraints according to cfg.
    :param cts: tuple of reference and mutant structures
    :type cts: Tuple(structure.Structure, structure.Structure)
    :param ct_numbers: tuple of reference and mutant ct numbers
                       in the original Maestro file
    :type ct_numbers: Tuple(int, int)
    :param enable_pose_restraint: flag for ligand pose restraints
    :type enable_pose_restraint: bool
    :param pose_restraint_cfg: ligand pose specification
    :type pose_restraint_cfg: Dict
    :param pose_restraint_terms: dihedrals need to be restrained
    :type pose_restraint_terms: AlchemicalInteractions.pose_restraint
    :param enable_conf_restraint: flag for ligand pose restraints
    :type enable_conf_restraint: bool
    :param conf_restraint_cfg: conformation specification
    :type conf_restraint_cfg: Dict
    :param conf_restraint_terms: dihedrals need to be restrained
    :type conf_restraint_terms: AlchemicalInteractions.conf_restraint
    :rtype: restraint.Restraints
    """
    from schrodinger.application.desmond.struc import \
        
get_atom_reference_coordinates
    _NAME = 'name'
    def _restraint_params(ct,
                          ct_num,
                          atoms,
                          param_spec,
                          is_wt=True,
                          restraint_type=None):
        def _get_AB_keys(key: str) -> str:
            return key + 'A', key + 'B'
        natoms = len(atoms)
        _STRETCH_REF_KEY = 'r0'
        _DIHEDRAL_REF_KEY = 'phi0'
        _REF_PREFIX = {2: _STRETCH_REF_KEY, 4: _DIHEDRAL_REF_KEY}
        _FCA, _FCB = _get_AB_keys(_FC)
        _SIGMA_A, _SIGMA_B = _get_AB_keys(_SIGMA)
        _SCHEDULE = 'schedule'
        _SOFT = 'soft'
        _ALPHA = 'alpha'
        _ALPHA_A, _ALPHA_B = _get_AB_keys(_ALPHA)
        _REF_KEY_A, _REF_KEY_B = _get_AB_keys(_REF_PREFIX[natoms])
        ref = _MEASURE_FUNCTIONS[natoms](*[
            get_atom_reference_coordinates(ct.atom[atoms[a]])
            for a in range(natoms)
        ])
        param_dict = {
            _FCA: 0.0,
            _FCB: 0.0,
            _REF_KEY_A: ref,
            _REF_KEY_B: ref,
            _SCHEDULE: param_spec[_SCHEDULE]
        }
        if is_wt:
            param_dict[_FCB] = param_spec[_FC]
        else:
            param_dict[_FCA] = param_spec[_FC]
        if param_spec[_NAME] == _FBHW:
            param_dict[_SIGMA_A] = param_spec[_SIGMA]
            param_dict[_SIGMA_B] = param_spec[_SIGMA]
        if restraint_type == constants.ConfRestraintType.CALPHA_RUNG or \
               
param_spec[_NAME] == _SOFT:
            param_dict[_ALPHA_A] = param_spec[_ALPHA]
            param_dict[_ALPHA_B] = param_spec[_ALPHA]
        term_atoms = tuple(AtomID(ct_num, atom) for atom in atoms)
        return (term_atoms, param_dict)
    _PACKED_ARGUMENTS = list(
        zip([True, False], constants.FEP_STATE_KEYS, cts, ct_numbers))
    _ALCHEMICAL_IMPROPER = 'alchemical_improper_'
    _ALCHEMICAL_SOFTSTRETCH = 'alchemical_softstretch_'
    all_restraints = Restraints()
    if enable_pose_restraint:
        table_name = _ALCHEMICAL_IMPROPER + pose_restraint_cfg[_NAME]
        for is_wt, key, ct, ct_num in _PACKED_ARGUMENTS:
            for dihe in pose_restraint_terms[key]:
                term_atoms, param_dict = _restraint_params(ct,
                                                           ct_num,
                                                           dihe,
                                                           pose_restraint_cfg,
                                                           is_wt=is_wt)
                all_restraints.addTerm(table_name, term_atoms, param_dict)
    if enable_conf_restraint:
        table_name_prefix = {
            constants.ConfRestraintType.BACKBONE: _ALCHEMICAL_IMPROPER,
            constants.ConfRestraintType.SIDECHAIN: _ALCHEMICAL_IMPROPER,
            constants.ConfRestraintType.CALPHA_RUNG: _ALCHEMICAL_SOFTSTRETCH
        }
        for r in constants.ConfRestraintType:
            conf_spec = conf_restraint_cfg[r]
            table_name = table_name_prefix[r] + conf_spec[_NAME]
            for is_wt, key, ct, ct_num in _PACKED_ARGUMENTS:
                for atoms in conf_restraint_terms[key][r]:
                    term_atoms, param_dict = _restraint_params(ct,
                                                               ct_num,
                                                               atoms,
                                                               conf_spec,
                                                               is_wt=is_wt,
                                                               restraint_type=r)
                    all_restraints.addTerm(table_name, term_atoms, param_dict)
    return all_restraints 
[docs]def has_positional_restraints(model: cms.Cms) -> bool:
    """
    Return True if the cms model has positional restraints.
    """
    # First check for ffio_restraints defined on the cms model
    for r in model.get_restrain():
        if set(r.keys()).intersection(
            [constants.RestrainTypes.POS, constants.RestrainTypes.POS_FBHW]):
            return True
    # Then check new restraints
    encoded = get_encoded_restraints(model)
    if encoded:
        r = Restraints(text=b64_decode(encoded))
        return r.has_positional
    return False