import math
import warnings
from typing import List
from typing import Tuple
from typing import Union
from schrodinger import structure
from schrodinger.infra import mmbitset
from schrodinger.infra import structure as infrastructure
from schrodinger.structutils import measure
from schrodinger.structutils import pbc_tools
DEPRECATION_WARNING = ("Steric clashes should be found with "
                       "schrodinger.infra.structure.get_contacts()")
[docs]def clash_iterator(struc1,
                   atoms1=None,
                   struc2=None,
                   atoms2=None,
                   allowable_overlap=0.4):
    """
    Iterate through all steric clashes between two groups of atoms
    :param struc1: The first structure to examine
    :type struc1: `schrodinger.structure.Structure`
    :param atoms1: A list of atom numbers for `struc1`. If given, only the
                   specified atoms will be examined. If not given, all atoms of
                   `struc1` will be used.
    :type atoms1: list(int)
    :param struc2: The second structure to examine. If not given `struc1` will
                   be used
    :type struc2: `schrodinger.structure.Structure`
    :param atoms2: A list of atom numbers for `struc2`. If given, only the
                   specified atoms will be examined. If not given, all atoms of
                   `struc2` will be used.
    :type atoms2: list(int)
    :param allowable_overlap: Steric clashes smaller than this will be ignored.
                              The default (0.4 A) is a reasonable value when
                              examining protein- protein interactions.
    :type allowable_overlap: float
    :return: A generator that iterates through steric clashes. Each iteration
             will yield a tuple of (atom from `struc1`, atom from `struc2`,
             distance). The distance between the two specified atoms in
             Angstroms (float). Note that this is the distance between the
             nuclear centers, not the size of the overlap.
    :rtype: generator
    """
    warnings.warn(DEPRECATION_WARNING, DeprecationWarning, stacklevel=2)
    if struc2 is None:
        struc2 = struc1
        both_strucs = (struc1,)
    else:
        both_strucs = (struc1, struc2)
    # Make sure that the distance cell distance is large enough to include all
    # possible clashes, but no larger
    max_vdw_radius = max(
        atom.radius for struc in both_strucs for atom in struc.atom)
    dist = 2 * max_vdw_radius - allowable_overlap
    dist_cell = measure.DistanceCellIterator(struc2, dist, atoms2)
    neighbor_iter = dist_cell.iterateNeighboringAtoms(struc1, atoms1)
    for (atom1_num, neighbor_nums) in neighbor_iter:
        atom1 = struc1.atom[atom1_num]
        atom1_radius = atom1.radius
        for atom2_num in neighbor_nums:
            atom2 = struc2.atom[atom2_num]
            dist = measure.measure_distance(atom1, atom2)
            atom2_radius = atom2.radius
            clash = dist - atom1_radius - atom2_radius + allowable_overlap
            if clash < 0:
                yield (atom1, atom2, dist) 
[docs]def sphere_overlap_volume(radius1, radius2, dist):
    """
    Calculate the volume of the overlap between two spheres
    :param radius1: The radius of the first sphere
    :type radius1: float
    :param radius2: The radius of the second sphere
    :type radius2: float
    :param dist: The distance between the centers of the two spheres
    :type dist: float
    :return: The overlap volume
    :rtype: float
    :note: The equation implemented here is taken from
        http://mathworld.wolfram.com/Sphere-SphereIntersection.html, equation 16
    """
    if dist <= abs(radius1 - radius2):
        min_rad = min(radius1, radius2)
        return 4.0 / 3.0 * math.pi * min_rad**3
    elif dist >= radius1 + radius2:
        return 0
    else:
        # Rename the variables so the names exactly match Mathworld to lower the
        # chances of a typo
        R, r, d = radius1, radius2, dist
        num_term1 = (R + r - d)**2
        num_term2 = d**2 + 2 * d * r - 3 * r**2 + 2 * d * R + 6 * r * R - 3 * R**2
        denom = 12.0 * d
        volume = (math.pi * num_term1 * num_term2) / denom
        return volume 
[docs]def clash_volume(struc1, atoms1=None, struc2=None, atoms2=None):
    """
    Calculate the volume of the steric overlap between two structures
    :param struc1: The first structure to examine
    :type struc1: `schrodinger.structure.Structure`
    :param atoms1: A list of atom numbers for `struc1`.  If given, only the
        specified atoms will be examined.  If not given, all atoms of `struc1`
        will be used.
    :type atoms1: list
    :param struc1: The second structure to examine.  If not given `struc1`
        will be used
    :type struc1: `schrodinger.structure.Structure`
    :param atoms2: A list of atom numbers for `struc2`.  If given, only the
        specified atoms will be examined.  If not given, all atoms of `struc2`
        will be used.
    :type atoms2: list
    :return: The steric overlap volume in cubic Angstroms
    :rtype: float
    :note: The overlap volume is calculated as a sum of pair-wise steric
        clashes.   This may double count overlaps for very severe steric clashes
        where there are really three spheres overlapping.
    """
    warnings.warn(DEPRECATION_WARNING, DeprecationWarning, stacklevel=2)
    clash_iter = clash_iterator(struc1, atoms1, struc2, atoms2, 0.0)
    total_volume = 0
    for atom1, atom2, dist in clash_iter:
        r1 = atom1.radius
        r2 = atom2.radius
        total_volume += sphere_overlap_volume(r1, r2, dist)
    return total_volume 
[docs]def get_steric_clashes(
    st1: structure.Structure,
    st1_atoms: Union[List[int], None] = None,
    st2: Union[structure.Structure, None] = None,
    st2_atoms: Union[List[int], None] = None,
    cutoff: float = 0.75,
    hbond_params: Union[infrastructure.AtomQueryParams, None] = None,
    salt_bridge_params: Union[infrastructure.SaltBridgeParams, None] = None,
) -> List[Tuple[structure._StructureAtom, structure._StructureAtom]]:
    """
    Return all pairs of atoms with steric clashes.
    To determine a clash, get the ratio of the distance between two atoms to
    the sum of their van der Waals radii::
        sqrt(square_distance) / (atom1_radius + atom2_radius)
    and ignore all values above a certain cutoff.
    :param st1: First structure.
    :type st1: structure.Structure
    :param st1_atoms: List of atom indices in the first structure,
            if `None` will use all atoms.
    :type st1_atoms: list[int] or NoneType
    :param st2: Second structure, if `None` will use the first structure.
    :type st2: structure.Structure or NoneType
    :param st2_atoms: List of atom indices in the second structure, if `None`
            will use all atoms.
    :type st1_atoms: list[int] or NoneType
    :param cutoff: Cutoff for clash consideration.
    :type cutoff: float
    :param hbond_params: Don't include hydrogen bonds matching `hbond_params`
            in clash list.
    :type hbond_params: schrodinger.infra.structure.AtomQueryParams
    :param salt_bridge_params: Don't include salt bridges matching
            `salt_bridge_params` in the clash list.
    :type salt_bridge_params: schrodinger.infra.structure.SaltBridgeParams
    :return: Pair of all atom pairs with steric clash.
    :rtype: list[tuple(structure._StructureAtom, structure._StructureAtom)]
    """
    st2 = st2 or st1
    if st1_atoms is None:
        st1_atoms = list(range(1, st1.atom_total + 1))
    if st2_atoms is None:
        st2_atoms = list(range(1, st2.atom_total + 1))
    bs1 = mmbitset.Bitset.from_list(size=st1.atom_total, on_list=st1_atoms)
    bs2 = mmbitset.Bitset.from_list(size=st2.atom_total, on_list=st2_atoms)
    pbc = pbc_tools.get_pbc(st1, st2, True)
    contact_params = infrastructure.ContactParams()
    contact_params.setCutoff(cutoff)
    _contacts = []
    contacts = infrastructure.get_contacts(st1, bs1, st2, bs2, contact_params,
                                           hbond_params, salt_bridge_params,
                                           pbc)
    for contact in contacts:
        # These are structure_atom.StructureAtom objects.
        st_atom1, st_atom2 = contact.getAtom1(), contact.getAtom2()
        _contacts.append(
            (st1.atom[st_atom1.getIndex()], st2.atom[st_atom2.getIndex()]))
    return _contacts