import contextlib
import copy
import itertools
import time
from functools import partial
from string import ascii_uppercase
from unittest import mock
from schrodinger import structure
from schrodinger.application.msv import command
from schrodinger.application.msv import seqio
from schrodinger.application.msv import structure_model
from schrodinger.application.msv.gui import gui_alignment
from schrodinger.application.msv.gui import msv_gui
from schrodinger.application.msv.gui.viewconstants import RowType
from schrodinger.protein import annotation
from schrodinger.protein import sequence
from schrodinger.Qt import QtCore
from schrodinger.Qt import QtWidgets
from schrodinger.Qt.QtCore import Qt
from schrodinger.test import mmshare_testfile
from schrodinger.utils import profiling
ALN_ANNO_TYPES = annotation.ProteinAlignmentAnnotations.ANNOTATION_TYPES
SEQ_ANNO_TYPES = annotation.ProteinSequenceAnnotations.ANNOTATION_TYPES
ANNS_NOT_IN_GUI = {
    SEQ_ANNO_TYPES.sasa, SEQ_ANNO_TYPES.hydrophobicity,
    SEQ_ANNO_TYPES.isoelectric_point, SEQ_ANNO_TYPES.rescode
}
# prevent performance tests for alignment set annotation until MSV-2595 is
# completed
ANNS_NOT_IN_GUI.add(SEQ_ANNO_TYPES.alignment_set)
# We currently don't have predictions data loaded for the performance tests, so
# skip. See MSV-2604.
ANNS_NOT_IN_GUI.update(
    annotation.ProteinSequenceAnnotations.PRED_ANNOTATION_TYPES)
ANNS_NOT_IN_GUI.add(SEQ_ANNO_TYPES.kinase_conservation)
[docs]def aln_has_residues(aln_info):
    return any(not res.is_gap for res in itertools.chain(*aln_info.seqs)) 
[docs]def make_alignment(AlnClass, aln_info):
    """
    Given an alignment class and information to populate an instance returns an
    instance of an alignment
    :param AlnClass: An alignment class
    :type  AlnClass: type
    :param aln_info: Information to populate the instance
    :type  aln_info: AlignmentInfo
    :return: A ProteinAlignment instance.
    :rtype: alignment.ProteinAlignment
    """
    if issubclass(AlnClass, gui_alignment.GuiCombinedChainProteinAlignment):
        saln = gui_alignment.GuiProteinAlignment(aln_info.seqs)
        aln = AlnClass(saln)
    else:
        aln = AlnClass(aln_info.seqs)
    # Set up anchor residues
    aln.anchorResidues(aln_info.anchor_residues)
    # Set up interseq disulfide bonds
    for cys1, cys2 in aln_info.cysteines_to_bond:
        aln.addDisulfideBond(cys1, cys2)
    return aln 
[docs]@contextlib.contextmanager
def timing_anno(anno, msv_widget):
    """
    A context manager for timing the painting of a given annotation or row type.
    All caches are cleared at the start of the context, and a function is
    yielded that will enable the row type.  Runtime for the execution of this
    function should be counted as part of the initial paint time, since the view
    will calculate size hints while this function is run.  If we're timing
    an annotation row type (as opposed to sequence rows), then painting and
    data-fetching for sequence rows will be disabled during the context, and the
    annotation will be disabled at the end of the context.
    :param anno: The desired annotation or row type to show.
    :type  anno: enum.Enum
    :param msv_widget: The widget to enable the annotations on.
    :type  msv_widget: AbstractMSVWidget
    :return: Yields a function for enabling the specified row type.  This
        function will be a no-op if timing sequence rows.
    :rtype: function
    """
    # changing the cell size clears all view caches, including the row delegate
    # caches
    msv_widget.view._updateCellSize()
    msv_widget._table_model._base_model._cache.clear()
    msv_widget.getAlignment().clearAllCaching()
    if anno is RowType.Sequence:
        yield lambda: None
    else:
        show_anno = partial(msv_widget._table_model._setVisibilityForRowType,
                            anno,
                            show=True)
        seq_delegate = msv_widget.view._row_delegates_by_row_type[
            RowType.Sequence]
        try:
            og_residue_paintRow = seq_delegate.paintRow
            og_residue_roles = seq_delegate.PER_CELL_PAINT_ROLES
            seq_delegate.paintRow = mock.Mock()
            seq_delegate.PER_CELL_PAINT_ROLES = frozenset()
            yield show_anno
        finally:
            seq_delegate.paintRow = og_residue_paintRow
            seq_delegate.PER_CELL_PAINT_ROLES = og_residue_roles
            msv_widget._table_model._setVisibilityForRowType(anno, show=False) 
[docs]def get_view_expanded_states(msv_wid):
    """
    Get expanded states for all rows in all views in the msv widget
    """
    all_expanded_states = []
    for view_ in (msv_wid.view, msv_wid.aln_info_view,
                  msv_wid.aln_metrics_view):
        expanded_states = []
        model = view_.model()
        for row_idx in range(model.rowCount()):
            idx = model.index(row_idx, 0)
            expanded_states.append(view_.isExpanded(idx))
        all_expanded_states.append(tuple(expanded_states))
    return all_expanded_states 
[docs]def assert_view_expansion_is_synced_with_model(msv_wid):
    all_expanded_states = get_view_expanded_states(msv_wid)
    aln = msv_wid.getAlignment()
    aln_expanded_states = [aln.isSeqExpanded(seq) for seq in aln]
    for row_idx, seq_expanded in enumerate(aln_expanded_states, start=1):
        matches = [
            view_states[row_idx] is seq_expanded
            for view_states in all_expanded_states
        ]
        assert all(matches) 
[docs]def assert_all_views_fully_expanded(msv_wid):
    all_expanded_states = get_view_expanded_states(msv_wid)
    for view_expanded_states in all_expanded_states:
        assert all(view_expanded_states) 
[docs]def process_events_multiple(app):
    """
    Call processEvents multiple times in case any pending timer slots
    trigger additional timers.  Also explicitly processes deleteLater calls.
    :param app: The Qt application
    :type app: QtWidgets.QApplication
    """
    for _ in range(25):
        app.processEvents()
        app.sendPostedEvents(None, QtCore.QEvent.DeferredDelete) 
[docs]class StandaloneStructureModelWithWorkspace(
        structure_model.StandaloneStructureModel):
    """
    A structure model with a fake workspace alignment so we can test how the MSV
    panel will respond when run inside Maestro.  Used in
    TestMSVPanelWithWorkspaceTab.
    """
[docs]    def __init__(self, parent, undo_stack):
        super().__init__()
        self._aln = gui_alignment.GuiProteinAlignment() 
[docs]    def setGuiModel(self, gui_model):
        super().setGuiModel(gui_model)
        if not gui_model.hasWorkspacePage():
            gui_model.addWorkspacePage(self._aln) 
[docs]    def getWorkspaceAlignment(self):
        return self._aln 
[docs]    def getWorkspaceColors(self):
        # Return the correct type to avoid NotImplementedError
        return {} 
    def _readStructures(self, filename):
        seqs = super()._readStructures(filename)
        # add sequences to our fake workspace alignment, but don't return the
        # sequences from the workspace alignment
        ws_seqs = copy.deepcopy(seqs)
        self._aln.addSeqs(ws_seqs)
        return seqs
[docs]    def importStructuresIntoWorkspace(self, filename):
        seqs = super()._readStructures(filename)
        # add sequences to our fake workspace alignment, and then return the
        # sequences from the workspace alignment
        self._aln.addSeqs(seqs)
        return seqs 
[docs]    def syncSelectionToMaestro(self, selection):
        """
        This method should be patched instead of actually called.
        """
        assert False 
[docs]    def getMsvAutosaveProjectName(self):
        return "" 
[docs]    def updateViewPages(self, gui_model):
        # Called by stu test openProject
        pass  
[docs]class BaseCheckUndoMixin:
    """
    Mixin for checking that alignment undo operations are correct. `checkUndo`
    is a no-op, so this class can be used for testing a non-undoable alignment.
    """
[docs]    @contextlib.contextmanager
    def checkUndo(self, aln):
        """
        Override to check that calling undo on the specified alignment works as
        expected.
        :param aln: An alignment to check
        :type aln: schrodinger.protein.alignment.Alignment
        """
        yield 
    def _alnSetIdCounter(self, aln):
        """
        :return: The alignment's set id counter
        """
        return aln._set_id_counter
    def _alnSetInfo(self, aln):
        """
        Get information about alignment sets in the given alignment.
        :param aln: The alignment to get alignment set info for
        :type aln: schrodinger.protein.alignment.BaseAlignment
        :return: A tuple of
            - A dictionary of {set name: indices of sequences in the set}
            - The alignment's set id counter
        :rtype: dict(str, set(int))
        """
        set_info = {}
        for aln_set in aln.alnSets():
            set_info[aln_set.name] = (aln_set.set_id,
                                      {aln.index(seq) for seq in aln_set})
        return set_info, self._alnSetIdCounter(aln) 
[docs]class CheckUndoMixin(BaseCheckUndoMixin):
    """
    Mixin for checking that GuiAlignment undo operations are correct.
    :cvar bool DELETE_UNDO_STACK: Whether checkUndo should set a new undo stack
        on entry and delete it on exit
    """
    DELETE_UNDO_STACK = False
[docs]    @contextlib.contextmanager
    def checkUndo(self, aln):
        """
        Check that calling undo on the specified alignment restores it to the
        state it was in before redo was called; calls redo on exit.
        :note: redo is called on the alignment on exiting the context, so that
            the alignment can be checked for other properties.
        :note: Because the context manager calls redo and undo, signals must be
            checked inside the context in tests.
        :note: Only one method that alters the alignment should be called inside
            the context
        :param aln: An undoable alignment to check
        :type aln:
            schrodinger.application.msv.gui.gui_alignment._ProteinAlignment
        """
        # On entering the context, we make a deep copy of the alignment and
        # set an undo stack on it
        original_aln = copy.deepcopy(aln)
        if self.DELETE_UNDO_STACK:
            aln.setUndoStack(command.UndoStack())
        else:
            aln.undo_stack.clear()
        yield
        undo_count = aln.undo_stack.count()
        if undo_count != 1:
            for idx in range(undo_count):
                print(aln.undo_stack.text(idx))
        assert undo_count == 1
        # On exiting the context, we call undo on the alignment and compare it
        # to the deep copy we made above
        aln.undo_stack.undo()
        self._checkAlignment(original_aln, aln)
        # Finally, restore the alignment so that other checks can be performed
        # on it
        aln.undo_stack.redo()
        if self.DELETE_UNDO_STACK:
            aln.undo_stack.deleteLater() 
    def _checkAlignment(self, original_aln, aln):
        """
        Calls undo on aln, compares it to original_aln, and then calls redo if
        no changes are noted
        :note: Because this method calls undo and redo on the alignment, signals
            may need to be checked within the context instead of after exiting it
        :param original_aln: A deep copy of aln made before modification
        :type original_aln: schrodinger.protein.alignment._ProteinAlignment or
            schrodinger.application.msv.gui.gui_alignment._ProteinAlignment
        :param aln: The alignment being checked for changes after undo
        :type aln: schrodinger.protein.alignment._ProteinAlignment or
            schrodinger.application.msv.gui.gui_alignment._ProteinAlignment
        """
        ## alignment level checks ##
        # check the ordering of sequences is preserved
        original_ordering = [seq.fullname for seq in original_aln]
        ordering = [seq.fullname for seq in aln]
        assert original_ordering == ordering
        assert len(original_aln) == len(aln)
        assert (str(original_aln.getReferenceSeq()) == str(
            aln.getReferenceSeq()))
        assert str(original_aln.getReferenceSeq()) == str(aln.getReferenceSeq())
        for attr_name in ['num_columns']:
            aln_attr = getattr(original_aln, attr_name)
            comparison_attr = getattr(aln, attr_name)
            assert aln_attr == comparison_attr
        # anchor checks
        def getAnchorIdxs(aln):
            return {(aln.index(res.sequence), res.idx_in_seq)
                    for res in aln.getAnchoredResidues()}
        assert getAnchorIdxs(original_aln) == getAnchorIdxs(aln)
        ## sequence level checks ##
        for original_seq, seq in zip(original_aln, aln):
            assert str(original_seq) == str(seq)
        ## selection model checks ##
        assert (original_aln.res_selection_model.getSelectionIndices() ==
                aln.res_selection_model.getSelectionIndices())
        # check known disulfide bonds
        orig_disulfides = {
            tuple(original_aln.getResidueIndices(bond))
            for bond in original_aln.disulfide_bonds
        }
        disulfides = {
            tuple(aln.getResidueIndices(bond)) for bond in aln.disulfide_bonds
        }
        assert orig_disulfides == disulfides
        # check predicted disulfide bonds
        orig_pred_disulfides = {
            tuple(original_aln.getResidueIndices(bond))
            for bond in original_aln.pred_disulfide_bonds
        }
        pred_disulfides = {
            tuple(aln.getResidueIndices(bond))
            for bond in aln.pred_disulfide_bonds
        }
        assert orig_pred_disulfides == pred_disulfides
        # check alignment sets
        assert self._alnSetInfo(original_aln) == self._alnSetInfo(aln)
        # check residue highlights for length
        assert (len(original_aln._residue_highlights) == len(
            aln._residue_highlights))
        # check expansion
        self._assertSameSeqExpansion(original_aln, aln)
        assert original_aln.getSeqShownStates() == aln.getSeqShownStates()
        assert (len(original_aln.getOutlineMap()) == len(aln.getOutlineMap()))
    def _assertSameSeqExpansion(self, aln1, aln2):
        for seq1, seq2 in itertools.zip_longest(aln1, aln2):
            assert aln1.isSeqExpanded(seq1) == aln2.isSeqExpanded(seq2)
    def _alnSetIdCounter(self, aln):
        """
        :return: The alignment's set id counter
        """
        return aln._aln._set_id_counter 
[docs]def add_dummy_structure(seq):
    """
    Add a 1 atom structure to the given sequence. This will make the
    `getStructure` and `hasStructure` methods work for many testing purposes.
    """
    st = structure.create_new_structure(1)
    seq._get_structure = lambda: st 
[docs]def get_seqs_for_structure(filename):
    """
    Create sequences for the structure in the specified file.  The sequence's
    `getStructure` and `hasStructure` methods will work properly.  (In the MSV
    panel, this would be handled by the structure model.  Here
    `seq._get_structure` is set manually, which is typically sufficient for
    testing purposes.)
    :param filename: The file to read the structure from
    :type filename: str
    :return: Sequences for each chain in the structure
    :rtype: list(sequence.Sequence)
    """
    struc = structure.Structure.read(mmshare_testfile(filename))
    seqs = seqio.StructureConverter.convert(struc)
    for seq in seqs:
        seq._get_structure = lambda: struc
    return seqs 
[docs]@contextlib.contextmanager
def enable_mock_BlastTask():
    """
    Replace the `BlastTask` class in the `blast` module with a mocked out
    version that returns the blast results for 1cmy:a. The `BlastTaskTester`
    does not make a call to the ncbi server.
    """
    with mock.patch('schrodinger.protein.tasks.blast.BlastTask._parseHits'
                   ) as parse_hits_mock:
        from schrodinger.protein.tasks import mock_blast_output
        parse_hits_mock.return_value = mock_blast_output.MOCK_OUTPUT
        with mock.patch(
                'schrodinger.protein.tasks.blast.BlastTask._initBlastPlus'):
            yield 
[docs]def strings_to_multichain_seqs(strings):
    """
    Convert a list of strings to a list of split-chain sequences.  Each string
    represents a single combined-chain sequence, with pipes ("|") used to
    represent chain breaks.
    :param strings: The strings to convert.
    :type strings: list[str]
    :return: The newly constructed split-chain sequences
    :rtype: sequence.ProteinSequence
    """
    seqs = []
    for seq_num, seq_str in enumerate(strings, start=1):
        seq_name = f"Seq {seq_num}"
        seq_strings = seq_str.split("|")
        for chain_name, chain_string in zip(ascii_uppercase, seq_strings):
            chain_string = chain_string.replace(" ", "")
            cur_seq = sequence.ProteinSequence(chain_string,
                                               name=seq_name,
                                               chain=chain_name)
            seqs.append(cur_seq)
    return seqs 
[docs]def strings_to_combined_aln(
        strings,
        split_aln_class=gui_alignment.GuiProteinAlignment,
        combined_aln_class=gui_alignment.GuiCombinedChainProteinAlignment):
    """
    Convert a list of strings to a combined-chain alignment.  Each string
    represents a single combined-chain sequence, with pipes ("|") used to
    represent chain breaks.
    :param strings: The strings to convert.
    :type strings: list[str]
    :param split_aln_class: The class of split-chain alignment to use.
    :type split_aln_class: Type[schrodinger.protein.alignment.BaseAlignment] or
        Type[gui_alignment.GuiProteinAlignment]
    :param combined_aln_class: The class of combined-chain alignment to use.
    :type combined_aln_class:
        Type[schrodinger.protein.alignment.CombinedChainProteinAlignment] or
        Type[gui_alignment.GuiCombinedChainProteinAlignment]
    :return: The newly constructed combined-chain alignment
    :rtype: Type[schrodinger.protein.alignment.CombinedChainProteinAlignment] or
        Type[gui_alignment.GuiCombinedChainProteinAlignment]
    """
    seqs = strings_to_multichain_seqs(strings)
    split_aln = split_aln_class(seqs)
    return combined_aln_class(split_aln) 
[docs]def combined_aln_to_strings(aln):
    """
    Convert a combined-chain alignment to a list of strings, with one string per
    sequence and pipes ("|") used to represent chain breaks.
    :param aln: The alignment to convert
    :type aln: schrodinger.protein.alignment.CombinedChainProteinAlignment or
        gui_alignment.GuiCombinedChainProteinAlignment
    :return: A list of strings containing the sequences in `aln`.
    :rtype: list[str]
    """
    return ["|".join(str(chain) for chain in seq.chains) for seq in aln] 
[docs]def compare_params(param1, param2):
    """
    Compare two concrete compound params, excluding items in getJsonBlacklist
    """
    blacklist = {param.paramName() for param in param1.getJsonBlacklist()}
    for key, value1 in param1.getSubParams().items():
        if key in blacklist:
            continue
        value2 = param2.getSubParam(key)
        assert value1 == value2, key