"""
Collection of functions used for RGroup Decomposition.
Copyright Schrodinger, LLC. All rights reserved.
"""
from typing import List
from rdkit import Chem
from rdkit.Chem import rdmolops
from rdkit.Chem import rdRGroupDecomposition
from schrodinger.livedesign import substructure
from schrodinger.livedesign.molhash import ATOM_PROP_MAP_NUMBER
from schrodinger.livedesign.preprocessor import ATOM_PROP_ATOM_LABEL
from schrodinger.livedesign.preprocessor import ATOM_PROP_DUMMY_LABEL
from schrodinger.livedesign.preprocessor import MOL_PROP_ATTACHPT
from schrodinger.livedesign.preprocessor import MOL_PROP_R_LABEL
ATOM_PROP_MATCH_TO_RG = 'match_to_rg'
[docs]def replace_rgroups_with_dummy_atoms(mol: Chem.rdchem.Mol) -> Chem.rdchem.Mol:
    updated_rgroup = Chem.Mol(mol)
    for at in updated_rgroup.GetAtoms():
        if at.HasProp(ATOM_PROP_MAP_NUMBER) and at.HasProp(
                ATOM_PROP_DUMMY_LABEL) and at.HasProp(MOL_PROP_R_LABEL):
            at.ClearProp(ATOM_PROP_MAP_NUMBER)
            at.ClearProp(ATOM_PROP_DUMMY_LABEL)
            at.ClearProp(MOL_PROP_R_LABEL)
            at.SetAtomicNum(0)
    # dummy atoms should be represented as * in sdfs
    params = rdmolops.AdjustQueryParameters.NoAdjustments()
    params.makeDummiesQueries = True
    return rdmolops.AdjustQueryProperties(updated_rgroup, params) 
def _get_rlabel_bundles(edit_core):
    """
    Returns a set of tuples containing the r labels on each fragment of
    the provided core. For example, a core that looks like:
        R1 - C - R3   R2 - C - C - R4
    would return {(1, 3), (2, 4)}
    """
    visited = [False] * edit_core.GetNumAtoms()
    rlabel_bundles = set()
    for at in edit_core.GetAtoms():
        if visited[at.GetIdx()]:
            continue
        q = [at.GetIdx()]
        current_bundle = []
        while q:
            idx = q.pop()
            if not visited[idx]:
                visited[idx] = True
                edit_core_atom = edit_core.GetAtomWithIdx(idx)
                nbrs = [
                    neigh.GetIdx() for neigh in edit_core_atom.GetNeighbors()
                ]
                q.extend(nbrs)
                try:
                    r_label = edit_core_atom.GetIntProp(MOL_PROP_R_LABEL)
                except KeyError:
                    continue
                current_bundle.append(r_label)
        rlabel_bundles.add(tuple(current_bundle))
    return rlabel_bundles
def _join_attchpts_at_atom(edit_core, attch_pts, rg_num):
    edit_core.BeginBatchEdit()
    new_at_idx = edit_core.AddAtom(Chem.AtomFromSmarts('*'))
    for attch_pt in attch_pts:
        for neighbor in edit_core.GetAtomWithIdx(attch_pt).GetNeighbors():
            old_bond = edit_core.GetBondBetweenAtoms(attch_pt,
                                                     neighbor.GetIdx())
            edit_core.AddBond(neighbor.GetIdx(), new_at_idx,
                              old_bond.GetBondType())
            # ensure conjugated and aromatic properties are the same as in the original
            # core, otherwise the stitched core will not match with the original
            new_bond = edit_core.GetBondBetweenAtoms(neighbor.GetIdx(),
                                                     new_at_idx)
            new_bond.SetIsConjugated(old_bond.GetIsConjugated())
            new_bond.SetIsAromatic(old_bond.GetIsAromatic())
            edit_core.RemoveBond(attch_pt, neighbor.GetIdx())
        edit_core.RemoveAtom(attch_pt)
    edit_core.GetAtomWithIdx(new_at_idx).SetIntProp(MOL_PROP_R_LABEL,
                                                    int(rg_num[1::]))
    edit_core.CommitBatchEdit()
    return edit_core
def _close_partial_rings(edit_core, attch_pts, rg_num):
    """
    Closes rings that were closed on the original scaffold.
    :param edit_core: core split at rgroups
    :param attch_pts: atoms with attachment points to join together
    :param rg_num: rgroup number
    :return: core with partial rings closed
    """
    join_atoms = dict()
    join_atoms_together = None
    for at_idx in attch_pts:
        at = edit_core.GetAtomWithIdx(at_idx)
        if at.HasProp(ATOM_PROP_MATCH_TO_RG):
            rlabel = at.GetIntProp(ATOM_PROP_MATCH_TO_RG)
            if rlabel in join_atoms:
                join_atoms_together = [join_atoms[rlabel], at_idx]
                break
            join_atoms[rlabel] = at_idx
    if join_atoms_together is None:
        # no ring to close
        return edit_core
    return _join_attchpts_at_atom(edit_core, join_atoms_together, rg_num)
def _rgroup_includes_ring_bonds(rlabels, rlabel_bundles, rg):
    # check if rgroup closes a ring
    if not any(
        [set(rlabels).issubset(set(bundle)) for bundle in rlabel_bundles]):
        return False
    # check if one atom has two attachment points
    for at in rg.GetAtoms():
        if len([
                neigh for neigh in at.GetNeighbors()
                if neigh.HasProp(MOL_PROP_R_LABEL)
        ]) == 2:
            return False
    return True
def _get_stitched_core(res):
    """
    Reconstructs the core returned from rgroup decomposition to match the original
    core by stitching the core back together at the internal rgroups.
    """
    # Each 'bundle' is a set of rlabels on some fragment of the core
    rlabel_bundles = _get_rlabel_bundles(res['Core'])
    edit_core = Chem.RWMol(res['Core'])
    existing_rgroup_labels = set()
    for rg_num, rg in res.items():
        if rg_num == 'Core':
            continue
        # Skip if this rgroup has already been considered
        rlabels = set()
        for at in rg.GetAtoms():
            try:
                rlabels.add(at.GetIntProp(MOL_PROP_R_LABEL))
            except KeyError:
                pass
        rlabels = tuple(sorted(rlabels))
        if rlabels not in existing_rgroup_labels:
            existing_rgroup_labels.add(rlabels)
        else:
            continue
        # Skip terminal RGroups
        if len(rlabels) == 1:
            continue
        # Find where rgroup belongs
        attch_pts = []
        for at in edit_core.GetAtoms():
            try:
                if at.GetIntProp(MOL_PROP_R_LABEL) in rlabels:
                    attch_pts.append(at.GetIdx())
            except KeyError:
                pass
        # Replace rgroup with a single atom if that was the structure of the
        # original core
        if attch_pts:
            if not _rgroup_includes_ring_bonds(rlabels, rlabel_bundles, rg):
                edit_core = _join_attchpts_at_atom(edit_core, attch_pts, rg_num)
            else:
                edit_core = _close_partial_rings(edit_core, attch_pts, rg_num)
    return edit_core
def _remove_duplicate_rgroups(original_scaffold, res):
    # Stitch core back together then match it to the original scaffold to
    # determine which RGroup in the result corresponds to which RGroup in the
    # original scaffold.
    stitched_core = _get_stitched_core(res)
    new_res = {'Core': original_scaffold}
    match = stitched_core.GetSubstructMatch(original_scaffold)
    if len(match) != stitched_core.GetNumAtoms():
        return None
    for at in original_scaffold.GetAtoms():
        if at.HasProp(MOL_PROP_R_LABEL):
            new_core_at = stitched_core.GetAtomWithIdx(match[at.GetIdx()])
            rgroup_label = f'R{new_core_at.GetProp(MOL_PROP_R_LABEL)}'
            new_res[f'R{at.GetProp(MOL_PROP_R_LABEL)}'] = res[rgroup_label]
    return new_res
def _replace_core_query(cores):
    """
    Split each of the cores at the rgroups to allow for RGroup decomposition
    when an rgroup is 'inside' a scaffold
    """
    scaff_split = False
    new_cores = []
    for core in cores:
        # find dummy atoms with a degree>1
        core_query = Chem.MolFromSmarts('[#0D{2-}]')
        matches = core.GetSubstructMatches(core_query)
        if not matches:
            new_cores.append(core)
            continue
        res = Chem.RWMol(core)
        null_query = Chem.AtomFromSmarts('*')
        atoms_removed = []
        next_rgroup = len(
            [at for at in core.GetAtoms() if at.HasProp(MOL_PROP_R_LABEL)]) + 1
        res.BeginBatchEdit()
        for match in matches:
            midx = match[0]
            match = res.GetAtomWithIdx(midx)
            nbrs = match.GetNeighbors()
            first = True
            for nbr in nbrs:
                if nbr.GetIdx() not in atoms_removed:
                    idx = res.AddAtom(Chem.Atom(0))
                    atom = res.GetAtomWithIdx(idx)
                    if first:
                        atom.SetProp(ATOM_PROP_DUMMY_LABEL,
                                     match.GetProp(ATOM_PROP_DUMMY_LABEL))
                        atom.SetIntProp(MOL_PROP_R_LABEL,
                                        match.GetIntProp(MOL_PROP_R_LABEL))
                        first = False
                    else:
                        atom.SetProp(ATOM_PROP_DUMMY_LABEL, f'R{next_rgroup}')
                        atom.SetIntProp(MOL_PROP_R_LABEL, next_rgroup)
                        next_rgroup += 1
                    atom.SetIntProp(ATOM_PROP_MATCH_TO_RG,
                                    match.GetIntProp(MOL_PROP_R_LABEL))
                    bond = res.GetBondBetweenAtoms(nbr.GetIdx(), midx)
                    new_bond_idx = res.AddBond(nbr.GetIdx(), idx,
                                               bond.GetBondType())
                    res.GetBondBetweenAtoms(nbr.GetIdx(), idx).SetIsConjugated(
                        bond.GetIsConjugated())
                    res.RemoveBond(nbr.GetIdx(), midx)
            res.RemoveAtom(midx)
            atoms_removed.append(midx)
        res.CommitBatchEdit()
        new_cores.append(res)
        scaff_split = True
    return new_cores, scaff_split
def _is_atom_conjugated(atom):
    for bnd in atom.GetBonds():
        if bnd.GetIsConjugated():
            return True
    return False
def _adjust_bonds_to_degree_one_neighbors_of_conjugated_atoms(scaff):
    res = Chem.RWMol(scaff)
    for atom in scaff.GetAtoms():
        if atom.GetDegree() != 1:
            continue
        bond = atom.GetBonds()[0]
        if bond.GetBondType(
        ) != Chem.BondType.SINGLE or not _is_atom_conjugated(
                bond.GetOtherAtom(atom)):
            continue
        qb = Chem.BondFromSmarts('-,=,:')
        res.ReplaceBond(bond.GetIdx(), qb)
    return res
def _clean_bonds_to_attachment_points(match):
    for k, mol in match.items():
        if k == 'Core':
            continue
        attch_pt_num = 1
        for at in mol.GetAtoms():
            # Assign correct attachment point properties
            if at.HasProp(MOL_PROP_ATTACHPT):
                at.ClearProp(MOL_PROP_ATTACHPT)
            if at.HasProp(MOL_PROP_R_LABEL):
                at.SetProp(ATOM_PROP_ATOM_LABEL, f'_AP{attch_pt_num}')
                attch_pt_num += 1
                # All bonds to attachment points should be single, nonaromatic
                # bonds for consistency with read from .sdf
                for bnd in at.GetBonds():
                    bnd.SetIsAromatic(False)
                    bnd.SetBondType(Chem.rdchem.BondType.SINGLE)
    return match
[docs]def get_rgroup_decomp(scaffold_mol: Chem.rdchem.Mol, match_mol: Chem.rdchem.Mol,
                      options: substructure.QueryOptions) -> List[dict]:
    scaffold_copy = Chem.Mol(scaffold_mol)
    if not options.stereospecific:
        # remove bond stereochemistry and chiral tags from scaffold_mol
        Chem.RemoveStereochemistry(scaffold_copy)
    scaffolds = [scaffold_copy]
    if options.adjust_conjugated_five_rings or options.adjust_single_bonds_between_aromatic_atoms:
        scaffold_copy = _adjust_bonds_to_degree_one_neighbors_of_conjugated_atoms(
            scaffold_copy)
    scaffolds = [m for m in substructure.expand_query(scaffold_copy, options)]
    scaffolds, scaff_split = _replace_core_query(scaffolds)
    decomp_params = rdRGroupDecomposition.RGroupDecompositionParameters()
    decomp_params.onlyMatchAtRGroups = True
    decomp_params.allowNonTerminalRGroups = True
    decomp_params.substructMatchParams.useEnhancedStereo = True
    decomp_params.removeAllHydrogenRGroups = False
    decomp_params.removeAllHydrogenRGroupsAndLabels = False
    matches, no_match = rdRGroupDecomposition.RGroupDecompose(
        scaffolds, [match_mol], options=decomp_params)
    if no_match:
        return None
    # Otherwise return the first result with duplicate rgroups removed
    match = matches.pop(0)
    match = _clean_bonds_to_attachment_points(match)
    if scaff_split:
        match = _remove_duplicate_rgroups(Chem.Mol(scaffold_mol), match)
    # ensure core returned is the provided scaffold
    if match is not None:
        match['Core'] = Chem.Mol(scaffold_mol)
    return match