'''
Logic to compute "interaction fields" between (protein) atoms and
"probe" particles.
'''
import math
import numpy as np
import scipy.spatial as spatial
from . import atomtyping
from . import common
#------------------------------------------------------------------------------#
# Interactions between the probes and protein atom types. Negative/positive
# values represent favorable/unfavorable interactions with the "probes":
# "arm" -- aromatic, "hyd" -- hydrophobic, "acc" -- h-bond acceptor,
# "don" -- h-bond donor, "pos"/"neg" -- positive/negative charge.
INTERACTIONS = {}
# yapf: disable
INTERACTIONS['C.2']   = {}                                # noqa: E221
INTERACTIONS['C.3']   = {'hyd': -1}                       # noqa: E221
INTERACTIONS['C.ar']  = {'arm': -1}                       # noqa: E221
INTERACTIONS['C.cat'] = {'arm': -1, 'neg': -1}
INTERACTIONS['N.3']   = {}                                # noqa: E221
INTERACTIONS['N.4']   = {'hyd':  1, 'acc': -1, 'neg': -1} # noqa: E221,E241
INTERACTIONS['N.am']  = {'hyd':  1, 'acc': -1}            # noqa: E221,E241
INTERACTIONS['N.his'] = {'hyd':  1, 'arm': -1, 'acc': -1, 'don': -1, 'neg': -1} # noqa: E241
INTERACTIONS['N.pl3'] = {'hyd':  1, 'acc': -1}            # noqa: E241
INTERACTIONS['O.2']   = {'hyd':  1, 'don': -1}            # noqa: E221,E241
INTERACTIONS['O.3']   = {'hyd':  1, 'acc': -1, 'don': -1} # noqa: E221,E241
INTERACTIONS['O.co2'] = {'hyd':  1, 'don': -1, 'pos': -1} # noqa: E241
INTERACTIONS['S.3']   = {'hyd': -1, 'arm': -1}            # noqa: E221
INTERACTIONS['m.2']   = {'hyd':  1, 'arm': -1, 'neg': -1.25, 'acc': -1, 'pos': 0.5}   # +2 metal ion
INTERACTIONS['N.ar']  = {'arm': -1}                       # aromatic N
INTERACTIONS['P.3']   = {}                                # phosphate P
# yapf: enable
ATYPES = sorted(INTERACTIONS.keys())
PROBES = sorted({k for x in INTERACTIONS.values() for k in x.keys()})
# (default) thresholds for the interactions to be taken into account
THRESHOLDS = {
    'hyd': -0.0185,
    'arm': -0.0500,
    'don': -0.0250,
    'acc': -0.0250,
    'pos': -0.0315,
    'neg': -0.0315
}
#------------------------------------------------------------------------------#
[docs]def get_hbond_direction(atom):
    '''
    Returns direction of "ideal" hydrogen bond that would be formed
    by the given atom. It is not always right, e.g. for O in C-O-H.
    This needs to be refined or proven irrelevant.
    '''
    bonded = [a for a in atom.bonded_atoms if a.atomic_number > 1]
    if not bonded:
        # O in HOH
        bonded = [a for a in atom.bonded_atoms]
    if not bonded:
        return None
    start = np.zeros(3)
    for peer in bonded:
        start += np.asarray(peer.xyz)
    start /= len(bonded)
    vec = np.asarray(atom.xyz) - start
    veclen = np.linalg.norm(vec)
    if veclen > 0.0:
        vec /= veclen
    return vec 
#------------------------------------------------------------------------------#
[docs]def get_interaction_sites(st, atom_indices=None, probes=None, logger=None):
    '''
    Identifies interaction sites among the protein atoms according to
    their atom types.
    :param st: Structure.
    :type st: `schrodinger.structure.Structure`
    :param atom_indices: Iterable over the contributing atom indices.
    :type atom_indices: iterable over int
    :param probes: Probes of interest.
    :type probes: container of str
    :return: List of atoms that interact with the requested probes.
        Individual atom contributions are given by tuples that
        hold the -1/0/1 integers (see `INTERACTIONS`) associated
        with the corresponding probe.
    :rtype: list(tuple(schrodinger.structure._StructureAtom, tuple(int)))
    '''
    logger = common.ensure_logger(logger)
    if atom_indices is None:
        atom_indices = range(1, st.atom_total + 1)
    if probes is None:
        probes = PROBES
    sites = []
    typer = atomtyping.AtomTyper()
    for (i, typ) in typer(st, atom_indices=atom_indices, logger=logger):
        if typ is None:
            logger.warning('no atom type for atom %d (%s/%s)', i,
                           st.atom[i].pdbname, st.atom[i].pdbres)
            continue
        try:
            params = INTERACTIONS[typ]
        except KeyError:
            continue
        data = tuple(params.get(p, 0) for p in probes)
        if any(data):
            sites.append((st.atom[i], data))
    return sites 
#------------------------------------------------------------------------------#
[docs]class Field(object):
    '''
    Handles computation of the "interaction potentials" generated by
    the "interaction sites" (protein atoms) acting on "probes".
    '''
[docs]    def __init__(self,
                 st,
                 atom_indices=None,
                 probes=None,
                 alpha=1.0,
                 r_cut=4.0,
                 a_cut=60.0,
                 logger=None):
        '''
        :param st: Structure.
        :type st: `schrodinger.structure.Structure`
        :param atom_indices: Iterable over the contributing atom indices.
        :type atom_indices: iterable over int
        :param probes: Probes of interest.
        :type probes: container of str
        :param alpha: Interaction range (length scale of exponential decay).
        :type alpha: float
        :param r_cut: Ignore contributions from atoms further than `r_cut`
            from a probe.
        :type r_cut: float
        :param a_cut: Ignore hydrogen bond interactions for angles
            exceeding `a_cut`.
        :type a_cut: float
        '''
        self.probes = tuple(PROBES if probes is None else probes)
        if atom_indices is None:
            atom_indices = range(1, st.atom_total + 1)
        sites = get_interaction_sites(st,
                                      atom_indices,
                                      probes=self.probes,
                                      logger=logger)
        if not sites:
            raise ValueError('no interaction sites found')
        self.alpha = alpha
        self.r_cut = r_cut
        self.a_cut = a_cut
        self.hbdir = np.zeros((len(sites), 3))
        self.params = np.ndarray((len(sites), len(self.probes)))
        self.positions = np.ndarray((len(sites), 3))
        try:
            self._acc_index = self.probes.index('acc')
        except ValueError:
            self._acc_index = None
        try:
            self._don_index = self.probes.index('don')
        except ValueError:
            self._don_index = None
        for (i, (atom, data)) in enumerate(sites):
            self.params[i] = data
            self.positions[i] = atom.xyz
            need_direction = (
                (self._acc_index is not None and data[self._acc_index]) or
                (self._don_index is not None and data[self._don_index]))
            if need_direction:
                self.hbdir[i] = get_hbond_direction(atom)
        self.kdt = spatial.cKDTree(self.positions) 
    def __call__(self, pos):
        '''
        Evaluates potentials at `pos`.
        :return: Potentials acting on probes at position `pos`.
        :rtype: list(float)
        '''
        num_probes = len(self.probes)
        energy = [0.0] * len(self.probes)
        sources = self.kdt.query_ball_point(pos, self.r_cut)
        for s in sources:
            delta = np.asarray(pos) - self.positions[s]
            r = np.linalg.norm(delta)
            cosine = np.dot(delta, self.hbdir[s])
            if r > 0:
                cosine /= r
            angle = math.degrees(math.acos(cosine))
            ksi = r / self.alpha
            for i in range(num_probes):
                check_direction = i in (self._acc_index, self._don_index)
                if check_direction and angle > self.a_cut:
                    continue
                energy[i] += self.params[s][i] * math.exp(-ksi)
        return energy
[docs]    def nearest_atom_distance(self, pos):
        '''
        Returns distance to the nearest atom that contributes to the potentials.
        '''
        return self.kdt.query(pos)[0]  
#------------------------------------------------------------------------------#