"""
Find salt-bridge interactions.
Examples
========
Find all salt bridge interactions within a protein::
    st = structure.Structure.read("protein.mae.gz")
    for atom1, atom2 in get_salt_bridges(st):
        print(f"Salt bridge between atoms {atom1.index} and {atom2.index}")
Find all salt bridges within a single protein chain::
    st = structure.Structure.read("protein.mae.gz")
    atoms = st.chain["C"].getAtomIndices()
    for atom1, atom2 in get_salt_bridges(st, atoms):
        print(f"Salt bridge between atoms {atom1.index} and {atom2.index}")
Find all salt bridges between a protein and a ligand::
    with StructureReader("protein_and_ligand.mae.gz") as reader:
        prot, lib = reader
    for atom1, atom2 in get_salt_bridges(prot, struc2=lig):
        print(f"Salt bridge between atoms {atom1.index} and {atom2.index}")
"""
# Copyright Schrodinger, LLC. All rights reserved.
import enum
from schrodinger import structure
from schrodinger.infra import mmbitset
from schrodinger.infra import structure as infrastructure
from schrodinger.structutils import pbc_tools
DEFAULT_MAX_DIST = 4.0
OrderBy = enum.Enum("OrderBy", ["AnionCation", "InputOrder"])
[docs]def get_salt_bridges(struc1,
                     group1=None,
                     struc2=None,
                     group2=None,
                     cutoff=5,
                     order_by=OrderBy.AnionCation,
                     honor_pbc=True):
    """
    Calculate all salt bridges within or between the specified atoms.  If struc2
    or group2 are given, then this function will return salt bridges between the
    two structures/groups of atoms.  If neither struc2 nor group2 are given,
    then this function will return salt bridges within a single structure/group
    of atoms.
    :param struc1: The structure to analyze
    :type struc1: `schrodinger.structure.Structure`
    :param group1: The list of atom indices in `struc1` to analyze.  If not
        given, all atoms in struc1 will be analyzed.
    :type group1: list
    :param struc2: The second structure to analyze.  If `group2` is given
        but `struc2` is not, then `struc1` will be used.
    :type struc2: `schrodinger.structure.Structure`
    :param group2: The list of atom indices in `struc2` to analyze.  If
        `struc2` is given but `group2` is not, then all atoms in `struc2` will be
        analyzed.
    :type group2: list
    :param cutoff: The maximum distance allowed for salt bridges
    :type cutoff: float
    :param order_by: How the returned salt bridge atom should be ordered.  If
        `OrderBy.AnionCation`, then each salt bridge will be returned as a tuple of
        (anion atom, cation atom).  If `OrderBy.InputOrder`, then each salt bridge
        will be returned as a tuple of (atom from struc1/group1, atom from
        struc2/group2).
    :type order_by: `OrderBy`
    :return: A list of salt bridges, where each salt bridge is represented by a
        tuple of two `schrodinger.structure._StructureAtom` objects.
    :rtype: list
    """
    if struc2 is None and group2 is None and order_by is OrderBy.InputOrder:
        err = ("Cannot order by input order when finding salt bridges within a "
               "single structure or region.")
        raise ValueError(err)
    salt_bridges, bs1 = _get_wrapped_sb_list(struc1, group1, struc2, group2,
                                             cutoff, honor_pbc)
    salt_bridges = list(map(_convert_salt_bridge, salt_bridges))
    if order_by is OrderBy.InputOrder:
        salt_bridges = [_to_input_order(sb, struc1, bs1) for sb in salt_bridges]
    return salt_bridges 
def _get_wrapped_sb_list(struc1,
                         group1,
                         struc2,
                         group2,
                         cutoff,
                         honor_pbc=True):
    """
    Get the salt bridge list returned by
    `schrodinger.infra.structure.get_salt_bridges`.
    :param struc1: The structure to analyze
    :type struc1: `schrodinger.structure.Structure`
    :param group1: The list of atom indices in `struc1` to analyze.  If not
        given, all atoms in struc1 will be analyzed.
    :type group1: list
    :param struc2: The second structure to analyze.  If `group2` is given
        but `struc2` is not, then `struc1` will be used.
    :type struc2: `schrodinger.structure.Structure`
    :param group2: The list of atom indices in `struc2` to analyze.  If
        `struc2` is given but `group2` is not, then all atoms in `struc2` will be
        analyzed.
    :type group2: list
    :param cutoff: The maximum distance allowed for salt bridges
    :type cutoff: float
    :return: A list of salt bridges, where each salt bridge is represented by a
        tuple of two `schrodinger.infra.structure.StructureAtom` objects in (anion,
        cation) order.
    :rtype: list
    """
    params = get_salt_bridge_params(cutoff=cutoff)
    bs1 = _convert_group(group1, struc1)
    pbc = pbc_tools.get_pbc(struc1, struc2, honor_pbc)
    if struc2 is None and group2 is None:
        salt_bridges = infrastructure.get_salt_bridges(struc1, bs1, params, pbc)
    else:
        if struc2 is None:
            struc2 = struc1
        bs2 = _convert_group(group2, struc2)
        salt_bridges = infrastructure.get_salt_bridges(struc1, bs1, struc2, bs2,
                                                       params, pbc)
    return salt_bridges, bs1
def _convert_group(group, struc):
    """
    Convert a list of atom indices to a bitset.
    :param group: A list of atom indices.  If None, the bitset will contain all
        atoms in `struc`.
    :type group: list
    :param struc: The structure that the `group` atom indices refer to
    :type struc: `schrodinger.structure.Structure`
    :return: The bitset
    :rtype: `mmbitset.Bitset`
    """
    bs = mmbitset.Bitset(size=struc.atom_total)
    if group is not None:
        list(map(bs.set, group))
    else:
        bs.fill()
    return bs
def _convert_salt_bridge(salt_bridge):
    """
    Convert `schrodinger.infra.structure.SaltBridge` object to a tuple of
    (anion, cation) `schrodinger.structure._StructureAtom` objects.
    """
    anion = _convert_atom(salt_bridge.getAnion())
    cation = _convert_atom(salt_bridge.getCation())
    return (anion, cation)
def _convert_atom(infra_atom):
    """
    Convert a `schrodinger.infra.structure.StructureAtom` object into a
    `schrodinger.structure._StructureAtom`
    """
    atom_index = infra_atom.getIndex()
    cpp_st = infra_atom.getStructure()
    st = structure.Structure(cpp_st)
    return st.atom[atom_index]
def _to_input_order(salt_bridge, struc1, bs1):
    """
    Switch the salt bridge tuple from (anion, cation) order to input order.
    :param salt_bridge: The salt bridge to re-order, as a tuple of
        `schrodinger.structure._StructureAtom` objects.
    :type salt_bridge: tuple
    :param struc1: The first structure
    :type struc1: `schrodinger.structure.Structure`
    :param bs1: The bitset of group 1 atoms
    :type bs1: `mmbitset.Bitset`
    :return: The re-ordered salt bridge, as a tuple of
        `schrodinger.structure._StructureAtom` objects.
    :rtype: tuple
    """
    anion, cation = salt_bridge
    if anion._ct == struc1 and bs1.get(anion.index):
        return anion, cation
    else:
        return cation, anion
[docs]def get_salt_bridge_params(
        cutoff: float = None) -> infrastructure.SaltBridgeParams:
    """
    Return salt bridge `SaltBridgeParams` object with the given criteria.
    :param cutoff: See `get_salt_bridges`.
    :return: `SaltBridgeParams` with the given salt bridge criteria.
    """
    params = infrastructure.SaltBridgeParams()
    params.setCutoff(cutoff)
    return params