import itertools
import more_itertools
from . import atomlist
from . import linknode
from . import markush
from . import posvarbond
#------------------------------------------------------------------------------#
def _get_linknodes_and_pvbonds(mol, linknodes=None, pvbonds=None):
    '''
    Collects and pre-validates "repeating units"
    and "position variant bonds" specifications.
    :param mol: Molecule.
    :type mol: rdkit.Chem.Mol
    :param linknodes: Link nodes (repeating units).
    :type linknodes: list(linknode.LinknodeSgroup)
    :param pvbonds: Position variant bonds.
    :type pvbonds: list(posvarbond.MulticenterSgroup)
    :return: Lists of "repeating units" and "position variant bonds".
    :rtype: (list(linknode.LinknodeSgroup), list(posvarbond.MulticenterSgroup))
    '''
    if linknodes is None:
        # to be validated later
        linknodes = linknode._collect_linknodes(mol)
    if pvbonds is None:
        # to be validated later
        pvbonds = posvarbond._collect_posvarbonds(mol)
    for ln in linknodes:
        for pvb in pvbonds:
            if ln.atoms & set(pvb.atoms) or pvb.center in ln.atoms:
                raise ValueError(
                    'position variable bonds within repeating units')
    return linknodes, pvbonds
#------------------------------------------------------------------------------#
def _get_random_flat_realization(mol, prng, linknodes, pvbonds):
    '''
    Returns random realization of the molecule obtained via not nestable
    flavors of enumeration:
    * repeating units,
    * position variant bonds, and
    * atom lists
    (in that order: "repeating units" must be expanded prior to "position
    variant bonds" because enumeration of the latter involves atom deletions).
    :param mol: Molecule.
    :type mol: rdkit.Chem.Mol
    :param prng: MT19937 pseudorandom number generator from numpy.
    :type prng: `numpy.random.RandomState`
    :param linknodes: Link nodes (repeating units), pre-validated.
    :type linknodes: list(linknode.LinknodeSgroup)
    :param pvbonds: Position variant bonds, pre-validated.
    :type pvbonds: list(posvarbond.MulticenterSgroup)
    :return: Molecule
    :rtype: rdkit.Chem.Mol
    '''
    tmp = linknode.LinknodeEnumerable(
        mol, linknodes=linknodes).getRandomRealization(prng)
    tmp = posvarbond.PosVarBondEnumerable(
        tmp, pvbonds=pvbonds).getRandomRealization(prng)
    return atomlist.AtomListEnumerable(tmp).getRandomRealization(prng)
#------------------------------------------------------------------------------#
[docs]def flat_enumerator(mol, prng=None, linknodes=None, pvbonds=None):
    '''
    Returns iterator over realizations obtained via not nestable flavors of
    enumeration:
    * repeating units,
    * position variant bonds, and
    * atom lists
    (in that order: "repeating units" must be expanded prior to "position
    variant bonds" because enumeration of the latter involves atom deletions).
    :param mol: Unadulterated molecule.
    :type mol: rdkit.Chem.Mol
    :param prng: MT19937 pseudorandom number generator from numpy or None.
    :type prng: `numpy.random.RandomState` or NoneType
    :param linknodes: Link nodes (repeating units).
    :type linknodes: list(linknode.LinknodeSgroup)
    :param pvbonds: Position variant bonds.
    :type pvbonds: list(posvarbond.MulticenterSgroup)
    :return: Iterator over molecules.
    :rtype: iterator over rdkit.Chem.Mol
    '''
    # FIXME: can atom indices change in RDKit::ROMol(const ROMol&)?
    _linknodes, _pvbonds = _get_linknodes_and_pvbonds(mol,
                                                      linknodes=linknodes,
                                                      pvbonds=pvbonds)
    if prng:  # random
        def get_realization():
            return _get_random_flat_realization(mol,
                                                prng,
                                                linknodes=linknodes,
                                                pvbonds=pvbonds)
        return more_itertools.repeatfunc(get_realization)
    else:  # sequential
        ln_iter = linknode.LinknodeEnumerable(mol,
                                              linknodes=linknodes).getIter()
        iterable_over_pvb_iters = (posvarbond.PosVarBondEnumerable(
            m, pvbonds=pvbonds).getIter() for m in ln_iter)
        pvb_iter = itertools.chain.from_iterable(iterable_over_pvb_iters)
        iterable_over_atomlist_iters = (
            atomlist.AtomListEnumerable(m).getIter() for m in pvb_iter)
        return itertools.chain.from_iterable(iterable_over_atomlist_iters) 
#------------------------------------------------------------------------------#
[docs]def flat_list_enumerator(molecules, prng=None):
    '''
    Returns iterator over structures obtained by applying "flat"
    enumeration to the `molecules`.
    :param molecules: List of (kinky) molecules.
    :type molecules: list(rdkit.Chem.Mol)
    :param prng:  MT19937 pseudorandom number generator from numpy or None.
    :type prng: `numpy.random.RandomState` or NoneType
    :return: Iterator over molecules.
    :rtype: iterator over rdkit.Chem.Mol
    '''
    if prng:  # random
        def get_realization():
            return next(flat_enumerator(prng.choice(molecules), prng=prng))
        return more_itertools.repeatfunc(get_realization)
    else:  # sequential
        iterable = (flat_enumerator(m, prng=None) for m in molecules)
        return itertools.chain.from_iterable(iterable) 
#------------------------------------------------------------------------------#
[docs]def place_rgroups(mol, todo, rgroups, prng, homo):
    '''
    Generator that yields realizations of `mol` with
    (some) atoms replaced by R-groups from `todo` (and, potentially,
    `rgroups`).
    :param mol: Scaffold molecule.
    :type mol: rdkit.Chem.Mol
    :param todo: List of lists of atom indices (in `mol`) paired with
        corresponding R-groups.
    :type todo: list(list(int), rdkit.Chem.Mol)
    :param rgroups: Dictionary that maps R-group numbers (positive
        integers) onto list of molecules. Assumed to contain "original"
        R-groups that may be necessary if some of the R-groups in `todo`
        include R-group references.
    :type rgroup: dict(int, list(rdkit.Chem.Mol))
    :param prng:  MT19937 pseudorandom number generator from numpy or None.
    :type prng: `numpy.random.RandomState` or NoneType
    :param homo: IDs (positive integers) of the homo R-groups.
    :type homo: set(int) or NoneType
    '''
    todo_iters = []  # in `todo` order
    for (_, rgmol) in todo:
        if markush.get_rlabels_set(rgmol) & rgroups.keys():
            # `rgmol` references nested R-groups that need to be enumerated
            todo_iters.append(enumerate_rgroups(rgmol, rgroups, prng, homo))
        else:
            todo_iters.append([rgmol])
    todo_iters_product = \
        
zip(*todo_iters) if prng else itertools.product(*todo_iters)
    for rgs in todo_iters_product:
        atoms_and_rgroups = list(
            itertools.chain.from_iterable(
                zip(indices, itertools.repeat(rg))
                for ((indices, _), rg) in zip(todo, rgs)))
        yield from markush.place_rgroups(mol, atoms_and_rgroups) 
#------------------------------------------------------------------------------#
[docs]def enumerate_rgroups(mol, rgroups, prng=None, homo=None):
    '''
    Enumerates R-groups in `mol` using R-groups from `rgroups`.
    :param mol: Scaffold molecule.
    :type mol: rdkit.Chem.Mol
    :param rgroups: Dictionary that maps R-group numbers (positive
        integers) onto list of molecules.
    :type rgroup: dict(int, list(rdkit.Chem.Mol))
    :param prng:  MT19937 pseudorandom number generator from numpy or None.
    :type prng: `numpy.random.RandomState` or NoneType
    :param homo: IDs (positive integers) of the homo R-groups.
    :type homo: set(int) or NoneType
    '''
    # resolve/flatten R-groups
    flat_rgroup_iters = []
    flat_rgroup_atoms = []
    for rlabel, atoms in markush.get_rlabels_map(mol).items():
        try:
            replacements = rgroups[rlabel]
        except KeyError:
            continue
        viable_replacements = []
        for rgmol in replacements:
            rgmol_rlabels = markush.get_rlabels_set(rgmol)
            if rgmol_rlabels:
                # example: R2 in R2-C should not be replaced with
                # anything that references R1 or R2; may reference R3
                if rlabel < min(rgmol_rlabels):
                    viable_replacements.append(rgmol)
            else:
                viable_replacements.append(rgmol)
        if viable_replacements:
            if homo and rlabel in homo:
                # all `atoms` to be replaced with the same R-group realization
                flat_rgroup_atoms.append(atoms)
                flat_rgroup_iters.append(
                    flat_list_enumerator(viable_replacements, prng))
            else:
                # enumerate R-group realizations at different `atoms`
                for idx in atoms:
                    flat_rgroup_atoms.append([idx])
                    flat_rgroup_iters.append(
                        flat_list_enumerator(viable_replacements, prng))
    # loop over R-group combinations
    flat_rgroups_iters_product = zip(*flat_rgroup_iters) \
        
if prng else itertools.product(*flat_rgroup_iters)
    for rgs in flat_rgroups_iters_product:
        atoms_and_rgroups = list(zip(flat_rgroup_atoms, rgs))
        yield from place_rgroups(mol,
                                 atoms_and_rgroups,
                                 rgroups=rgroups,
                                 prng=prng,
                                 homo=homo) 
#------------------------------------------------------------------------------#
[docs]def collection(mol, rgroups=None, prng=None, homo=None):
    '''
    Top-level API: generator that yields molecules obtained from `mol`.
    :param mol: Scaffold molecule.
    :type mol: rdkit.Chem.Mol
    :param rgroups: Dictionary that maps R-group numbers (positive
        integers) onto list of molecules.
    :type rgroup: dict(int, list(rdkit.Chem.Mol))
    :param prng:  MT19937 pseudorandom number generator from numpy or None.
    :type prng: `numpy.random.RandomState` or NoneType
    :param homo: IDs (positive integers) of the "homo" R-groups.
        Regular R-groups that share the same label get enumerated
        independently (e.g., four outcomes are expected for ``R1-CO-R1`` with
        ``R1 = [*Cl, *F]``). OTOH, homo R-groups with the same label end up
        with the same realization (within a nesting level), so that only
        two outcomes would be obtained in the example above.
    :type homo: set(int) or NoneType
    '''
    if rgroups is None:
        rgroups = dict()
    for flat_mol in flat_enumerator(mol, prng):
        if prng:  # random
            yield next(
                enumerate_rgroups(flat_mol,
                                  rgroups=rgroups,
                                  prng=prng,
                                  homo=homo))
        else:  # sequential
            yield from enumerate_rgroups(flat_mol,
                                         rgroups=rgroups,
                                         prng=prng,
                                         homo=homo) 
#------------------------------------------------------------------------------#