"""
Present a list of all ligands that have been included in the Workspace or
selected in the project, and allow the user to select their desired ligand or
ligands.
`LigandListWidget` is the standard widget for including this list in a panel.
"""
import contextlib
import enum
from functools import total_ordering
import schrodinger
from schrodinger import project
from schrodinger.infra import mm
from schrodinger.Qt import QtCore
from schrodinger.Qt import QtWidgets
from schrodinger.Qt.QtCore import Qt
from schrodinger.structutils import analyze
from schrodinger.ui.qt import table_helper
from schrodinger.ui.qt.appframework2 import maestro_callback
from schrodinger.ui.qt.appframework2 import markers
from schrodinger.ui.qt.utils import maestro_required
maestro = schrodinger.get_maestro()
LigSource = enum.Enum("LigSource", ["included", "selected"])
[docs]@total_ordering
class Ligand(object):
    """
    An object representing one ligand in the list. Contains a reference to
    the analyze.Ligand object.
    """
[docs]    def __init__(self, found_lig, proj_row):
        """
        :param found_lig: The ligand object found using `schrodinger.
            structutils.analyze.AslLigandSearcher`
        :type found_lig: `schrodinger.structutils.analyze.Ligand`
        :param proj_row: The Project Table row
        :type proj_row: `schrodinger.project.ProjectRow`
        """
        self.entry_id = proj_row.entry_id
        self.entry_title = proj_row.title
        self._found_lig = found_lig
        # Atom indices BY ENTRY (they may have different atom numbers in the
        # Workspace structure):
        self.atom_indexes = found_lig.atom_indexes
        self._entire_entry = len(self.atom_indexes) == len(
            proj_row.getStructure().atom)
        self._res_names = self._getResNames() 
    @property
    def struc(self):
        return self._found_lig.st
[docs]    def getName(self, only_one_entry):
        """
        Get the name to use for this ligand in the ligand list.  If the ligand
        is the entire entry, the structure title will be used.  Otherwise, the
        entry title followed by the residue name and number will be used (or a
        multiple residue names and numbers for ligands that span multiple
        residues).
        :param only_one_entry: Whether only one entry is included in the
                workspace (all ligands in the list came from it). In this case,
                the entry title will be excluded from the row name. Note that this
                setting will be ignored when the ligand is the entire entry.
        :type only_one_entry: bool
        :return: The properly formatted name
        :rtype: str
        """
        # Need to also show the entry ID, not just title, because it's
        # fairly common for multiple entries to have the same title.
        # Title is reported last because it can be very long
        entry_info = "Entry %s, Title: %s" % (self.entry_id, self.entry_title)
        if self._entire_entry:
            return entry_info
        elif only_one_entry:
            # All ligands came from the same entry, so show only residue info:
            return self._res_names
        else:
            return self._res_names + "; " + entry_info 
    def __eq__(self, other):
        return (self.entry_id == other.entry_id and
                self._found_lig.sort_key() == other._found_lig.sort_key())
    def __lt__(self, other):
        return (self.entry_id < other.entry_id or
                self._found_lig.sort_key() < other._found_lig.sort_key())
    def __repr__(self):
        return "%s" % self.entry_title
    def _getResNames(self):
        """
        Create a string containing residue name and numbers for all residues in
        this ligand's structure.
        :return: The residue names and numbers
        :rtype: str
        """
        res_names = [
            "%s (%i%s)" % (res.pdbres, res.resnum, res.inscode.strip())
            for res in self.struc.residue
        ]
        chain = self.struc.atom[1].chain
        if chain == " ":
            chain = "_"
        return "%s:%s" % (chain, " - ".join(res_names)) 
[docs]class LigandListView(markers.MarkerMixin, QtWidgets.QListView):
    """
    A list view for ligands.  Note that multiple ligand selection can be enabled
    via `ligand_list_view.setSelectionMode(QtWidgets.QListView.ExtendedSelection)`.
    :cvar ligandSelectionChanged: A signal emitted when the selected ligands
        have changed.
    :vartype ligandSelectionChanged: `QtCore.pyqtSignal`
    """
    ligandSelectionChanged = QtCore.pyqtSignal()
[docs]    def __init__(self, parent=None):
        # See Qt documentation for argument documentation
        super(LigandListView, self).__init__(parent)
        self._auto_include = True
        self._auto_fit = True
        self._use_markers = False
        self._marker_color = (1, 1, 1)
        self._including_entries = False 
[docs]    def setAutoInclude(self, auto_include):
        """
        Specify whether ligands should be included in the Workspace when
        they are selected in the list.  Only relevant when selecting ligands
        that are selected the Project Table.
        :param auto_include: True if ligands should be included.  False
            otherwise.
        :type auto_include: bool
        """
        self._auto_include = auto_include 
[docs]    def setAutoFit(self, auto_fit):
        """
        Specify whether ligands should be zoomed in on in the Workspace when
        they are selected in the list.  Requires auto-include.
        :param auto_fit: True if ligands should be zoomed in on.  False
            otherwise.
        :type auto_fit: bool
        """
        self._auto_fit = auto_fit 
[docs]    def setMarkSelectedLigands(self, use_markers):
        """
        Specify whether ligands that are selected in the ligand list should be
        marked in the workspace.  See `setMarkerColor` to control the marker
        color.
        :param use_markers: True if workspace markers should be used.  False
            otherwise.
        :type use_markers: bool
        """
        self._use_markers = use_markers
        if use_markers:
            self._updateLigandMarkers()
        else:
            self.removeAllMarkers() 
[docs]    def setMarkerColor(self, color):
        """
        Specify the color of the workspace markers used to mark selected
        ligands.  Only has an effect if `setMarkSelectedLigands` has been set
        to True.
        :param color: A tuple of RGB float values for the marker color.
        :type color: tuple
        """
        self._marker_color = color
        if self._use_markers:
            self._updateLigandMarkers() 
[docs]    def selectedLigands(self):
        """
        Return a list of `Ligand` objects for the selected rows.
        """
        return [
            index.data(table_helper.ROW_OBJECT_ROLE)
            for index in self.selectedIndexes()
        ] 
[docs]    def selectLigandsFromAtoms(self, atoms):
        """
        Select all ligands containing the specified atom(s).
        :param atoms:  A list of atoms (`schrodinger.structure._StructureAtom`)
            or a single `schrodinger.structure._StructureAtom`.  This atom must be
            from a Workspace or Project Table structure.
        :type atoms: list or `schrodinger.structure._StructureAtom`
        :raise ValueError: If `atoms` doesn't specify any ligands.  ValueError
            will also be raised if `atoms` specified more than one ligand and the
            view is in SingleSelection selection mode.
        """
        if isinstance(atoms, schrodinger.structure._StructureAtom):
            atoms = [atoms]
        model = self.model()
        atom_data = [(atom.entry_id, atom.number_by_entry) for atom in atoms]
        to_lig = model.atomToLigNumMapping()
        row_nums = {to_lig.get(atom) for atom in atom_data}
        # Ignore any atoms that don't correspond to a ligand
        row_nums.discard(None)
        if not row_nums:
            raise ValueError("No ligands included in selected atoms")
        if len(row_nums) > 1 and self.selectionMode() == self.SingleSelection:
            raise ValueError("Multiple ligands specified by selected atoms")
        indices = [model.index(row, 0) for row in row_nums]
        sel = QtCore.QItemSelection()
        for index in indices:
            sel.select(index, index)
        sel_model = self.selectionModel()
        # Don't auto-zoom while the user is interacting with the Workspace
        with self._disableAutoFit():
            sel_model.select(sel, sel_model.ClearAndSelect) 
    @contextlib.contextmanager
    def _disableAutoFit(self):
        """
        Temporarily disable auto-zooming
        """
        old_auto_fit = self._auto_fit
        self._auto_fit = False
        try:
            yield
        finally:
            self._auto_fit = old_auto_fit
[docs]    def selectLigandsFromWorkspaceAtomNums(self, atom_nums):
        """
        Select all ligands containing the specified atom(s).
        :param atoms:  A list of Workspace atom numbers (ints) or a single
            Workspace atom number.
        :type atoms: list or int
        :raise ValueError: If `atom_nums` doesn't specify any ligands.
            ValueError will also be raised if `atom_nums` specified more than one
            ligand and the view is in SingleSelection selection mode.
        """
        if isinstance(atom_nums, int):
            atom_nums = [atom_nums]
        ws_struc = maestro.workspace_get()
        atoms = [ws_struc.atom[i] for i in atom_nums]
        self.selectLigandsFromAtoms(atoms) 
[docs]    def selectionChanged(self, selected, deselected):
        # See Qt documentation for method documentation
        super(LigandListView, self).selectionChanged(selected, deselected)
        atom_data = self._getSelectedAtomData()
        if self._auto_include:
            self._includeEntries(atom_data.keys())
            # PANEL-5967 - Also exclude the rows that have ben de-selected.
            if self.model()._source == LigSource.selected:
                desel_idxs = deselected.indexes()
                eids = [
                    index.data(table_helper.ROW_OBJECT_ROLE).entry_id
                    for index in desel_idxs
                ]
                self._excludeEntries(eids)
        if self._auto_fit:
            self._fitTo(atom_data)
        if self._use_markers:
            self._updateLigandMarkers()
        self.ligandSelectionChanged.emit() 
    def _getSelectedAtomData(self):
        """
        Return a dictionary of {entry id: set of atom indices} containing all
        atoms in all ligands currently selected in the ligand list.
        """
        atom_data = {}
        for lig in self.selectedLigands():
            atom_data.setdefault(lig.entry_id, set()).update(lig.atom_indexes)
        return atom_data
    def _updateLigandMarkers(self):
        """
        Update the workspace markers used to mark selected ligands.
        """
        self.removeAllMarkers()
        atoms_to_mark = []
        for cur_ligand in self.selectedLigands():
            atoms_to_mark.extend(cur_ligand.struc.atom)
        if atoms_to_mark:
            self.addMarker(atoms_to_mark, self._marker_color)
[docs]    def selectIfNoSelection(self):
        """
        When called this function will select the first ligand in the list
        if no other ligand in the list is currently selected.
        """
        if self.selectionMode() in (self.NoSelection, self.MultiSelection):
            return
        model = self.model()
        if not self.selectedIndexes() and model.rowCount():
            index = model.index(0, 0)
            sel_model = self.selectionModel()
            sel_model.select(index, sel_model.SelectCurrent) 
[docs]    def setModel(self, model):
        # See Qt documentation for method documentation
        super(LigandListView, self).setModel(model)
        model.modelReset.connect(self.selectIfNoSelection)
        model.rowsInserted.connect(self.selectIfNoSelection)
        model.rowsRemoved.connect(self.selectIfNoSelection)
        self.selectIfNoSelection() 
    def _includeEntries(self, eids):
        """
        Make sure that all specified entry ids are included in the Workspace.
        :param eids: An iterable of entry ids to include
        :type eids: iterable
        """
        self._changeEntryInclusion(eids, project.NOT_IN_WORKSPACE,
                                   project.IN_WORKSPACE)
    def _excludeEntries(self, eids):
        """
        Make sure that the specified eids are excluded in the Workspace.
        :param eids: An iterable of entry ids to exclude.
        :type eids: iterable
        """
        self._changeEntryInclusion(eids, project.IN_WORKSPACE,
                                   project.NOT_IN_WORKSPACE)
    @maestro_required
    def _changeEntryInclusion(self, eids, current_state, new_state):
        """
        Check the workspace inclusion state of each entry id in an iterable.
        If the state matches current_state, change the state to new_state.
        :param eids: An iterable of entry ids to exclude.
        :type eids: iterable
        :param current_state: Inclusion state to check in each entry.
        :type current_state: int, should be `project.IN_WORKSPACE` or
            `project.NOT_IN_WORKSPACE`
        :param new_state: State to change entries that match current_state to.
        :type new_state: int, should be one of `project.IN_WORKSPACE` or
            `project.NOT_IN_WORKSPACE`
        """
        self._including_entries = True
        try:
            proj = maestro.project_table_get()
        except project.ProjectException:
            # Project may have been closed during operation
            pass
        else:
            for cur_eid in eids:
                row = proj[cur_eid]
                if row.in_workspace == current_state:
                    row.in_workspace = new_state
        self._including_entries = False
    @maestro_required
    def _fitTo(self, atom_data):
        """
        Fit the Workspace to all specified atoms
        :param atom_data: A dictionary of {entry id: set of atom indices} for
            the atoms to zoom in on.
        :type atom_data: dict
        """
        if atom_data:
            asl = self._createAsl(atom_data)
            maestro.command("fit %s" % asl)
        else:
            maestro.command("fit")
    def _createAsl(self, atom_data):
        """
        Create and ASL specifying the given atoms
        :param atom_data: A dictionary of {entry id: set of atom indices} for
            the atoms to include in the ASL.
        :type atom_data: dict
        :return: The ASL
        :rtype: str
        """
        asl_per_eid = []
        for eid in sorted(atom_data.keys(), key=int):
            atom_nums = sorted(atom_data[eid])
            joined_nums = ",".join(map(str, atom_nums))
            cur_asl = "(entry.id %s AND atom.entrynum %s)" % (eid, joined_nums)
            asl_per_eid.append(cur_asl)
        return " OR ".join(asl_per_eid) 
[docs]class LigandListModel(table_helper.RowBasedListModel):
[docs]    def __init__(self, parent=None, source=LigSource.included):
        """
        :param parent: The Qt parent widget.
        :type parent: `QtWidgets.QWidget` or NoneType
        :param source: The desired ligand source.
        :type source: `LigSource`
        """
        super(LigandListModel, self).__init__(parent)
        self._only_one_entry = True
        self.setLigandSource(source) 
[docs]    def setLigandSource(self, source):
        """
        Specify whether ligands should be taken from selected or included
        Project Table entries.
        :param source: The desired ligand source.
        :type source: `LigSource`
        """
        self._source = source
        self.updateLigandList() 
[docs]    def ligandSource(self):
        """
        Return whether ligands are currently being taken from selected or
        included Project Table entries.
        :return: The current ligand source.
        :rtype: `LigSource`
        """
        return self._source 
[docs]    def updateLigandList(self):
        """
        Update the contents of the model.  This method should be called whenever
        project inclusion or selection changes.
        """
        if not maestro:
            # Do not populate the table when run outside of Maestro. Otherwise,
            # any panel that includes this widget would have to mock this
            # method.
            return
        finder = analyze.AslLigandSearcher()
        all_ligs = []
        proj_rows = self._projRows()
        num_rows = len(proj_rows)
        # FIXME for "included" mode, this code will ignore the scratch entries
        try:
            maestro.project_table_synchronize()
            for row in proj_rows:
                st = row.getStructure(workspace_sync=False)
                for cur_lig in finder.search(st):
                    lig = Ligand(cur_lig, row)
                    all_ligs.append(lig)
        except (mm.MmException, project.ProjectException):
            return
        all_ligs.sort()
        self._only_one_entry = (self._source is LigSource.included and
                                num_rows == 1)
        self.replaceRows(all_ligs) 
    def _projRows(self):
        """
        Get all rows included or selected in the Project Table based on the
        current source.
        :return: An iterable of `schrodinger.project.ProjectRow` for all
            included or selected Project Table rows.
        :rtype: iterable
        """
        if not maestro:
            # unit tests
            return []
        try:
            proj = maestro.project_table_get()
        except project.ProjectException:
            # Project may have been closed during operation
            return []
        if self._source is LigSource.included:
            return proj.included_rows
        elif self._source is LigSource.selected:
            return proj.selected_rows
        else:
            err = ("source must be a valid LigSource value, not %s" %
                   self._source)
            raise ValueError(err)
    @table_helper.data_method(Qt.DisplayRole)
    def _displayData(self, lig):
        return lig.getName(self._only_one_entry)
[docs]    def atomToLigNumMapping(self):
        """
        Return a dictionary of {(entry id, atom number): ligand row number} for
        all atoms in all ligands in the table.
        """
        to_lig = {}
        for lig_num, cur_lig in enumerate(self.rows):
            cur_mapping = {(cur_lig.entry_id, atom_num): lig_num
                           for atom_num in cur_lig.atom_indexes}
            to_lig.update(cur_mapping)
        return to_lig