'''
Utilities to find cavities in groups of atoms and cover them with grids.
'''
import math
import numpy as np
import scipy.spatial as spatial
#------------------------------------------------------------------------------#
[docs]def dist3(a, b):
    dx, dy, dz = a[0] - b[0], a[1] - b[1], a[2] - b[2]
    return math.sqrt(dx**2 + dy**2 + dz**2) 
#------------------------------------------------------------------------------#
class _UnionOfSpheres(object):
    def __init__(self, centers, radii):
        '''
        Does not make copies of the arguments.
        :param centers: Positions of the spheres.
        :type centers: `numpy.ndarray`
        :param radii: Radii of the spheres.
        :type radii: `numpy.ndarray`
        '''
        self.centers = centers
        self.radii = radii
        self.max_radius = np.amax(self.radii)
        self.kdt = spatial.cKDTree(self.centers)
    def __contains__(self, pos):
        '''
        Check whether point `pos` is inside one of the spheres in the union.
        '''
        return self.intersects(pos=pos, radius=0.0)
    def intersects(self, pos, radius):
        '''
        Check whether the sphere of given `radius` with center at `pos`
        intersects with one of the spheres in the union.
        '''
        neighbors = self.kdt.query_ball_point(pos, self.max_radius + radius)
        for n in neighbors:
            r = dist3(pos, self.centers[n])
            if r < self.radii[n] + radius:
                return True
        return False
#------------------------------------------------------------------------------#
[docs]def surfnet_gaps_supplier(centers,
                          radii,
                          min_radius=1.0,
                          max_radius=4.0,
                          ligand_centers=None,
                          ligand_radii=None):
    '''
    Generator of the "gap" spheres obtained using SURFNET algorithm:
    Roman A.Laskowski
    "SURFNET: A program for visualizing molecular surfaces,
    cavities, and intermolecular interactions"
    https://doi.org/10.1016/0263-7855(95)00073-9
    A sphere is placed so that the two given atoms are on opposite sides
    of the sphere's surface. If the sphere contains any other atoms, it
    is reduced in size until no more atoms are contained. Only spheres
    with a radius of 1 to 4 angstrom are kept. The result of this
    procedure is a number of separate groups of interpenetrating
    spheres, called gap regions, both inside the protein and on its
    surface, which correspond to the protein's cavities and clefts.
    :param centers: Positions of the spheres to be considered.
    :type centers: `numpy.ndarray`
    :param radii: Radii of the spheres to be considered.
    :type radii: `numpy.ndarray`
    :param min_radius: Smallest allowed radius of the "gap" sphere.
    :type min_radius: float
    :param max_radius: Largest allowed radius of the "gap" sphere.
    :type max_radius: float
    :param ligand_centers: Position of ligand atoms or `None`. If not
        `None`, only emit spheres that intersect with one of the "ligand"
        spheres.
    :type ligand_centers: `numpy.ndarray`
    :param radii: Radii of the ligand atoms or `None`. If not `None`,
        number of the radii must be equal to the number of `ligand_centers`.
    :type radii: `numpy.ndarray`
    '''
    kdt = spatial.cKDTree(centers)
    rmax = np.amax(radii)
    if ligand_centers is None:
        ligand = None
    else:
        ligand = _UnionOfSpheres(ligand_centers, ligand_radii)
    num_spheres = centers.shape[0]
    for i in range(num_spheres):
        pos_i = centers[i]
        radius_i = radii[i]
        candidates = kdt.query_ball_point(pos_i, 2 * (rmax + max_radius))
        candidates = filter(lambda j: j > i, candidates)
        for j in candidates:
            pos_j = centers[j]
            radius_j = radii[j]
            x_ij = pos_j - pos_i
            dist_ij = np.linalg.norm(x_ij)
            # +---+----------x---------+-----------+
            # i  r_i        pos       r_j          j
            radius = (dist_ij - radius_i - radius_j) / 2
            if radius < min_radius or radius > max_radius:
                continue
            pos = pos_i + ((radius_i + radius) / dist_ij) * x_ij
            if ligand is not None and not ligand.intersects(pos, radius):
                continue
            neighbors = kdt.query_ball_point(pos, radius + rmax)
            for n in neighbors:
                dist = dist3(pos, centers[n])
                radius = min(radius, dist - radii[n])
                if radius < min_radius:
                    break
            if radius < min_radius:
                continue
            if ligand is not None and not ligand.intersects(pos, radius):
                continue
            yield (pos, radius) 
#------------------------------------------------------------------------------#
[docs]def get_surfnet_gaps(centers, radii, min_radius=1.0, max_radius=4.0):
    '''
    Merges spheres generated by `surfnet_gaps_supplier()` into
    connected groups (clefts).
    :return: List of clefts. Each cleft is a list of `((x, y, z), radius)`
        tuples. Clefts are sorted by their size in discending order.
    :rtype: list(list(tuple))
    '''
    clefts = []
    supplier = surfnet_gaps_supplier(centers, radii, min_radius, max_radius)
    for (pos, radius) in supplier:
        merge = []
        for (i, cleft) in enumerate(clefts):
            for (_pos, _radius) in cleft:
                x12 = pos - _pos
                if np.dot(x12, x12) < (radius + _radius)**2:
                    merge.append(i)
                    break
        if not merge:
            # new cleft
            clefts.append([(pos, radius)])
        else:
            clefts[merge[0]].append((pos, radius))
            if len(merge) > 1:
                for j in merge[1:]:
                    clefts[merge[0]].extend(clefts[j])
                    clefts[j] = None
                clefts = [c for c in clefts if c is not None]
    return sorted(clefts, key=len, reverse=True) 
#------------------------------------------------------------------------------#
[docs]def get_atom_positions(atoms, want_radii=True):
    '''
    If `want_radii` is `True`, returns a tuple of `numpy.ndarray` instances
    holding positions and vdW radii of the atoms. Returns a single
    `numpy.ndarray` holding the coordinates otherwise.
    :param atoms: Container of atoms.
    :type atoms: container of `schrodinger.structure._StructureAtom`
    :return: Positions (and vdW radii) of the atoms.
    :rtype: (numpy.ndarray, numpy.ndarray) or `numpy.ndarray`
    '''
    if not atoms:
        return (np.empty((0, 3)), np.empty(0)) if want_radii else np.empty(0)
    num_atoms = len(atoms)
    positions = np.empty((num_atoms, 3))
    if want_radii:
        radii = np.empty(num_atoms)
    for (i, atom) in enumerate(atoms):
        positions[i] = atom.xyz
        if want_radii:
            radii[i] = atom.radius
    return (positions, radii) if want_radii else positions 
#------------------------------------------------------------------------------#
[docs]def get_com_and_axes(spheres):
    '''
    Calculates centroid and principal axes of the spheres.
    :param spheres: Spheres to be considered.
    :type spheres: container of ((x, y, z), radius)
    :return: Center of mass and principal axes of the spheres.
    :rtype: tuple(numpy.ndarray, numpy.ndarray)
    '''
    num_spheres = len(spheres)
    weight = np.empty(num_spheres)
    xyz = np.empty((num_spheres, 3))
    weight_sum = 0.0
    for (i, (pos, radius)) in enumerate(spheres):
        weight[i] = radius**3
        weight_sum += weight[i]
        xyz[i] = pos
    if weight_sum > 0.0:
        weight /= weight_sum
    com = np.sum(xyz * weight[:, np.newaxis], axis=0)
    tensor = np.zeros((3, 3))
    indices = np.ndindex(3, 3)
    for i, j in indices:
        delta = int(i == j)
        elem = np.zeros(num_spheres)
        for k in range(num_spheres):
            vec = xyz[k] - com
            r2d = np.sum(vec * vec) * delta
            elem[k] = weight[k] * (r2d - vec[i] * vec[j])
        tensor[i][j] = np.sum(elem)
    moments, axes = np.linalg.eigh(tensor)
    return com, axes 
#------------------------------------------------------------------------------#
[docs]def get_grid_parameters(spheres, com=None, axes=None, step=0.5, border=0.5):
    '''
    Determines parameters of the grid necessary to cover the spheres.
    :param spheres: The spheres to be considered.
    :type spheres: container of ((x, y, z), radius)
    :param com: Center of mass of the spheres (or None).
    :type com: `numpy.ndarray`
    :param axes: 3x3 matrix whose columns hold principal axes of the spheres.
    :type axes: `numpy.ndarray`
    :param step: Grid step.
    :type step: float
    :param border: Border around the spheres.
    :type border: float
    :return: Grid size, origin and vectors.
    :rtype: ((int, int, int), numpy.ndarray, numpy.ndarray)
    '''
    if com is None:
        com = np.zeros(3)
    if axes is None:
        axes = np.eye(3)
    rotation = np.transpose(axes)
    lo = None
    hi = None
    for (pos, radius) in spheres:
        x = rotation @ (np.asarray(pos) - com)
        if lo is None:
            lo = x - radius * np.ones(3)
            hi = x + radius * np.ones(3)
        else:
            lo = np.minimum(lo, x - radius * np.ones(3))
            hi = np.maximum(hi, x + radius * np.ones(3))
    # pick direction for the principal axes
    chosenaxes = np.empty((3, 3))
    for i in range(3):
        if abs(lo[i]) > abs(hi[i]):
            lo[i], hi[i] = -hi[i], -lo[i]
            chosenaxes[:, i] = -axes[:, i]
        else:
            chosenaxes[:, i] = axes[:, i]
    # add "buffer" around the orthogonal parallelipiped covering the spheres
    lo -= border * np.ones(3)
    hi += border * np.ones(3)
    # determine the grid size
    size = tuple(int((hi[i] - lo[i]) / step) + 1 for i in range(3))
    # actual extent covered by the grid
    actual = step * np.asarray(np.asarray(size))
    # "center" the grid over the spheres
    lo -= 0.5 * (actual - (hi - lo))
    # origin and grid vectors
    origin = chosenaxes @ lo + com
    vectors = (chosenaxes @ np.eye(3)) * step
    return size, origin, vectors 
#------------------------------------------------------------------------------#
[docs]class Cavity(_UnionOfSpheres):
[docs]    def __init__(self, spheres):
        '''
        :param spheres: Spheres that define the cavity.
        :type spheres: container of (pos, radius) tuples
        '''
        if not spheres:
            raise ValueError('need at least one sphere')
        num_spheres = len(spheres)
        radii = np.empty(num_spheres)
        centers = np.empty((num_spheres, 3))
        for (i, (pos, radius)) in enumerate(spheres):
            radii[i] = radius
            centers[i] = pos
        super().__init__(centers, radii)  
#------------------------------------------------------------------------------#
[docs]class Grid(object):
    '''
    Maps grid indices to Cartesian coordinates using provided
    grid origin and unit vectors.
    '''
[docs]    def __init__(self, origin, vectors):
        self.origin = origin
        self.vectors = vectors 
    def __call__(self, ijk):
        return self.origin \
            
+ ijk[0] * self.vectors[0] \
            
+ ijk[1] * self.vectors[1] \
            
+ ijk[2] * self.vectors[2]
[docs]    def indicesToPositions(self, indices):
        '''
        Returns positions for the indices as single `numpy.ndarray`.
        :param indices: Container of triples of integers.
        :type indices: container
        :return: Array of positions.
        :rtype: `numpy.ndarray`
        '''
        num_points = len(indices)
        outcome = np.empty((num_points, 3))
        for (i, ijk) in enumerate(indices):
            outcome[i] = self.__call__(ijk)
        return outcome 
    @property
    def delta(self):
        return tuple(np.linalg.norm(self.vectors[i, :]) for i in range(3)) 
#------------------------------------------------------------------------------#