from collections import defaultdict
from rdkit import Chem
from rdkit.Chem import rdChemReactions
from rdkit.Chem import rdqueries
from . import common
_WAS_DUMMY = 'was_dummy'  # RDKit native
_RLABEL_PROP = '_MolFileRLabel'  # RDKit native
_ATOM_LABEL_PROP = 'atomLabel'  # RDKit native
_REACT_ATOM_IDX_PROP = 'react_atom_idx'  # RDKit native
_ACTIVE_PLACEHOLDER_PROP = '_activePlaceHolder'
_BOND_ID_PROP = '_bond_id'
def _get_rlabel(atom):
    """
    Returns atom's R-label or zero.
    :param atom: Atom.
    :type atom: rdkit.Chem.Atom
    :return: R-label (positive integer) or zero.
    :rtype: int
    """
    try:
        return atom.GetUnsignedProp(_RLABEL_PROP)
    except KeyError:
        return 0
def _set_rlabel(atom, rlabel):
    """
    Sets atom's R-label if `rlabel` is positive, clears it otherwise.
    :param atom: Atom.
    :type atom: rdkit.Chem.Atom
    :param rlabel: R-label.
    :type rlabel: int
    """
    if rlabel > 0:
        atom.SetUnsignedProp(_RLABEL_PROP, rlabel)
    else:
        atom.ClearProp(_RLABEL_PROP)
def _clear_attachment_order(atom):
    """
    Clears attachment order atom properties.
    :param atom: Atom.
    :type atom: rdkit.Chem.Atom
    """
    atom.ClearProp(common.CML_ATTACHMENT_ORDER_PROP)
    try:
        label = atom.GetProp(_ATOM_LABEL_PROP)
        if label.startswith('_AP'):
            atom.ClearProp(_ATOM_LABEL_PROP)
    except KeyError:
        pass
def _get_attachment_order(atom):
    """
    Obtains attachment order from `atom` properties.
    :param atom: Atom.
    :type atom: rdkit.Chem.Atom
    :return: Attachment order or None.
    :rtype: int or NoneType
    """
    try:
        return atom.GetIntProp(common.CML_ATTACHMENT_ORDER_PROP)
    except KeyError:
        pass
    try:
        label = atom.GetProp(_ATOM_LABEL_PROP)
        if label.startswith('_AP'):
            return int(label[3:])
    except (KeyError, ValueError):
        pass
    return None
def _is_attachment_point(atom):
    """
    :param atom: Atom.
    :type atom: rdkit.Chem.Atom
    """
    is_dummy = atom.GetAtomicNum() == 0
    is_degree_one = atom.GetDegree() == 1
    return is_dummy and is_degree_one and _get_rlabel(atom) == 0
def _get_attachment_points(mol):
    """
    Gathers "attachment point" atoms in `mol`. If all
    "attachment orders" are available, sort atoms accordingly.
    :param mol: Molecule.
    :type mol: rdkit.Chem.ROMol
    :return: List of "attachment point" atoms and
        "attachment orders" availability.
    :rtype: list(rdkit.Chem.Atom), bool
    """
    attachment_points = []
    for atom in mol.GetAtoms():
        if _is_attachment_point(atom):
            attachment_points.append((_get_attachment_order(atom), atom))
    have_attachment_orders = all(o is not None for (o, _) in attachment_points)
    if have_attachment_orders:
        attachment_points.sort(key=lambda x: x[0])
    return [atom for (_, atom) in attachment_points], have_attachment_orders
def _propagate_properties(product, ratom, rgroup):
    """
    Helper function called from `_place_rgroup`.
    :param product: Product molecule obtained via RDKit "chemical reaction".
    :type product: rdkit.Chem.Mol
    :param ratom: "placeholder" (R-group) atom from the scaffold molecule.
    :type ratom: rdkit.Chem.Atom
    :param rgroup: R-group molecule.
    :type rgroup: rdkit.Chem.Mol
    """
    scaffold = ratom.GetOwningMol()
    def scaffold_atom(atom):
        try:
            react_atom_idx = atom.GetProp(_REACT_ATOM_IDX_PROP)
            return scaffold.GetAtomWithIdx(int(react_atom_idx))
        except (KeyError, ValueError):
            return None
    for atom in product.GetAtoms():
        # prefer R-group references from `scaffold`
        # over the ones from the `rgroup`
        sc_atom = scaffold_atom(atom)
        if sc_atom and atom.GetAtomicNum() == 0:
            _set_rlabel(atom, _get_rlabel(sc_atom))
        # drop "attachment orders" inherited from from `rgroup`
        if not sc_atom or atom.HasProp(_WAS_DUMMY):
            _clear_attachment_order(atom)
        atom.ClearProp(_WAS_DUMMY)
        if not sc_atom:
            continue
        # bond IDs and types
        for bond in atom.GetBonds():
            neigh = bond.GetOtherAtom(atom)
            sc_neigh = scaffold_atom(neigh)
            if sc_neigh and neigh.GetIdx() > atom.GetIdx():
                # unaffected scaffold bond
                sc_bond = scaffold.GetBondBetweenAtoms(sc_atom.GetIdx(),
                                                       sc_neigh.GetIdx())
                try:
                    bond.SetProp(common.CML_ID_PROP,
                                 sc_bond.GetProp(common.CML_ID_PROP))
                except KeyError:
                    pass
            elif not sc_neigh:
                # bond to R-group in the scaffold
                try:
                    bond.SetProp(common.CML_ID_PROP,
                                 atom.GetProp(_BOND_ID_PROP))
                    atom.ClearProp(_BOND_ID_PROP)
                except KeyError:
                    pass
                sc_bond = scaffold.GetBondBetweenAtoms(ratom.GetIdx(),
                                                       sc_atom.GetIdx())
                bond.SetBondType(sc_bond.GetBondType())
def _place_rgroup(atom, rgroup):
    """
    Replaces `atom` with `rgroup`.
    :param atom: Atom.
    :type atom: rdkit.Chem.Atom
    :param rgroup: R-group molecule.
    :type rgroup: rdkit.Chem.Mol
    :return: List of molecules (which may be empty).
    :rtype: list(rdkit.Chem.Mol)
    """
    rgroup_copy = Chem.Mol(rgroup)
    for rgatom in rgroup_copy.GetAtoms():
        rgatom.SetAtomMapNum(0)
        rgatom.ClearProp(_WAS_DUMMY)
        rgatom.ClearProp(_REACT_ATOM_IDX_PROP)
    attachment_points, have_attachment_orders = \
        _get_attachment_points(rgroup_copy)
    if len(attachment_points) != atom.GetDegree():
        return []
    bonds = []
    for b in atom.GetBonds():
        try:
            bond_id = b.GetProp(common.CML_ID_PROP)
            bonds.append((bond_id, b))
        except KeyError:
            break
    else:
        bonds.sort(key=lambda x: x[0])
    have_bond_ids = len(bonds) == atom.GetDegree()
    # create "reactant" query that matches the `atom`; require
    # "bond priorities" to match "attachment orders" in case both
    # are available, and bond orders are consistent; match only
    # bond orders (from `rgroup`) otherwise
    match_ids = have_bond_ids and have_attachment_orders
    if match_ids:
        # plexus-like matching that relies on labels;
        # check whether bond orders are consistent
        for (ap, (_, b)) in zip(attachment_points, bonds):
            # `ap` has exactly one bond
            if b.GetBondType() != ap.GetBonds()[0].GetBondType():
                match_ids = False
                break
    # "reactant" template
    reactant = Chem.RWMol()
    # R-group (placeholder) atom
    r_atom = rdqueries.HasIntPropWithValueQueryAtom(_ACTIVE_PLACEHOLDER_PROP,
                                                    atom.GetIdx())
    r_atom_index = reactant.AddAtom(r_atom)
    # R-group (placeholder) atom bonds
    for i, ap in enumerate(attachment_points, 1):
        ap.ClearProp(_BOND_ID_PROP)
        ap.SetAtomMapNum(i)
        r_neighbor = Chem.AtomFromSmarts('[*]')
        r_neighbor.SetAtomMapNum(i)
        r_neighbor_index = reactant.AddAtom(r_neighbor)
        if match_ids:
            # Plexus-like, match "bond_id"
            bond_id = bonds[i - 1][0]
            bond_query = rdqueries.HasStringPropWithValueQueryBond(
                common.CML_ID_PROP, bond_id)
            r_num_bonds = reactant.AddBond(r_atom_index, r_neighbor_index)
            reactant.ReplaceBond(r_num_bonds - 1, bond_query)
            # Store "bond_id" as "dummy" atom property
            # because RDKit reactions do not propagate bond
            # properties, but do copy over atom properties (at
            # least to some extent)
            ap.SetProp(_BOND_ID_PROP, bond_id)
        else:
            # match attachment point bond order
            reactant.AddBond(r_atom_index, r_neighbor_index,
                             ap.GetBonds()[0].GetBondType())
    # mark the `atom` temporarily
    atom.SetIntProp(_ACTIVE_PLACEHOLDER_PROP, atom.GetIdx())
    try:
        if not atom.GetOwningMol().HasSubstructMatch(reactant):
            # ENUM-290: favor scaffold bond orders
            for bond in reactant.GetBonds():
                bond.SetBondType(Chem.BondType.UNSPECIFIED)
        rxn = rdChemReactions.ChemicalReaction()
        rxn.AddReactantTemplate(reactant)
        rxn.AddProductTemplate(rgroup_copy)
        rdChemReactions.UpdateProductsStereochemistry(rxn)
        products = [m for (m,) in rxn.RunReactant(atom.GetOwningMol(), 0)]
    finally:
        atom.ClearProp(_ACTIVE_PLACEHOLDER_PROP)
    for p in products:
        _propagate_properties(p, atom, rgroup_copy)
    return products
def _get_r2p_map(mol):
    """
    Assumes that `mol` is a product of "chemical reaction". Returns
    a dictionary that maps "reactant" atom indices onto atom indices
    in the `mol`.
    :param mol: Molecule.
    :type mol: rdkit.Chem.Mol
    :return: Reactant-to-product atom index map.
    :rtype: dict(int, int)
    """
    outcome = dict()
    for atom in mol.GetAtoms():
        try:
            react_atom_idx = atom.GetProp(_REACT_ATOM_IDX_PROP)
            outcome[int(react_atom_idx)] = atom.GetIdx()
        except (KeyError, ValueError):
            pass
    return outcome
[docs]def place_rgroups(mol, atom_indices_and_rgroups, appended_rgroups=None):
    """
    Generator that yields realizations of `mol`
    with (some) atoms replaced by R-groups.
    :param mol: Molecule.
    :type mol: rdkit.Chem.Mol
    :param atom_indices_and_rgroups: List of atom indices paired with
        corresponding R-groups.
    :type atom_indices_and_rgroups: list(int, rdkit.Chem.Mol)
    :param appended_rgroups: List of appended rgroups so far.
    :rtype appended_rgroups: list(rdkit.Chem.Mol)
    """
    if appended_rgroups is None:
        appended_rgroups = []
    mol_name = mol.GetProp('_Name') if mol.HasProp('_Name') else 'scaffold'
    if atom_indices_and_rgroups:
        curr = atom_indices_and_rgroups[0]
        atom, rgroup = mol.GetAtomWithIdx(curr[0]), curr[1]
        for p in _place_rgroup(atom, rgroup):
            # propagate name from mol to products returned by _place_rgroup
            p.SetProp('_Name', mol_name)
            mol2p = _get_r2p_map(p)
            todo = [(mol2p[i], rg) for (i, rg) in atom_indices_and_rgroups[1:]]
            # track added rgroup(s)
            appended_rgroups.append(rgroup)
            yield from place_rgroups(p, todo, appended_rgroups)
            # remove added rgroup(s)
            appended_rgroups.pop()
    else:
        for index, rgroup in enumerate(appended_rgroups, 1):
            cxsmiles = Chem.MolToCXSmiles(rgroup)
            mol.SetProp(f's_rge_R{index}_cxsmiles', cxsmiles)
            rgroup_id = rgroup.GetProp('_Name') if rgroup.HasProp(
                '_Name') else ''
            mol.SetProp(f's_rge_R{index}', rgroup_id)
            mol_name += f'_{rgroup_id}'
        # naming scheme: {scaffold_name}_{rgroup1}_{rgroup2} etc
        mol.SetProp('_Name', mol_name)
        yield mol 
[docs]def canonicalize_R_labels(mol):
    """
    Translates different conventions of R-group labelling
    into the RDKit "native" (AtomRLabel).
    :param mol: Molecule.
    :type mol: rdkit.Chem.Mol
    """
    for atom in mol.GetAtoms():
        if _get_rlabel(atom) > 0:
            continue
        try:
            _set_rlabel(atom, int(atom.GetProp(common.CML_RGROUP_REF_PROP)))
            continue
        except (KeyError, ValueError):
            pass
        try:
            label = atom.GetProp('atomLabel')
            if label.startswith('_R'):
                _set_rlabel(atom, int(label[2:]))
        except (KeyError, ValueError):
            pass 
[docs]def get_rlabels_set(mol):
    """
    Returns set of R-labels carried by the atoms in the `mol`.
    :param mol: Molecule.
    :type mol: rdkit.Chem.Mol
    :return: Set or R-labels.
    :rtype: set(int)
    """
    labels = [_get_rlabel(atom) for atom in mol.GetAtoms()]
    return {r for r in labels if r > 0} 
[docs]def get_rlabels_map(mol):
    """
    Returns map from R-labels to atom indices.
    :param mol: Molecule.
    :type mol: rdkit.Chem.Mol
    :return: Map from R-labels to atom indices.
    :rtype: dict of int:list(int)
    """
    outcome = defaultdict(list)
    for atom in mol.GetAtoms():
        label = _get_rlabel(atom)
        if label > 0:
            outcome[label].append(atom.GetIdx())
    return outcome