import os
import re
import sqlite3
from collections import defaultdict
from past.utils import old_div
import numpy
from schrodinger import structure
from schrodinger.infra import canvas
from schrodinger.infra import fast3d
from schrodinger.infra import mm
from schrodinger.utils import log
#------------------------------------------------------------------------------#
_LOGGER_NAME = 'cgxutils'
#------------------------------------------------------------------------------#
[docs]class SerializationError(Exception):
    pass 
#------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#
[docs]class FragmentLibrary:
    '''
    Encapsulates access to CGX fragment library (SQLite3 file).
    '''
[docs]    def __init__(self, filename):
        '''
        Opens existing library,
        or creates new one if `filename` does not exist.
        :param filename: File name.
        :type filename: `str`
        '''
        create = not os.path.exists(filename)
        if not create and not fast3d.Engine.isFragmentLibraryFile(filename):
            raise ValueError("'%s' is not a CGX fragments library file" %
                             filename)
        self.conn = sqlite3.connect(filename)
        self.conn.text_factory = str
        if create:
            self.conn.execute(
                "CREATE TABLE Conformations(smiles TEXT UNIQUE, confs BLOB)")
            self.conn.commit() 
    def __enter__(self):
        return self
    def __exit__(self, *a):
        self.close()
[docs]    def close(self):
        self.conn.close() 
    def _lookup(self, smiles):
        cur = self.conn.execute(
            'SELECT confs FROM Conformations WHERE smiles=?', (smiles,))
        return cur.fetchone()
    def _insert(self, smiles, blob, replace=False):
        sql = 'REPLACE' if replace else 'INSERT'
        sql += ' INTO Conformations VALUES (?, ?)'
        with self.conn:
            self.conn.execute(sql, (smiles, sqlite3.Binary(blob)))
[docs]    def __contains__(self, smiles):
        return bool(self._lookup(smiles)) 
[docs]    def __len__(self):
        cur = self.conn.execute('SELECT Count(smiles) FROM Conformations')
        return int(cur.fetchone()[0]) 
    def __delitem__(self, smiles):
        with self.conn:
            self.conn.execute('DELETE FROM Conformations WHERE smiles=?',
                              (smiles,))
    def __missing__(self, smiles):
        raise KeyError
    def __getitem__(self, smiles):
        '''
        Returns list of conformations. Each conformations is
        a (float, numpy.ndarray) tuple holding energy and coordinates.
        '''
        row = self._lookup(smiles)
        if row is not None:
            return deserialize_conformations(row[0])
        else:
            return self.__missing__(smiles)
    def __setitem__(self, smiles, conformations):
        self._insert(smiles,
                     serialize_conformations(conformations).data,
                     replace=True)
[docs]    def iteritems(self):
        cur = self.conn.execute(
            'SELECT smiles, confs FROM Conformations ORDER BY smiles')
        for row in cur:
            yield (row[0], deserialize_conformations(row[1]))  
#------------------------------------------------------------------------------#
def _order(i, j):
    return (i, j) if i < j else (j, i)
#------------------------------------------------------------------------------#
def _get_stereo(st, logger):
    '''
    Extracts and parses mmstereo properties starting with ``s_st_``.
    :param st: Structure.
    :type st: `schrodinger.Structure`
    :return: Dictionaries for chiral atoms and stereo bonds. The
            chirality dictionary maps chiral atom indices to 4- or 5-tuples
            `(X, a1, a2, a3, a4)` (4-tuples lack a4) where `X` denotes the
            "chirality" (R, S, ANR or ANS), and `a1, a2, a3, a4` are bonded
            atom indices in the appropriate order. The stereo dictionary
            maps bonds (as ordered 2-tuples of atom indices) to the tuples
            like `(X, a1, a2, a3, a4)` where `X` is E, Z, P or M and the
            remaining members are indices of the atoms involved.
    :rtype: `tuple(dict, dict)`
    '''
    chiral_centers = dict()
    cistrans_bonds = dict()
    for (k, v) in st.property.items():
        if re.match(r'^s_st_Chirality_(\d+)$', k):
            m = re.match(r'^(\d+)_(R|S|ANR|ANS|\?)_(\d+)_(\d+)_(\d+)(_(\d+))?$',
                         v)
            if not m:
                logger.warn('could not parse chirality %s', v)
                continue
            atom = int(m.group(1))
            data = [
                m.group(2),
                int(m.group(3)),
                int(m.group(4)),
                int(m.group(5))
            ]
            if m.group(7) is not None:
                data.append(int(m.group(7)))
            chiral_centers[atom] = tuple(data)
        elif re.match(r'^s_st_EZ_(\d+)$', k):
            m = re.match(r'^((\d+)_)+(E|Z|P|M|\?)$', v)
            if not m:
                logger.warn('could not parse cis/trans %s', v)
                continue
            m = v.split('_')
            if len(m) < 4:
                logger.warn('too few atoms in cis/trans %s', v)
                continue
            data = [m[-1]] + [int(s) for s in m[:-1]]
            cistrans_bonds[_order(data[2], data[3])] = data
    return (chiral_centers, cistrans_bonds)
#------------------------------------------------------------------------------#
[docs]def recompute_stereo(st):
    _discard_stereo(st)
    _compute_stereo(st) 
#------------------------------------------------------------------------------#
[docs]def has_undefined_stereo(st, logger):
    '''
    Extracts and parses mmstereo properties (s_st), return True if
    there is an undefined (labelled as '?') chirality or cis/trans.
    :param st: Structure
    :type st: `schrodinger.Structure`
    '''
    chiral_centers, cistrans_bonds = _get_stereo(st, logger)
    for data in chiral_centers.values():
        if data[0] == '?':
            return True
    for data in cistrans_bonds.values():
        if data[0] == '?':
            return True
    return False 
#------------------------------------------------------------------------------#
def _discard_stereo(st):
    '''
    Deletes mmstereo properties.
    '''
    names = [s_st for s_st in st.property if re.match(r'^s_st_', s_st)]
    for name in names:
        del st.property[name]
#------------------------------------------------------------------------------#
def _compute_stereo(st):
    handle = mm.mmstereo_new(st)
    try:
        mm.mmstereo_add_stereo_information_to_ct(st, handle)
    finally:
        mm.mmstereo_delete(handle)
#------------------------------------------------------------------------------#
def _parity(*x):
    '''
    Returns parity of the permutation (O(N^2)).
    '''
    assert len(x) > 1
    if len(x) == 2:
        return x[0] > x[1]
    else:
        outcome = _parity(*x[1:])
        for y in x[1:]:
            outcome ^= (y > x[0])
        return outcome
#------------------------------------------------------------------------------#
def _set_stereo(origin_cc, origin_sb, st, logger, i_f3d_origin):
    '''
    Set fragment's ``s_st_`` properties from the origin's data in `origin_cc`
    and `origin_sb`.
    :param origin_cc: Chiral centers dictionary returned by `_get_stereo()`
            for the origin (parent) structure.
    :type origin_cc: `dict`
    :param origin_sb: EZPM dictionary returned by `_get_stereo()`
            for the origin (parent) structure.
    :type origin_sb: `dict`
    '''
    flip = {'R': 'S', 'S': 'R', 'ANS': 'ANR', 'ANR': 'ANS'}
    _compute_stereo(st)
    cc, sb = _get_stereo(st, logger)
    _discard_stereo(st)
    # chiralities
    i = 0
    for (atom, data) in cc.items():
        oatom = st.atom[atom].property[i_f3d_origin]
        try:
            origin_data = origin_cc[oatom]
        except KeyError:
            continue  # was not annotated in the parent structure
        if len(data) != len(origin_data):
            logger.warn('number of chiral center neighbors has '
                        'changed after fragmentation')
            continue
        # "origin" indices can be negative (implying "cap" atoms)
        mapped_neighbors = \
            [abs(st.atom[j].property[i_f3d_origin]) for j in data[1:]]
        origin_parity = _parity(*origin_data[1:])
        mapped_parity = _parity(*mapped_neighbors)
        new_chirality = origin_data[0]
        if origin_parity ^ mapped_parity:
            # ranking of atoms in the fragment has different parity
            new_chirality = flip.get(new_chirality, new_chirality)
        i += 1
        st.property['s_st_Chirality_%d' %
                    i] = '%d_%s_%s' % (atom, new_chirality, '_'.join(
                        map(str, data[1:])))
    # cis/trans
    i = 0
    for (atom1, atom2), data in sb.items():
        oatom1 = st.atom[atom1].property[i_f3d_origin]
        oatom2 = st.atom[atom2].property[i_f3d_origin]
        try:
            origin_data = origin_sb[_order(oatom1, oatom2)]
        except KeyError:
            continue  # was not annotated in the parent structure
        i += 1
        st.property['s_st_EZ_%d' %
                    i] = '_'.join(map(str, data[1:])) + '_' + origin_data[0]
#------------------------------------------------------------------------------#
[docs]def chop(engine, st, logger=None, i_f3d_origin='i_f3d_origin'):
    '''
    Chops structure into ConfGenX fragments (side effect: modifies
    atom properties of the input structure).
    :param engine: Fast3D engine.
    :type engine: `schrodinger.infra.fast3d.Engine`
    :param st: Structure (may include several molecules).
    :type st: `schrodinger.Structure`
    :param i_f3d_origin: Name of the atom property to keep track of the
                         original atom indices.
    :type i_f3d_origin: `str`
    :return: List of fragments.
    :rtype: `list(schrodinger.Structure)`
    '''
    if logger is None:
        logger = log.get_output_logger(_LOGGER_NAME)
    chiral_centers, stereo_bonds = _get_stereo(st, logger)
    for atom in st.atom:
        atom.property[i_f3d_origin] = atom.index
    num_frags, partition = engine.partition(st)
    frag_atom_indices = [[] for f in range(0, num_frags)]
    for atom in st.atom:
        frag_index = partition[atom.index] - 1
        frag_atom_indices[frag_index].append(atom.index)
    for bond in st.bond:
        i1 = bond.atom1.index
        f1 = partition[i1]
        i2 = bond.atom2.index
        f2 = partition[i2]
        if f1 != f2:
            # inter-fragment bond, keep "cap" atoms
            frag_atom_indices[f1 - 1].append(i2)
            frag_atom_indices[f2 - 1].append(i1)
    frags = [st.extract(i, copy_props=True) for i in frag_atom_indices]
    # discard s_st_ properties from fragments to avoid low-level warnings
    for f in frags:
        _discard_stereo(f)
    # update the "cap" atoms
    for (i, f) in enumerate(frags, 1):
        for atom in f.atom:
            original_index = atom.property[i_f3d_origin]
            if partition[original_index] != i:
                atom.property[i_f3d_origin] = -original_index
                atom.atomic_number = 1
                atom.formal_charge = 0
                atom.retype()
    # propagate stereochemistry
    for f in frags:
        _set_stereo(chiral_centers, stereo_bonds, f, logger, i_f3d_origin)
    return frags 
#------------------------------------------------------------------------------#
[docs]def is_builtin_fragment(smiles):
    '''
    Checks whether the smiles is for one of the "built-in" fragments.
    '''
    internal = {'[H]O[H]', '[H]C([H])([H])[H]', 'O=C([H])N([H])[H]'}
    return smiles in internal 
#------------------------------------------------------------------------------#
[docs]def uniquesmiler():
    '''
    Returns unique SMILES generator.
    '''
    adaptor = canvas.ChmMmctAdaptor()
    generator = canvas.ChmUniqueSmilesGenerator()
    generator.wantStereo(True)
    generator.wantCharge(True)
    generator.wantAllHydrogens(True)
    def smiler(st):
        chmol = adaptor.create(st.handle,
                               canvas.ChmMmctAdaptor.StereoFromAnnotation_Safe)
        return generator.getSmilesAndMap(chmol)
    return smiler 
#------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#
[docs]def default_custom_fraglib():
    '''
    Path to the custom fragments library that is to be used by default.
    '''
    return fast3d.Engine.defaultCustomRepository() 
#------------------------------------------------------------------------------#