"""
Utilities for swapping fragments.
Copyright Schrodinger, LLC. All rights reserved.
"""
from schrodinger.infra import structure as infrastructure
from schrodinger.structutils import measure
from schrodinger.structutils import rmsd
from schrodinger.application.matsci import msutils
from schrodinger.application.matsci import msprops
REACTION_WF_STRUCTURE_KEY = 'b_matsci_Reaction_Workflow_Structure'
KEEP_ATOM_PROP = 'b_matsci_Reaction_Workflow_Keep_Atom'
SUPERPOSITION_ATOM_PROP = 'i_matsci_Reaction_Workflow_Superposition_Atom'
CONFORMERS_GROUP_KEY = 's_matsci_Reaction_Workflow_Conformers_Group'
SIBLING_GROUP_KEY = 's_matsci_Reaction_Workflow_Sibling_Group'
# for the following keep the original key for backwards compatibility
PARENT_SIBLING_GROUPS_KEY = 's_matsci_Reaction_Workflow_Parent_Conformer_Groups'
TRANSITION_STATE_STRUCTURE_KEY = 'b_matsci_Reaction_Workflow_Transition_State_Structure'
RESTRAINED_DISTANCES_KEY = 's_matsci_Reaction_Workflow_Restrained_Distances'
RESTRAINED_ANGLES_KEY = 's_matsci_Reaction_Workflow_Restrained_Angles'
RESTRAINED_DIHEDRALS_KEY = 's_matsci_Reaction_Workflow_Restrained_Dihedrals'
INDEX_SEPARATOR = ','
SEPARATOR = ';'
DISTANCE_CELL_LEN = 0.75
[docs]class SwapFragmentsException(Exception):
    pass 
[docs]def get_idx_groups_str(idx_groups):
    """
    Get a string representation of the given index groups.
    :type idx_groups: list
    :param idx_groups: contains lists of indices
    :rtype: str
    :return: the string
    """
    strs = []
    for idx_group in idx_groups:
        astr = '(' + INDEX_SEPARATOR.join([str(i) for i in idx_group]) + ')'
        strs.append(astr)
    return SEPARATOR.join(strs) 
[docs]def get_idxs_marked_atoms(st, prop):
    """
    Return a list of indices of atoms in the given
    structure that have the given property defined.
    :type st: schrodinger.structure.Structure
    :param st: the structure
    :type prop: str
    :param prop: the property that marks the atoms
    :rtype: list
    :return: contains indices of atoms
    """
    if not msutils.has_atom_property(st, prop):
        return []
    return [atom.index for atom in st.atom if atom.property.get(prop)] 
[docs]def get_idx_groups(text):
    """
    Get index groups from the given string.
    :type text: str
    :param text: the string
    :raise SwapFragmentsException: if there is an issue
    :rtype: list
    :return: contains list of indices
    """
    # valid patterns are like '(i,j,...);(k,l,...);...'
    idxs = []
    for token in text.split(SEPARATOR):
        if not token:
            continue
        try:
            obj = eval(token)
        except (SyntaxError, NameError):
            raise SwapFragmentsException
        if not isinstance(obj, tuple):
            raise SwapFragmentsException
        if not all(isinstance(i, int) for i in obj):
            raise SwapFragmentsException
        idxs.append(list(obj))
    return idxs 
def _get_restrain_group_idxs(st, prop):
    """
    Return a list of lists of indices of a restrain group
    in the given structure.
    :type st: schrodinger.structure.Structure
    :param st: the structure
    :type prop: str
    :param prop: the restrain property
    :rtype: list
    :return: contains lists of restrain indices
    """
    idxs = st.property.get(prop, [])
    if idxs:
        idxs = get_idx_groups(idxs)
    return idxs
def _get_new_restrain_group_idxs(st, key, old_to_new, offset=0):
    """
    Return new restrain group indices.
    :type st: schrodinger.structure.Structure
    :param st: the structure
    :type key: str
    :param key: determines the type of restrain coordinate
    :type old_to_new: dict
    :param old_to_new: the old-to-new atom index map
    :type offset: int
    :param offset: an offset added to all new restrain
        group indices, useful when assembling structures
        from parts of other structures
    :rtype: list
    :return: contains new restrain group indices
    """
    # skip the entire group of old indices if any of them either (1)
    # are not present in the map or (2) present but map to None
    new_restrain_idxs = []
    for old_list in _get_restrain_group_idxs(st, key):
        try:
            new_list = [old_to_new[idx] for idx in old_list]
        except KeyError:
            continue
        if None in new_list:
            continue
        new_restrain_idxs.append([idx + offset for idx in new_list])
    return new_restrain_idxs
[docs]def get_keep_idxs(st):
    """
    Return a list of indices of keep atoms
    in the given structure.
    :type st: schrodinger.structure.Structure
    :param st: the structure
    :rtype: list
    :return: contains indices of keep atoms
    """
    return get_idxs_marked_atoms(st, KEEP_ATOM_PROP) 
def _get_remaining_atom_idxs(struct, idxs):
    """
    Return remaining atom indices in the given structure
    that are not in the given indices.
    :type struct: schrodinger.structure.Structure
    :param struct: the structure
    :type idxs: list
    :param idxs: the indices
    :rtype: list
    :return: the remaining indices
    """
    all_idxs = set(range(1, struct.atom_total + 1))
    return sorted(all_idxs.difference(idxs))
[docs]def get_superposition_idxs(st):
    """
    Return a list of indices of superposition atoms
    in the given structure.
    :type st: schrodinger.structure.Structure
    :param st: the structure
    :rtype: list
    :return: contains indices of superposition atoms
    """
    idxs = get_idxs_marked_atoms(st, SUPERPOSITION_ATOM_PROP)
    getter = lambda x: st.atom[x].property[SUPERPOSITION_ATOM_PROP]
    return sorted(idxs, key=getter) 
[docs]def get_extracted_and_maps(st, idxs):
    """
    Extract and return a structure from the given indices as
    well as old-to-new and new-to-old atom index maps.
    :type st: schrodinger.structure.Structure
    :param st: the structure
    :type idxs: list
    :param idxs: the indices
    :rtype: schrodinger.structure.Structure, dict, dict
    :return: the extracted structure and old-to-new and
        new-to-old index maps
    """
    est = st.extract(idxs, copy_props=True)
    old_to_new_map = dict(zip(idxs, range(1, len(idxs) + 1)))
    new_to_old_map = dict(zip(range(1, len(idxs) + 1), idxs))
    return est, old_to_new_map, new_to_old_map 
[docs]def get_cut_bonds(st, keep_idxs):
    """
    Return a list of (keep, replace, order) tuples, where
    keep is an index that will be kept, replace is an index
    that will be replaced, and order is the bond order, for
    bonds involving atoms specified in the given keep indices.
    :type st: schrodinger.structure.Structure
    :param st: the structure
    :type keep_idxs: list
    :param keep_idxs: the keep indices
    :rtype: list
    :return: contains (keep, replace, order) tuples
    """
    bonds = []
    for bond in st.bond:
        idx_1 = bond.atom1.index
        idx_2 = bond.atom2.index
        keep_1 = idx_1 in keep_idxs
        keep_2 = idx_2 in keep_idxs
        if keep_1 and not keep_2:
            bonds.append((idx_1, idx_2, bond.order))
        elif not keep_1 and keep_2:
            bonds.append((idx_2, idx_1, bond.order))
    return bonds 
[docs]def get_closest_atom(nov_st, ref_st, nov_idx, ref_cell):
    """
    Return the reference atom index closest to the given novel
    atom index.
    :type nov_st: schrodinger.structure.Structure
    :param nov_st: first structure, called novel
    :type ref_st: schrodinger.structure.Structure
    :param ref_st: second structure, called reference
    :type nov_idx: int
    :param nov_idx: the novel index
    :type ref_cell: infrastructure.DistanceCell
    :param ref_cell: distance cell for the reference structure
    :rtype: int
    :return: the reference index
    """
    nov_point = nov_st.atom[nov_idx].xyz
    ref_idxs = [match.getIndex() for match in ref_cell.query_atoms(*nov_point)]
    ref_idx_distance_pairs = []
    for ref_idx in ref_idxs:
        ref_point = ref_st.atom[ref_idx].xyz
        pair = (ref_idx, measure.measure_distance(nov_point, ref_point))
        ref_idx_distance_pairs.append(pair)
    if ref_idx_distance_pairs:
        return sorted(ref_idx_distance_pairs, key=lambda x: x[1])[0][0] 
def _get_bonds_for_assembling(nov_st,
                              ref_st,
                              nov_keep_idxs,
                              ref_keep_idxs,
                              require_identical_bonds=True):
    """
    Return a list of (novel atom index, reference atom index, order)
    tuples of bonds that should be created upon assembly of a new
    structure containing the specified keep indices from the specified
    structures.
    :type nov_st: schrodinger.structure.Structure
    :param nov_st: first structure, called novel
    :type ref_st: schrodinger.structure.Structure
    :param ref_st: second structure, called reference
    :type nov_keep_idxs: list
    :param nov_keep_idxs: novel keep indices
    :type ref_keep_idxs: list
    :param ref_keep_idxs: reference keep indices
    :type require_identical_bonds: bool
    :param require_identical_bonds: whether to require that bonds to be
        created must exist in both novel and reference structures and
        be of the same bond order
    :rtype: list
    :return: contains (novel atom index, reference atom index, order)
        tuples of bonds to be created
    """
    nov_bonds = get_cut_bonds(nov_st, nov_keep_idxs)
    ref_bonds = get_cut_bonds(ref_st, ref_keep_idxs)
    if not nov_bonds or not ref_bonds:
        return []
    ref_cell = infrastructure.DistanceCell(ref_st, DISTANCE_CELL_LEN)
    bonds_to_make = []
    for nov_keep, nov_replace, order in nov_bonds:
        ref_replace = get_closest_atom(nov_st, ref_st, nov_keep, ref_cell)
        ref_keep = get_closest_atom(nov_st, ref_st, nov_replace, ref_cell)
        if (ref_keep, ref_replace, order) in ref_bonds or \
            (not require_identical_bonds and ref_keep in ref_keep_idxs):
            bonds_to_make.append((nov_keep, ref_keep, order))
    return bonds_to_make
def _set_property_on_assembled(st, ref_st, key):
    """
    Set the given property on the assembled structure.
    :type st: schrodinger.structure.Structure
    :param st: the assembled structure
    :type ref_st: schrodinger.structure.Structure
    :param ref_st: the reference structure
    :type key: str
    :param key: the structure property key
    """
    prop = st.property.get(key)
    if prop is None:
        prop = ref_st.property.get(key)
        if prop is not None:
            st.property[key] = prop
def _update_properties(st,
                       nov_st,
                       ref_st,
                       nov_old_to_new,
                       ref_old_to_new,
                       offset,
                       title=None):
    """
    Update the properties of the given assembled structure.
    :type st: schrodinger.structure.Structure
    :param st: the assembled structure
    :type nov_st: schrodinger.structure.Structure
    :param nov_st: first structure, called novel
    :type ref_st: schrodinger.structure.Structure
    :param ref_st: second structure, called reference
    :type nov_old_to_new: dict
    :param nov_old_to_new: the old-to-new atom index map for the
        extracted novel structure
    :type ref_old_to_new: dict
    :param ref_old_to_new: the old-to-new atom index map for the
        extracted reference structure
    :type offset: int
    :param offset: an offset for reference indices given that the
        assembled structure is a copy of part of the novel extended
        by part of the reference
    :type title: str
    :param title: the title to be given to the assembled structure
    """
    if not title:
        title = '_'.join([nov_st.title, ref_st.title])
    st.title = title
    rxn_key = REACTION_WF_STRUCTURE_KEY
    if not nov_st.property.get(rxn_key) and not ref_st.property.get(rxn_key):
        return
    # the assembled structure has inherited structure properties
    # from the novel and inherited atom properties from both the
    # novel and reference
    _set_property_on_assembled(st, ref_st, msprops.CHARGE_PROP)
    _set_property_on_assembled(st, ref_st, msprops.MULTIPLICITY_PROP)
    if ref_st.property.get(rxn_key):
        conf_key = CONFORMERS_GROUP_KEY
        sibling_group_key = SIBLING_GROUP_KEY
        parent_sibling_groups_key = PARENT_SIBLING_GROUPS_KEY
        st.property[rxn_key] = True
        st.property[conf_key] = ref_st.property[conf_key]
        # allow missing sibling group key for backwards compatibility
        st.property.pop(sibling_group_key, None)
        sibling = ref_st.property.get(sibling_group_key)
        if sibling:
            st.property[sibling_group_key] = sibling
        st.property.pop(parent_sibling_groups_key, None)
        parents = ref_st.property.get(parent_sibling_groups_key)
        if parents:
            st.property[parent_sibling_groups_key] = parents
        hierarchy = msutils.get_project_group_hierarchy(st=ref_st)
        msutils.set_project_group_hierarchy(st, hierarchy)
    ts_key = TRANSITION_STATE_STRUCTURE_KEY
    if nov_st.property.get(ts_key) or ref_st.property.get(ts_key):
        st.property[ts_key] = True
    # keep, superpose, and restrained atom indices are boolean atom
    # properties and so they are automatically inherited
    # handle restrain distances, angles, and dihedrals
    keys = (RESTRAINED_DISTANCES_KEY, RESTRAINED_ANGLES_KEY,
            RESTRAINED_DIHEDRALS_KEY)
    for key in keys:
        st.property.pop(key, None)
        idxs = _get_new_restrain_group_idxs(nov_st,
                                            key,
                                            nov_old_to_new,
                                            offset=0)
        idxs += _get_new_restrain_group_idxs(ref_st,
                                             key,
                                             ref_old_to_new,
                                             offset=offset)
        text = get_idx_groups_str(idxs)
        if text:
            st.property[key] = text
[docs]def get_assembled_structure(nov_st,
                            ref_st,
                            nov_superposition_idxs,
                            ref_superposition_idxs,
                            nov_keep_idxs,
                            ref_keep_idxs,
                            title=None,
                            require_identical_bonds=True):
    """
    Return an assembled structure from the given structures
    using superposition followed by extraction.
    :type nov_st: schrodinger.structure.Structure
    :param nov_st: first structure, called novel
    :type ref_st: schrodinger.structure.Structure
    :param ref_st: second structure, called reference
    :type nov_superposition_idxs: list
    :param nov_superposition_idxs: novel superposition indices
    :type ref_superposition_idxs: list
    :param ref_superposition_idxs: reference superposition indices
    :type nov_keep_idxs: list
    :param nov_keep_idxs: novel keep indices
    :type ref_keep_idxs: list
    :param ref_keep_idxs: reference keep indices
    :type title: str
    :param title: the title to be given to the assembled structure
    :type require_identical_bonds: bool
    :param require_identical_bonds: whether to require that bonds to be
        created must exist in both novel and reference structures and
        be of the same bond order
    :raise SwapFragmentsException: if there is an issue
    :rtype: schrodinger.structure.Structure
    :return: the assembled structure
    """
    if len(nov_superposition_idxs) != len(ref_superposition_idxs):
        msg = ('The number of reference and novel superposition '
               'atoms must be equivalent.')
        raise SwapFragmentsException(msg)
    rmsd.superimpose(nov_st,
                     nov_superposition_idxs,
                     ref_st,
                     ref_superposition_idxs,
                     use_symmetry=False,
                     move_which=rmsd.CT)
    nov_est, nov_old_to_new, nov_new_to_old = get_extracted_and_maps(
        nov_st, nov_keep_idxs)
    ref_est, ref_old_to_new, ref_new_to_old = get_extracted_and_maps(
        ref_st, ref_keep_idxs)
    old_idx_bonds = _get_bonds_for_assembling(
        nov_st,
        ref_st,
        nov_keep_idxs,
        ref_keep_idxs,
        require_identical_bonds=require_identical_bonds)
    new_idx_bonds = []
    for old_nov_keep, old_ref_keep, order in old_idx_bonds:
        new_nov_keep = nov_old_to_new[old_nov_keep]
        new_ref_keep = ref_old_to_new[old_ref_keep] + nov_est.atom_total
        new_idx_bonds.append((new_nov_keep, new_ref_keep, order))
    st = nov_est.copy()
    st.extend(ref_est)
    offset = nov_est.atom_total
    _update_properties(st,
                       nov_st,
                       ref_st,
                       nov_old_to_new,
                       ref_old_to_new,
                       offset,
                       title=title)
    for new_idx_bond in new_idx_bonds:
        st.addBond(*new_idx_bond)
    return st 
[docs]def get_idxs_str(idxs, sort=True):
    """
    Get a string representation of the given indices.
    :type idxs: list
    :param idxs: the idxs
    :type sort: bool
    :param sort: whether to sort
    :rtype: str
    :return: the string
    """
    if sort:
        jdxs = sorted(idxs)
    else:
        jdxs = list(idxs)
    return INDEX_SEPARATOR.join([str(j) for j in jdxs]) 
[docs]def get_idxs(le):
    """
    Get indices from the given QLineEdit.
    :type le: QtWidgets.QLineEdit
    :param le: the line edit
    :rtype: list
    :return: the indices
    """
    text = le.text().strip()
    if not text:
        return []
    else:
        return [int(s) for s in text.split(INDEX_SEPARATOR) if s]