import itertools
from enum import Enum
from schrodinger.application.msv.gui.homology_modeling.constants import HomologyStatus
from schrodinger.infra.util import enum_speedup
from schrodinger.protein import annotation
SEQ_ANNO_TYPES = annotation.ProteinSequenceAnnotations.ANNOTATION_TYPES
PickMode = enum_speedup(
    Enum('PickMode', ['Pairwise', 'HMChimera', 'HMProximity', 'HMBindingSite']))
[docs]def get_pickable_seqs(aln, mode):
    """
    Return the sequences in the alignment that are pickable in the given mode.
    :param aln: Alignment
    :type aln: gui_alignment.GuiProteinAlignment
    :param mode: Picking mode
    :type mode: PickMode
    :return: Iterable of pickable sequences
    :rtype: collections.Iterable[sequence.ProteinSequence]
    """
    if mode is PickMode.HMChimera:
        return (seq for seq in aln
                if aln.getHomologyStatus(seq) is HomologyStatus.Template)
    elif mode is PickMode.Pairwise:
        if len(aln) < 2:
            return []
        return aln[:2] 
[docs]def handle_reset_pick(aln, mode):
    """
    Dispatch reset_pick to the correct method for the mode
    :param aln: Alignment
    :type aln: gui_alignment.GuiProteinAlignment
    :param mode: Picking mode
    :type mode: PickMode
    """
    reset_func = _MODE_RESET_MAP[mode]
    if reset_func is None:
        raise ValueError("Reset func not implemented for picking mode", mode)
    reset_func(aln) 
[docs]def handle_pick(aln, mode, *args, **kwargs):
    """
    Dispatch the pick to the correct method for the mode. args and kwargs are
    passed to the picking method.
    :param aln: Alignment
    :type aln: gui_alignment.GuiProteinAlignment
    :param mode: Picking mode
    :type mode: PickMode
    """
    pick_func = _MODE_PICK_MAP[mode]
    if pick_func is None:
        raise ValueError("Picking func not implemented for picking mode", mode)
    pick_func(aln, *args, **kwargs) 
def _handle_pick_chimeric(aln, residues, selected):
    """
    Update residue highlights for chimeric picking.
    Exactly one residue from each column should be picked. Therefore, picking
    in one sequence should un-pick from the other sequences. Un-picking from a
    non-default sequence should re-pick in the default sequence. Un-picking
    from the default sequence has no effect.
    :param aln: Alignment
    :type aln: gui_alignment.GuiProteinAlignment
    :param residues: Residues to pick or unpick
    :type residues: set[residue.Residue]
    :param selected: Whether the residues are being selected
    :type selected: bool
    """
    if not residues:
        return
    templates = list(get_pickable_seqs(aln, PickMode.HMChimera))
    if not templates:
        return
    default_template, *other_templates = templates
    template_indices = {v: i for i, v in enumerate(templates)}
    templates = set(templates)
    picked_residues = aln.homology_composite_residues
    to_pick = set()
    to_unpick = set()
    for res in residues:
        if res.sequence not in templates:
            continue
        col_idx = res.idx_in_seq
        unpick_res = [seq[col_idx] for seq in templates if col_idx < len(seq)]
        unpick_res.remove(res)
        already_picked = res in picked_residues
        if res.sequence == default_template:
            if selected and not already_picked:
                to_pick.add(res)
                to_unpick.update(unpick_res)
        elif res.sequence in other_templates:
            if col_idx < len(default_template):
                default_res = default_template[col_idx]
            else:
                default_res = None
            if selected and not already_picked:
                to_pick.add(res)
                to_unpick.update(unpick_res)
            elif not selected and already_picked:
                to_unpick.add(res)
                if default_res is not None:
                    to_pick.add(default_res)
    # Filter out gaps
    to_pick = {res for res in to_pick if not res.is_gap}
    to_unpick = {res for res in to_unpick if not res.is_gap}
    # Handle selection of multiple template sequences
    picked_and_unpicked = to_pick & to_unpick
    seq_idx_key = lambda res: res.idx_in_seq
    for k, group in itertools.groupby(
            sorted(picked_and_unpicked, key=seq_idx_key), seq_idx_key):
        group = sorted(group, key=lambda res: template_indices[res.sequence])
        # Pick exactly one residue from this column
        keep = group.pop()
        to_unpick.remove(keep)
        # Don't pick the rest of the residues from the column
        to_pick.difference_update(group)
    to_pick, to_unpick = fix_chimeric_pick(aln,
                                           to_pick=to_pick,
                                           to_unpick=to_unpick)
    aln.updateHomologyCompositeResidues(to_add=to_pick, to_remove=to_unpick)
[docs]def fix_chimeric_pick(aln, to_pick=None, to_unpick=None):
    """
    Given an alignment and optional sets of residues to pick/unpick, unpick
    residues aligned with reference sequence gaps.
    :param aln: Alignment
    :type aln: gui_alignment.GuiProteinAlignment
    :param to_pick: Set of residues to pick to be updated in place
    :type to_pick: set or NoneType
    :param to_unpick: Set of residues to unpick to be updated in place
    :type to_unpick: set or NoneType
    :return: Sets of residues to pick and unpick
    :rtype: tuple(set, set)
    """
    if to_pick is None:
        to_pick = set()
    if to_unpick is None:
        to_unpick = set()
    return_value = (to_pick, to_unpick)
    ref_seq = aln.getReferenceSeq()
    first_ref_res = None
    ref_gap_idxs = []
    for elem in ref_seq:
        if elem.is_gap:
            ref_gap_idxs.append(elem.idx_in_seq)
        elif first_ref_res is None:
            first_ref_res = elem
    if first_ref_res is None:
        return return_value
    first_ref_res_idx = first_ref_res.idx_in_seq
    last_ref_res = next(res for res in reversed(ref_seq) if res.is_res)
    last_ref_res_idx = last_ref_res.idx_in_seq
    # Unpick template residues aligned with ref gaps
    for seq in get_pickable_seqs(aln, PickMode.HMChimera):
        indexes = itertools.chain(
            range(0, first_ref_res_idx),
            ref_gap_idxs,
            range(last_ref_res_idx + 1, aln.num_columns),
        )  # yapf: disable
        for idx in indexes:
            if idx >= len(seq):
                continue
            res = seq[idx]
            if res.is_res:
                to_pick.discard(res)
                to_unpick.add(res)
    return return_value 
def _handle_pick_pairwise(aln, residues, selected):
    """
    Update pairwise constraints.
    :param aln: Alignment
    :type aln: gui_alignment.GuiProteinAlignment
    :param residues: Residues to pick or unpick
    :type residues: set[residue.Residue]
    :param selected: Whether the residues are being selected
    :type selected: bool
    """
    if not residues:
        return
    ref_seq, other_seq = get_pickable_seqs(aln, PickMode.Pairwise)
    res = residues.pop()
    if res in ref_seq:
        aln.setRefConstraint(res)
    elif res in other_seq:
        aln.setOtherConstraint(res)
def _handle_pick_binding_site(aln, res, lig):
    """
    Set or unset a ligand constraint to the given residue's column. If there is
    no corresponding reference residue, the constraints will not be changed.
    :param res: The residue to constrain to the ligand
    :type res: residue.Residue
    :param lig: Name of the ligand to constrain
    :type lig: str
    """
    ref_seq = aln.getReferenceSeq()
    res_column = res.idx_in_seq
    if res_column >= len(ref_seq):
        return
    ref_res = ref_seq[res_column]
    if ref_res.is_gap:
        return
    aln.setHMLigandConstraint(ref_res, lig)
def _handle_pick_proximity(aln, res):
    ref_seq = aln.getReferenceSeq()
    if ref_seq is None or res not in ref_seq:
        return
    aln.setHMProximityConstraint(res)
def _handle_reset_chimeric(aln):
    mode = PickMode.HMChimera
    current_pick = aln.homology_composite_residues
    aln.updateHomologyCompositeResidues(to_add=(),
                                        to_remove=current_pick,
                                        signal=False)
    default_template = next(get_pickable_seqs(aln, mode), None)
    if default_template is not None:
        handle_pick(aln, mode, list(default_template), selected=True)
def _handle_reset_pairwise(aln):
    aln.resetPairwiseConstraints()
def _handle_reset_binding_site(aln):
    aln.clearHMLigandConstraints()
def _handle_reset_proximity(aln):
    aln.clearHMProximityConstraints()
_MODE_PICK_MAP = {
    PickMode.Pairwise: _handle_pick_pairwise,
    PickMode.HMChimera: _handle_pick_chimeric,
    PickMode.HMProximity: _handle_pick_proximity,
    PickMode.HMBindingSite: _handle_pick_binding_site,
}
_MODE_RESET_MAP = {
    PickMode.Pairwise: _handle_reset_pairwise,
    PickMode.HMChimera: _handle_reset_chimeric,
    PickMode.HMProximity: _handle_reset_proximity,
    PickMode.HMBindingSite: _handle_reset_binding_site,
}