'''
Basic CML reader: http://www.xml-cml.org/
'''
import contextlib
import json
import xml.etree.ElementTree as ET
from collections import defaultdict
from rdkit import Chem
from . import common
#------------------------------------------------------------------------------#
# Map of CML bond orders to RDKit bond types.
_CML_BOND_ORDER_TO_RDK = {
    '1' : Chem.BondType.SINGLE,
    'S' : Chem.BondType.SINGLE,
    '2' : Chem.BondType.DOUBLE,
    'D' : Chem.BondType.DOUBLE,
    '3' : Chem.BondType.TRIPLE,
    'T' : Chem.BondType.TRIPLE,
    'A' : Chem.BondType.AROMATIC
} # yapf: disable
# CML atom attributes *not* to be set as RDKit atom properties.
_SPECIAL_CML_ATOM_PROPS = {
    'elementType',
    'formalCharge',
    'x2', 'y2',
    'x3', 'y3', 'z3'
} # yapf: disable
# CML bond attributes *not* to be set as RDKit bond properties.
_SPECIAL_CML_BOND_PROPS = {
    'atomRefs2',
    'order'
} # yapf: disable
#------------------------------------------------------------------------------#
[docs]def parse_and_drop_xmlns(source):
    '''
    Parses `source` via `xml.etree.ElementTree.iterparse`,
    and removes namespace prefix (if present) from tags and
    attributes.
    :param source: File name or file-like object.
    :type source: str or file-like
    :return: Root of the tree.
    :rtype: `xml.etree.ElementTree.Element`
    '''
    it = ET.iterparse(source)
    # from https://stackoverflow.com/a/33997423
    for _, el in it:
        if '}' in el.tag:
            el.tag = el.tag.split('}', 1)[1]  # strip all namespaces
            # strip namespaces of attributes too
            for at in list(el.attrib.keys()):
                if '}' in at:
                    newat = at.split('}', 1)[1]
                    el.attrib[newat] = el.attrib[at]
                    del el.attrib[at]
    return it.root 
#------------------------------------------------------------------------------#
def _get_xml_atom_coordinates(xml_atom):
    '''
    Returns `xml_atom` coordinates or None.
    :param xml_atom: CML atom element.
    :type xml_atom: `xml.etree.ElementTree.Element`
    :return: Cartesian coordinates or None.
    :rtype: [float, float, float] or None
    '''
    # catch `TypeError` because xml.etree.ElementTree.Element.get()
    # does not raise KeyError -- it returns None instead
    try:
        return [float(xml_atom.get(c)) for c in ('x3', 'y3', 'z3')]
    except TypeError:
        pass
    try:
        return [float(xml_atom.get(c)) for c in ('x2', 'y2')] + [0.0]
    except TypeError:
        return None
#------------------------------------------------------------------------------#
def _make_rdk_atom_from_xml_atom(xml_atom, prop_prefix=common.CML_PROP_PREFIX):
    '''
    Instantiates `rdkit.Chem.Atom` from the XML element data.
    :param xml_atom: XML element that represents CML atom.
    :type xml_atom: `xml.etree.ElementTree.Element`
    :param prop_prefix: Prefix to be added to the CML property names.
    :type prop_prefix: str
    :return: RDKit atom and its coordinates.
    :rtype: rdkit.Chem.Atom, [float, float, float] or None
    '''
    element = xml_atom.get('elementType')
    if not element or element.upper() in ('R', 'X'):
        element = '*'
    rdk_atom = Chem.Atom(element)
    try:
        rdk_atom.SetFormalCharge(int(xml_atom.get('formalCharge', '')))
    except ValueError:
        pass
    for (k, v) in xml_atom.attrib.items():
        if k not in _SPECIAL_CML_ATOM_PROPS:
            rdk_atom.SetProp(prop_prefix + k, v)
    return rdk_atom, _get_xml_atom_coordinates(xml_atom)
#------------------------------------------------------------------------------#
def _decorate_rdk_bond_from_xml_bond(rdk_bond,
                                     xml_bond,
                                     prop_prefix=common.CML_PROP_PREFIX):
    '''
    Propagates data from the XML element that represents CML bond
    to the RDKit bond.
    :param rdk_bond: RDKit bond.
    :type rdk_bond: rdkit.Chem.Bond
    :param xml_bond: XML element that represents CML bond.
    :type xml_bond: `xml.etree.ElementTree.Element`
    :param prop_prefix: Prefix to be added to the CML property names.
    :type prop_prefix: str
    '''
    try:
        rdk_bond.SetBondType(_CML_BOND_ORDER_TO_RDK[xml_bond.get('order')])
    except KeyError:
        pass
    for bs in xml_bond.findall('./bondStereo'):
        stereo = bs.text.strip().upper()
        if stereo == 'W':
            rdk_bond.SetBondDir(Chem.BondDir.BEGINWEDGE)
        elif stereo == 'H':
            rdk_bond.SetBondDir(Chem.BondDir.BEGINDASH)
    for (k, v) in xml_bond.attrib.items():
        if k not in _SPECIAL_CML_BOND_PROPS:
            rdk_bond.SetProp(prop_prefix + k, v)
#------------------------------------------------------------------------------#
def _adapt_enhanced_stereo(mol, prop_prefix):
    '''
    Translates "enhanced stereochemistry" data from MRV-specific
    properties into RDKit conventions.
    :param mol: R/W molecule.
    :type mol: `rdkit.Chem.RWMol`
    :param prop_prefix: Prefix used for the CML property names.
    :type prop_prefix: str
    '''
    group_id_prop = prop_prefix + 'mrvStereoGroup'
    groups = defaultdict(list)
    for atom in mol.GetAtoms():
        try:
            group_id = atom.GetProp(group_id_prop)
            groups[group_id].append(atom.GetIdx())
        except (KeyError, ValueError):
            pass
    rdk_stereo_groups = []
    for group_id, atom_indices in groups.items():
        if group_id.startswith('or'):
            rdk_sg_type = Chem.StereoGroupType.STEREO_OR
        elif group_id.startswith('and'):
            rdk_sg_type = Chem.StereoGroupType.STEREO_AND
        elif group_id.startswith('abs'):
            rdk_sg_type = Chem.StereoGroupType.STEREO_ABSOLUTE
        else:
            continue
        rdk_stereo_groups.append(
            Chem.CreateStereoGroup(rdk_sg_type, mol, atom_indices))
    mol.SetStereoGroups(rdk_stereo_groups)
    for rsg in rdk_stereo_groups:
        for atom in rsg.GetAtoms():
            atom.ClearProp(group_id_prop)
#------------------------------------------------------------------------------#
[docs]def rdk_mol_from_cml_element(xml_molecule, prop_prefix=common.CML_PROP_PREFIX):
    '''
    Instantiates `rdkit.Chem.Mol` from the XML element.
    :param xml_molecule: XML element that represents CML molecule.
    :type xml_molecule: `xml.etree.ElementTree.Element`
    :param prop_prefix: Prefix to be added to the CML property names.
    :type prop_prefix: str
    :return: RDKit molecule.
    :rtype: rdkit.Chem.ROMol
    '''
    rwmol = Chem.RWMol(Chem.Mol())
    # atoms
    xml_atoms = xml_molecule.findall('./atomArray/atom')
    rdk_atoms_and_coordinates = (_make_rdk_atom_from_xml_atom(
        e, prop_prefix=prop_prefix) for e in xml_atoms)
    atom_id2idx = dict()  # ID -> RDKit index
    atom_id_property = prop_prefix + 'id'
    coordinates = []
    for rdk_atom, crd in rdk_atoms_and_coordinates:
        rdk_index = rwmol.AddAtom(rdk_atom)
        if rdk_atom.HasProp(atom_id_property):
            atom_id2idx[rdk_atom.GetProp(atom_id_property)] = rdk_index
        if crd:
            coordinates.append(crd)
    # bonds
    for xml_bond in xml_molecule.findall('./bondArray/bond'):
        atomRefs = xml_bond.get('atomRefs2').split()
        assert len(atomRefs) == 2
        atom1_idx = atom_id2idx[atomRefs[0]]
        atom2_idx = atom_id2idx[atomRefs[1]]
        num_bonds = rwmol.AddBond(atom1_idx, atom2_idx)
        rdk_bond = rwmol.GetBondWithIdx(num_bonds - 1)
        _decorate_rdk_bond_from_xml_bond(rdk_bond,
                                         xml_bond,
                                         prop_prefix=prop_prefix)
    _adapt_enhanced_stereo(rwmol, prop_prefix=prop_prefix)
    mol = rwmol.GetMol()
    if len(coordinates) == mol.GetNumAtoms():
        conformer = Chem.Conformer(len(coordinates))
        for (i, xyz) in enumerate(coordinates):
            conformer.SetAtomPosition(i, xyz)
        mol.AddConformer(conformer)
    Chem.AssignChiralTypesFromBondDirs(mol)
    Chem.DetectBondStereochemistry(mol)
    # molecules (S-groups)
    sgroups = []
    for xml_sgroup in xml_molecule.findall('./molecule'):
        sgroups.append({k: v for (k, v) in xml_sgroup.attrib.items()})
    if sgroups:
        mol.SetProp(prop_prefix + 'sgroups', json.dumps(sgroups))
    return mol 
#------------------------------------------------------------------------------#
[docs]class CmlFileReader(contextlib.AbstractContextManager):
    '''
    Does not need to be context manager (such need may
    arise in case we switch to a different XML parser).
    '''
[docs]    def __init__(self, filename, prop_prefix=common.CML_PROP_PREFIX):
        self._filename = filename
        self._prop_prefix = prop_prefix 
    def __enter__(self):
        root = parse_and_drop_xmlns(self._filename)
        self._xml_molecules = root.findall(
            './MDocument/MChemicalStruct/molecule')
        # reverse the list to .pop() from it later
        self._xml_molecules.reverse()
        return self
    def __exit__(self, *exc_details):
        return None
    def __iter__(self):
        return self
    def __next__(self):
        try:
            xml_molecule = self._xml_molecules.pop()
        except IndexError:
            raise StopIteration
        return rdk_mol_from_cml_element(xml_molecule) 
#------------------------------------------------------------------------------#