"""
File containing plot related code used in the Trajectory Plots GUI
"""
import csv
import uuid
from enum import Enum
from enum import auto
import openpyxl
from schrodinger.models import mappers
from schrodinger.Qt import QtCore
from schrodinger.Qt import QtGui
from schrodinger.Qt import QtWidgets
from schrodinger.Qt.PyQt5 import QtChart
from schrodinger.Qt.QtCore import Qt
from schrodinger.ui.qt import basewidgets
from schrodinger.ui.qt import filedialog
from schrodinger.ui.qt.standard.icons import icons
from schrodinger.utils import csv_unicode
from . import advanced_plot_ui
from . import energy_plot_ui
from . import collapsible_plot_ui
from . import shortcut_ui
from . import traj_plot_models
from . import energy_plots
from . import plots as tplots
# Plot Constants
MAX_AXIS_TICKS = 5
MIN_AXIS_SPAN = 0.1
MAX_SHORTCUTS_IN_ROW = 3
IMAGE_WIDTH = 20000
SERIES_WIDTH = 1
# Plot context menu actions
SHOW = "Show"
HIDE = "Hide"
SAVE_IMG = "Save Image..."
EXPORT_CSV = "Export as CSV..."
EXPORT_EXCEL = "Export to Excel..."
DELETE = "Delete"
VIEW_PLOT = "View Plot..."
# Colors
RMSF_COLOR = QtGui.QColor.fromRgb(158, 31, 222)
TEMP_COLOR = QtGui.QColor.fromRgb(207, 105, 31)
HELIX_COLOR = QtGui.QColor.fromRgb(253, 236, 232)
STRAND_COLOR = QtGui.QColor.fromRgb(229, 246, 250)
# Series Names
B_FACTOR_SERIES = 'b_factor_series'
SS_HELIX_SERIES = 'secondary_structure_helix_series'
SS_STRAND_SERIES = 'secondary_structure_strand_series'
#############################
# ENUMS
#############################
[docs]class TrajectoryPlotType(Enum):
    """
    Enum of plot types to generate
    """
    MEASUREMENT_WORKSPACE = auto()
    MEASUREMENT_ADD = auto()
    MEASUREMENT_PLANAR_ANGLE = auto()
    MEASUREMENT_CENTROID = auto()
    INTERACTIONS_ALL = auto()
    INTERACTIONS_HYDROGEN_BONDS = auto()
    INTERACTIONS_HALOGEN_BONDS = auto()
    INTERACTIONS_SALT_BRIDGE = auto()
    INTERACTIONS_PI_PI = auto()
    INTERACTIONS_CAT_PI = auto()
    DESCRIPTORS_RMSD = auto()
    DESCRIPTORS_ATOM_RMSF = auto()
    DESCRIPTORS_RES_RMSF = auto()
    DESCRIPTORS_RADIUS_GYRATION = auto()
    DESCRIPTORS_PSA = auto()
    DESCRIPTORS_SASA = auto()
    DESCRIPTORS_MOLECULAR_SA = auto()
    ENERGY_ALL_GROUPED = auto()
    ENERGY_ALL_INDIVIDUAL = auto()
    ENERGY_INDIVIDUAL_MOLECULES = auto()
    ENERGY_CUSTOM_SUBSTRUCTURE_SETS = auto()
    ENERGY_CUSTOM_ASL_SETS = auto() 
ENERGY_PLOT_TYPES = {
    TrajectoryPlotType.ENERGY_ALL_GROUPED,
    TrajectoryPlotType.ENERGY_ALL_INDIVIDUAL,
    TrajectoryPlotType.ENERGY_INDIVIDUAL_MOLECULES,
    TrajectoryPlotType.ENERGY_CUSTOM_SUBSTRUCTURE_SETS,
    TrajectoryPlotType.ENERGY_CUSTOM_ASL_SETS
}
PlotDataType = Enum('PlotDataType', ('RMSF', 'TRAJECTORY', 'ENERGY'))
#############################
# Plot Formatting Functions
#############################
[docs]def handle_chart_legend(chart, is_multiseries_interactions):
    """
    Sets the chart legend depending on the type of the chart
    :param chart: Chart containing legend
    :type chart: QtChart.QChart
    :param is_multiseries_interactions: is this a multiseries interaction plot
    :type is_multiseries_interactions: bool
    """
    legend = chart.legend()
    if is_multiseries_interactions:
        legend.setShowToolTips(True)
        legend.setAlignment(Qt.AlignBottom)
    else:
        legend.hide() 
[docs]def set_series_width(series, width):
    """
    Sets the pen width of the series
    :param series: Series to check
    :type series: QLineSeries
    :param width: Width to set
    :type width: int
    """
    pen = series.pen()
    pen.setWidth(width)
    series.setPen(pen) 
[docs]def slim_chart(chart):
    """
    Removes as much unnecessary padding from a chart as possible
    :param chart: The chart to slim
    :type chart: QtChart.QChart
    """
    chart.layout().setContentsMargins(0, 0, 0, 0)
    chart.setWindowFrameMargins(0, 0, 0, 0)
    chart.setBackgroundRoundness(0) 
def _is_series_ss(series):
    """
    Returns whether series is a series representing a Secondary Structure
    :param series: Series to check
    :type series: QLineSeries
    """
    return type(series) in [
        SecondaryStructureStrandSeries, SecondaryStructureHelixSeries
    ]
def _generateAxisSpecifications(data, axis):
    """
    Sets axis values based on provided data
    :param data: Data for series on axis
    :type data: list
    :param axis: Axis to set
    :type axis: QValueAxis
    """
    axis_values = set(round(val, 1) for val in data)
    # set min
    axis_min = min(axis_values)
    if axis.min() < axis_min:
        axis_min -= MIN_AXIS_SPAN
    axis.setMin(axis_min)
    # set max
    axis_max = max(axis_values)
    if axis.max() > axis_max:
        axis_max += MIN_AXIS_SPAN
    axis.setMax(axis_max)
    # set ticks
    num_ticks = min(MAX_AXIS_TICKS,
                    (axis_max - axis_min + MIN_AXIS_SPAN) / MIN_AXIS_SPAN)
    axis.setTickCount(num_ticks)
#############################
# TRADITIONAL PLOTS
#############################
[docs]class AbstractTrajectoryChartView(QtChart.QChartView):
    """
    QChartView subclass shared by all trajectory plots.
    """
    # This signal is emitted when a point is clicked in the view:
    plotClicked = QtCore.pyqtSignal(QtCore.QPointF)
[docs]    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setMouseTracking(True)
        self._mouse_press_pos = None 
[docs]    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self._mouse_press_pos = event.pos()
        super().mousePressEvent(event) 
[docs]    def mouseReleaseEvent(self, event):
        """
        Find the frame that the user's left click selected.
        Display selection used in the task input.
        """
        if event.button() == Qt.LeftButton:
            release_pos = event.pos()
            if release_pos == self._mouse_press_pos:
                # User has not dragged the mouse.
                value = self.chart().mapToValue(release_pos)
                self.plotClicked.emit(value)
        super().mouseReleaseEvent(event)  
[docs]class AbstractTrajectoryPlot(QtCore.QObject):
    """
    Base class for storing plot data for trajectory analysis data. Also holds
    a reference to the chart view.
    Note that we cannot simply use QtChart.QLineSeries.clicked for this because it
    does not appear to trigger on OS X.
    :ivar displayAsl: Display the asl for the corresponding entry id Signal.
        args are (asl, entry_id)
    :type displayAsl: `QtCore.pyqtSignal(str, int)`
    :ivar displayFrameAndAsl: Change frame and show ASL for given entry id
        Signal args are (asl, entry_id, frame_number)
    :type displayFrameAndAsl: `QtCore.pyqtSignal(str, int, int)`
    """
    displayAsl = QtCore.pyqtSignal(str, int)
    displayFrameAndAsl = QtCore.pyqtSignal(str, int, int)
[docs]    def __init__(self,
                 *args,
                 parent=None,
                 task=None,
                 trj=None,
                 eid=None,
                 settings_hash=None,
                 **kwargs):
        """
        :param task: Finished trajectory task
        :type task: tasks.AbstractTask
        :param trj: Trajectory associated with this plot
        :type trj: playertoolbar.EntryTrajectory
        :param eid: Entry ID for this plot
        :type eid: int
        :param settings_hash: Optional additional settings string to further
                              identify a plot as unique beyond its
                              analysis mode and fit ASL.
        :type settings_hash: str
        """
        super().__init__(parent=parent)
        self.view = AbstractTrajectoryChartView(*args, **kwargs)
        self.window = self.view.window()
        self.view.plotClicked.connect(self.onPlotClicked)
        self.task = task
        self.fit_asl = None
        if self.task is not None:
            self.fit_asl = self.task.output.fit_asl
        self.settings_hash = settings_hash
        self.trj = trj
        self.eid = eid
        self.time_to_frame_map = None
        self.series_map = {}
        if self.trj:
            trj = self.trj
            # Traj Player widgets use 1-based indexing so we so here as well.
            self.time_to_frame_map = {
                fr.time / 1000: idx for idx, fr in enumerate(trj, start=1)
            } 
[docs]    def chart(self):
        return self.view.chart() 
[docs]    def onPlotClicked(self, value):
        if not self.task:
            return
        if self.time_to_frame_map:
            time = value.x()
        frame_idx = self._getNearestFrameForTime(time)
        if frame_idx is not None:
            self.displayFrameAndAsl.emit(self.fit_asl, self.eid, frame_idx) 
    def _getNearestFrameForTime(self, time):
        """
        Given a time value, return the frame nearest to that time.
        :param time: Time to get the nearest frame for
        :type time: float
        :return: 1-based frame index closest to the specified time or None if time
                 is out of range.
        :rtype: int or None
        """
        all_times = list(self.time_to_frame_map.keys())
        if time < min(all_times) or time > max(all_times):
            return None
        nearest_key = None
        for key in self.time_to_frame_map:
            if nearest_key is None or abs(time - key) < abs(time - nearest_key):
                nearest_key = key
        return self.time_to_frame_map[nearest_key]
[docs]    def getPlotType(self):
        """
        Returns what type of plot this class uses.
        Subclasses must override.
        """
        raise NotImplementedError 
[docs]    def getDataForExport(self):
        """
        Return a list of row data to export to CSV or Excel.
        Subclasses must override.
        :return: Data to be exported
        :rtype: list(list)
        """
        raise NotImplementedError 
[docs]    def getDataForExportFromMainPanel(self):
        """
        Most panels export the same data whether export was selected from
        the plot panel or the main panel. Override this method to export
        different type of data when exporting from the parent panel, via
        the "Export Results..." button.
        """
        return self.getDataForExport() 
[docs]    def exportToCSV(self):
        """
        Export plot data to a CSV file
        """
        fpath = filedialog.get_save_file_name(
            parent=self.window,
            caption="Save as CSV",
            filter="Comma-separated value (*.csv)")
        if not fpath:
            return
        rows = self.getDataForExport()
        with csv_unicode.writer_open(fpath) as fh:
            writer = csv.writer(fh)
            for row in rows:
                writer.writerow(row) 
[docs]    def exportToExcel(self):
        """
        Export data to an .xls file
        """
        fpath = filedialog.get_save_file_name(parent=self.window,
                                              caption="Save as Excel Workbook",
                                              filter='Excel (*.xls)')
        if not fpath:
            return
        wb = openpyxl.Workbook()
        ws = wb.active
        for row in self.getDataForExport():
            ws.append(row)
        wb.save(fpath) 
[docs]    def saveImage(self):
        """
        Save a .png file of the plot
        """
        fpath = filedialog.get_save_file_name(parent=self.window,
                                              caption="Save Image",
                                              filter="PNG (*.png)")
        if not fpath:
            return
        view = self.view
        aspect_ratio = view.height() / view.width()
        # make sure image has high enough resolution for publication use.
        pixmap = QtGui.QPixmap(IMAGE_WIDTH, int(IMAGE_WIDTH * aspect_ratio))
        pixmap.fill(Qt.transparent)
        painter = QtGui.QPainter(pixmap)
        view.render(painter)
        pixmap.save(fpath)
        painter.end() 
 
[docs]class TrajectoryAnalysisPlot(AbstractTrajectoryPlot):
    """
    Chart class used for graphs with an x-axis of frames
    """
[docs]    def getDataForExport(self):
        """
        Return a list of row data to export to CSV or Excel.
        :return: Data to be exported
        :rtype: list(list)
        """
        rows = []
        header_row = ["Frame", "Time (ns)"]
        series_keys = list(self.series_map.keys())
        series_titles = series_keys
        # If there is a single series title use plot widget title instead.
        if len(series_titles) == 1:
            series_titles = [self.parent().getPlotTitle()]
        header_row.extend(series_titles)
        rows.append(header_row)
        for time, idx in self.time_to_frame_map.items():
            row = [idx, time]
            for series in series_keys:
                row.append(self.series_map[series][time])
            rows.append(row)
        return rows 
[docs]    def getPlotType(self):
        return PlotDataType.TRAJECTORY  
[docs]class BaseRmsfPlot(AbstractTrajectoryPlot):
    """
    Chart class for time series data.
    These contain callouts describing which point the user is hovering over.
    """
[docs]    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.initializeCallouts()
        self._callout = None
        self._series = None
        self.id = None 
[docs]    def initializeCallouts(self):
        """
        Initializes the plot to accept events
        """
        chart = self.chart()
        chart.setAcceptHoverEvents(True)
        self.view.scene().addItem(chart) 
[docs]    def onPlotClicked(self, value):
        """
        Fire a signal to show ASL of selection on left click
        """
        if not self.task:
            return
        # Over-ride base method, because x-axis represents atoms/residues
        # and not trajectory frames.
        data_x = round(value.x())
        if self.task.input.analysis_mode == traj_plot_models.AnalysisMode.AtomRMSF:
            asl = f'atom.n {self.task.input.atom_numbers[data_x]}'
        elif self.task.input.analysis_mode == traj_plot_models.AnalysisMode.ResRMSF:
            res_lbl = self.task.output.residue_info.residue_names[data_x]
            atoms = self.task.input.residue_atoms[res_lbl]
            asl = f"atom.n {','.join(map(str, atoms))}"
        self.displayAsl.emit(asl, self.eid) 
[docs]    def enableSeriesTracking(self):
        chart = self.chart()
        series = chart.series()
        for line in series:
            if not _is_series_ss(series):
                line.hovered.connect(self.onHover)
        # Explicitly save a reference to the series so it doesn't get destroyed (PANEL-18838)
        self._series = self.chart().series() 
[docs]    def generateCalloutText(self, pos):
        rmsf_info = temp_info = ''
        callout_text_list = []
        data_x = round(pos.x())
        for series in self.chart().series():
            series_type = type(series)
            if not _is_series_ss(series):
                data_point = series.at(data_x)
                if series_type == OutputSeries:
                    rmsf_info = f'RMSF = {data_point.y():.2f} Å'
                if series_type == BFactorSeries and series.isVisible():
                    temp_info = f'B Factor = {round(data_point.y(), 1)}'
        if self.task.input.analysis_mode == traj_plot_models.AnalysisMode.AtomRMSF:
            atom_info = self.task.input.atom_labels[data_x]
            callout_text_list = [atom_info, rmsf_info]
        elif self.task.input.analysis_mode == traj_plot_models.AnalysisMode.ResRMSF:
            res_info = self.task.output.residue_info.residue_names[data_x]
            callout_text_list = [res_info, rmsf_info]
            if temp_info:
                callout_text_list.insert(1, temp_info)
        return callout_text_list 
[docs]    def onHover(self, pos, enter):
        series = self.sender()
        if self._callout is None:
            text_list = self.generateCalloutText(pos)
            callout = Callout(self.chart(), series, pos, text_list)
            callout.setZValue(1)
            self._callout = callout
        if enter:
            self.view.scene().addItem(self._callout)
        else:
            self.view.scene().removeItem(self._callout)
            self._callout = None 
[docs]    def getPlotType(self):
        return PlotDataType.RMSF  
[docs]class AtomRmsfPlot(BaseRmsfPlot):
[docs]    def getDataForExport(self):
        """
        Return a list of row data to export to CSV or Excel.
        :return: Data to be exported
        :rtype: list(list)
        """
        header_row = ['Atom Index']
        series_titles = self.series_map.keys()
        header_row.extend(series_titles)
        rows = [header_row]
        for series in series_titles:
            for idx, (key, value) in enumerate(self.series_map[series].items()):
                if idx >= len(rows) - 1:
                    rows.append([key])
                rows[idx + 1].append(value)
        return rows  
[docs]class ResidueRmsfPlot(BaseRmsfPlot):
[docs]    def getDataForExport(self):
        """
        Return a list of row data to export to CSV or Excel.
        :return: Data to be exported
        :rtype: list(list)
        """
        res_names = self.task.output.residue_info.residue_names
        # Residue plots always have a single series
        assert len(self.series_map) == 1
        plot_title, values_dict = next(iter(self.series_map.items()))
        header_row = ['Residue Index', 'Residue', plot_title]
        rows = [header_row]
        for (key, value), res_name in zip(values_dict.items(), res_names):
            row = [key, res_name, value]
            rows.append(row)
        return rows 
 
[docs]class EnergyPlotPlot(AbstractTrajectoryPlot):
    """
    Chart class for energy matrix data.
    The plot data will be populated by the EnergyPlotPanel.
    """
[docs]    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.results = None
        self.frame_times = None
        self.energies = None
        self.chart_title = ''
        self.set_names = []
        if self.task:
            self.set_names = self.task.input.set_names 
[docs]    def getPlotType(self):
        return PlotDataType.ENERGY 
[docs]    def enableSeriesTracking(self):
        pass 
[docs]    def loadFromTask(self, task):
        """
        Load in results from the given task.
        :param task: Task to get result data from.
        :type task: traj_plot_models.TrajectoryEnergyJobTask
        """
        results, set_asls, frame_times = \
            
energy_plots.parse_results_file(task.output.results_file)
        if results.val is None:
            # Task failed to complete.
            raise RuntimeError(
                'Energy analysis task failed to produce results.')
        self.set_asls = set_asls  # used for export
        self.results = results
        # Convert picoseconds to nanoseconds:
        self.frame_times = [time / 1000.0 for time in frame_times] 
[docs]    def setPlotData(self, energies):
        """
        Set self.energies array to the given data, and re-draw the chart.
        """
        self.energies = energies
        chart = self.chart()
        chart.setTitle(self.chart_title)
        chart.removeAllSeries()
        series = tplots.OutputSeries()
        if not chart.axes():
            # Create left/horizontal axis
            self.x_axis = QtChart.QValueAxis()
            self.x_axis.setTitleText('Time (ns)')
            chart.addAxis(self.x_axis, Qt.AlignBottom)
            series.attachAxis(self.x_axis)
            self.y_axis = tplots.OutputAxis()
            self.y_axis.setLabelFormat('%.0f')
            self.y_axis.setTitleText('Energy (kCal/mol)')
            chart.addAxis(self.y_axis, Qt.AlignLeft)
            series.attachAxis(self.y_axis)
        if self.energies is None:
            # No sets or terms selected
            return
        # Add data series to the plot:
        for x_time, y_energy in zip(self.frame_times, self.energies):
            series.append(x_time, y_energy)
        _generateAxisSpecifications(self.energies, self.y_axis)
        self.x_axis.setMin(min(self.frame_times))
        self.x_axis.setMax(max(self.frame_times))
        chart.addSeries(series) 
[docs]    def getDataForExport(self):
        """
        Return a list of row data to export to Excel or CSV. Used by the
        export menu in the plot sub-window.
        """
        rows = []
        header_row = ['Frame', 'Time (ns)', 'Energy (kCal/mol)']
        rows.append(header_row)
        for idx, (time, energy) in enumerate(zip(self.frame_times,
                                                 self.energies),
                                             start=1):
            row = [idx, time, energy]
            rows.append(row)
        return rows 
[docs]    def getDataForExportFromMainPanel(self):
        """
        Return a list of row data to export to Excel.
        Used by the "Export Results..." button of the parent plots panels.
        :return: Data to be exported
        :rtype: list(list)
        """
        all_values_dict = energy_plots.format_results_by_frame(self.results)
        rows = []
        # Print ASLs for each set, above the header row:
        for i, asl in enumerate(self.set_asls):
            rows.append([f"sel_{i:03}", asl])
        # Header row:
        header = ["Frame", "Time (ns)"] + list(all_values_dict.keys())
        rows.append(header)
        # Energies, one row per frame:
        energy_lists = all_values_dict.values()
        for idx, time in enumerate(self.frame_times):
            energies_for_frame = [energies[idx] for energies in energy_lists]
            row = [idx + 1, time] + energies_for_frame
            rows.append(row)
        return rows  
[docs]class Callout(QtWidgets.QGraphicsItem):
    """
    A callout is a rounded rectangle that displays values for a point on a QChart
    """
[docs]    def __init__(self, chart, series, pos, text_list):
        self.font = QtGui.QFont()
        self.chart = chart
        self.series = series
        self.pos = pos
        self.text_list = text_list
        self.bounding_rect = None
        super().__init__() 
[docs]    def paint(self, painter, option, widget):
        rect = self.boundingRect()
        light_blue = QtGui.QColor(220, 220, 255)
        painter.setBrush(light_blue)
        painter.drawRoundedRect(rect, 5, 5)
        text_rect = rect.adjusted(5, 5, 5, 5)
        text = '\n'.join(self.text_list)
        painter.drawText(text_rect, Qt.AlignLeft, text) 
[docs]    def generateBoundingRect(self):
        """
        Creates a bounding rect based on text length/height and chart position
        """
        # Generate metrics for callout text
        fm = QtGui.QFontMetrics(self.font)
        buffer = 10
        text_width = max(*[fm.width(text) for text in self.text_list],
                         30) + buffer
        text_height = max(fm.height() * len(self.text_list), 30) + buffer
        # Create rectangle, flipping orientation if at risk of escaping chart
        pt = self.chart.mapToPosition(self.pos, self.series)
        x0, y0 = pt.x(), pt.y()
        x1 = x0 - text_width
        if x1 < 0:
            x1 = x0 + text_width
        y1 = y0 - text_height
        if y1 < 0:
            y1 = y0 + text_height
        pt0 = QtCore.QPointF(min(x0, x1), min(y0, y1))
        pt1 = QtCore.QPointF(max(x0, x1), max(y0, y1))
        rect = QtCore.QRectF(pt0, pt1)
        return rect 
[docs]    def boundingRect(self):
        if self.bounding_rect is None:
            self.bounding_rect = self.generateBoundingRect()
        return self.bounding_rect  
[docs]class CollapsiblePlot(QtWidgets.QWidget):
    """
    This class defines a collapsible plot. The widget has an
    area for title text, a 'collapse' button and a 'close' button.
    :ivar widgetClosed: Signal emitted when the widget is closed.
                        Emits a float containing the widget's id and
                        whether it is an interaction plot.
    :type widgetClosed: `QtCore.pyqtSignal(str, bool)`
    """
    widgetClosed = QtCore.pyqtSignal(str, bool)
[docs]    def __init__(self,
                 parent=None,
                 system_title='',
                 plot_title='',
                 plot=None,
                 is_interactions=False,
                 tooltip=None):
        """
        :param system_title: System title for the widget
        :type system_title: str
        :param plot_title: Title to set for this title bar
        :type plot_title: str
        :param plot: Plot to set in the collapsible area
        :type plot: `AbstractTrajectoryPlot`
        :param is_interactions: Whether plot is an interactions plot
        :type is_interactions: bool
        :param tooltip: Optional tooltip for the title
        :type tooltip: str
        """
        super().__init__(parent=parent)
        self.ui = collapsible_plot_ui.Ui_Form()
        self.ui.setupUi(self)
        self.ui.collapse_btn.clicked.connect(self.onCollapseButtonClicked)
        self.ui.close_btn.clicked.connect(self.close)
        self.ui.system_title_label.setText(system_title)
        self.ui.plot_title_le.setText(plot_title)
        if len(plot_title) >= 45:
            if tooltip is None:
                tooltip = plot_title
            else:
                tooltip = plot_title + ': ' + tooltip
            plot_title = plot_title[:42] + '...'
        self.eid = None
        self.ui.plot_title_le.setText(plot_title)
        if tooltip:
            tooltip += '<br><i>Double-click to edit</i>'
            self.ui.plot_title_le.setToolTip(tooltip)
        self.plot = plot
        self.system_title = system_title
        self.fit_asl = None
        self.settings_hash = None
        self.id = str(uuid.uuid4())
        self.is_interaction_plot = is_interactions
        if plot is not None:
            self.setPlot(plot) 
[docs]    def setPlot(self, plot):
        """
        Set the plot in the collapsible area to the specified object.
        :param plot: Plot to set in the collapsible area.
        :type plot: `AbstractTrajectoryPlot`
        """
        self.fit_asl = plot.fit_asl
        self.settings_hash = plot.settings_hash
        if self.plot is not None:
            if self.plot == plot:
                return
            else:
                self.ui.widget_layout.removeWidget(self.plot)
                self.plot.deleteLater()
        self.plot = plot
        plot.view.setSizePolicy(QtWidgets.QSizePolicy.Expanding,
                                QtWidgets.QSizePolicy.Expanding)
        self.eid = self.plot.eid
        self.ui.widget_layout.addWidget(plot.view)
        plot.view.setVisible(True) 
[docs]    def setPlotTitle(self, plot_title):
        self.ui.plot_title_le.setText(plot_title) 
[docs]    def getPlotTitle(self):
        """
        Returns plot title.
        :return: plot title
        :rtype: str
        """
        return self.ui.plot_title_le.text() 
[docs]    def close(self):
        """
        Close and remove this widget.
        """
        if self.layout():
            self.widgetClosed.emit(self.id, self.is_interaction_plot)
            self.layout().removeWidget(self)
            self.deleteLater()
        super().close() 
[docs]    def mousePressEvent(self, event):
        if event.button() == Qt.RightButton:
            self.plot.showContextMenu()
        super().mousePressEvent(event)  
#############################
# ADVANCED PLOTS AND SHORTCUTS
#############################
[docs]class BasePlotPanel(basewidgets.Panel):
    pass 
[docs]class BaseAdvancedPlotPanel(BasePlotPanel):
    """
    Base class for plot panels that get opened via shortcuts in the
    "Advanced Plots" section of the main plots panel.
    """
    def _showContextMenu(self):
        menu = QtWidgets.QMenu(self)
        menu.addAction(SAVE_IMG, self.plot.saveImage)
        menu.addAction(EXPORT_CSV, self.plot.exportToCSV)
        menu.addAction(EXPORT_EXCEL, self.plot.exportToExcel)
        menu.addSeparator()
        menu.addAction(DELETE, lambda: self.closeRequested.emit(self.id))
        menu.exec_(QtGui.QCursor.pos()) 
[docs]class RmsfPlotPanel(BaseAdvancedPlotPanel):
    """
    Advanced plots are for time-series data (e.x. RMSF)
    :cvar closeRequested: Signal emitted when the widget is closed.
                          Emits a str containing the widget's id
    :type closeRequested: `QtCore.pyqtSignal(str)`
    """
    ui_module = advanced_plot_ui
    model_class = traj_plot_models.RmsfPlotModel
    closeRequested = QtCore.pyqtSignal(str)
[docs]    def __init__(self, plot, mode, parent=None):
        self.plot = plot
        self.chart = plot.chart()
        self.mode = mode
        self.id = plot.id
        super().__init__(parent) 
[docs]    def initSetUp(self):
        super().initSetUp()
        residue_mode = self.mode is traj_plot_models.AnalysisMode.ResRMSF
        self.ui.residue_info_wdg.setVisible(residue_mode)
        self.ui.residue_options_wdg.setVisible(residue_mode)
        self.ui.plot_layout.addWidget(self.plot.view)
        self.ui.options_link.clicked.connect(self._onOptionsToggle)
        self.ui.close_btn.clicked.connect(self.close) 
[docs]    def defineMappings(self):
        M = self.model_class
        ui = self.ui
        b_factor_trg = mappers.TargetSpec(ui.pdb_b_factor_cb,
                                          slot=self._onBFactorToggle)
        ss_trg = mappers.TargetSpec(ui.secondary_st_color_cb,
                                    slot=self._onSecondaryStructureToggle)
        return [
            (ss_trg, M.secondary_structure_colors),
            (b_factor_trg, M.b_factor_plot),
        ]  # yapf: disable 
    def _onBFactorToggle(self):
        visible = self.model.b_factor_plot
        if self.mode is traj_plot_models.AnalysisMode.ResRMSF:
            for series in self.chart.series():
                if type(series) == BFactorSeries:
                    series.setVisible(visible)
            for axis in self.chart.axes():
                if type(axis) == BFactorAxis:
                    axis.setVisible(visible)
    def _onSecondaryStructureToggle(self):
        visible = self.model.secondary_structure_colors
        if self.mode is traj_plot_models.AnalysisMode.ResRMSF:
            for series in self.chart.series():
                if _is_series_ss(series):
                    series.setVisible(visible)
    def _onOptionsToggle(self):
        visible = not self.ui.residue_options_wdg.isVisible()
        self.ui.residue_options_wdg.setVisible(visible)
[docs]    def mousePressEvent(self, event):
        if event.button() == Qt.RightButton:
            self._showContextMenu()
        super().mousePressEvent(event)  
[docs]class AdvancedPlotShortcut(basewidgets.BaseWidget):
    """
    Shortcut icon that opens an advanced plots (RMSF and Energy plots).
    :cvar widgetClosed: Signal emitted when the widget is closed.
                        Emits a str containing the widget's id
    :type widgetClosed: `QtCore.pyqtSignal(str)`
    """
    ui_module = shortcut_ui
    widgetClosed = QtCore.pyqtSignal(str)
[docs]    def __init__(self, plot, shortcut_title='', window_title='', parent=None):
        super().__init__(parent)
        self.plot = plot
        self.plot.setWindowTitle(window_title)
        self.plot.closeRequested.connect(self._closeEvent)
        icon = QtGui.QPixmap(":/trajectory_gui_dir/icons/adv_plot.png")
        self.ui.icon_lbl.setPixmap(icon)
        self.ui.shortcut_lbl.setText(shortcut_title) 
[docs]    def mousePressEvent(self, event):
        super().mousePressEvent(event)
        if event.button() == QtCore.Qt.LeftButton:
            self.plot.show()
            self.plot.raise_()
        if event.button() == QtCore.Qt.RightButton:
            self._showContextMenu() 
    def _showContextMenu(self):
        menu = QtWidgets.QMenu(self)
        menu.addAction(VIEW_PLOT)
        menu.addSeparator()
        menu.addAction(DELETE)
        res = menu.exec_(QtGui.QCursor.pos())
        if not res:
            return
        res_txt = res.text()
        if res_txt == VIEW_PLOT:
            self.plot.show()
            self.plot.raise_()
        elif res_txt == DELETE:
            self.deleteShortcut()
[docs]    def deleteShortcut(self):
        """
        Remove this shortcut, and the plot associated with it.
        """
        self._closeEvent(self.plot.id) 
    def _closeEvent(self, plot_id):
        self.setVisible(False)
        self.widgetClosed.emit(plot_id)
        self.plot.close()
        self.close() 
[docs]class ShortcutRow(basewidgets.BaseWidget):
    """
    This class represents a row of advanced plot shortcuts
    """
[docs]    def initLayOut(self):
        super().initLayOut()
        spacer = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding,
                                       QtWidgets.QSizePolicy.Minimum)
        self.row_layout = QtWidgets.QHBoxLayout()
        self.row_layout.addItem(spacer)
        self.main_layout.addLayout(self.row_layout) 
[docs]    def hasSpace(self):
        """
        Returns whether the shortcut row has space for another widget
        """
        return self.widgetCount() < MAX_SHORTCUTS_IN_ROW 
 
#############################
# Custom Series and Axes
#############################
[docs]class OutputAxis(QtChart.QValueAxis):
    pass 
[docs]class BFactorAxis(QtChart.QValueAxis):
    pass 
[docs]class SecondaryStructureAxis(QtChart.QValueAxis):
    pass 
[docs]class OutputSeries(QtChart.QLineSeries):
    pass 
[docs]class BFactorSeries(QtChart.QLineSeries):
    pass 
[docs]class SecondaryStructureHelixSeries(QtChart.QAreaSeries):
    pass 
[docs]class SecondaryStructureStrandSeries(QtChart.QAreaSeries):
    pass 
[docs]class EnergyPlotPanel(BaseAdvancedPlotPanel):
    """
    Plot for energy analysis.
    :cvar closeRequested: Signal emitted when the widget is closed.
                          Emits a str containing the widget's id
    :type closeRequested: `QtCore.pyqtSignal(str)`
    """
    ui_module = energy_plot_ui
    model_class = traj_plot_models.EnergyPlotModel
    closeRequested = QtCore.pyqtSignal(str)
[docs]    def __init__(self, plot_view, parent=None):
        self.plot = plot_view
        self.id = plot_view.id
        self.chart = plot_view.chart()
        super().__init__(parent)
        self.setWindowTitle('Review Energy Plots') 
[docs]    def initSetUp(self):
        super().initSetUp()
        self.ui.close_btn.clicked.connect(self.close)
        self.ui.plot_layout.addWidget(self.plot.view)
        hheader = self.ui.sets_table.view.horizontalHeader()
        hheader.setStretchLastSection(True)
        hheader.hide()
        pixmap = QtGui.QPixmap(icons.MORE_ACTIONS_DB)
        self.ui.options_btn.setIcon(QtGui.QIcon(pixmap))
        self.ui.options_btn.setIconSize(QtCore.QSize(30, 15))
        self.ui.options_btn.setStyleSheet('border: none;')
        self.ui.options_btn.clicked.connect(self._showContextMenu)
        self.resize(self.width(), 800)  # make taller 
[docs]    def initFinalize(self):
        super().initFinalize()
        # Populate the sets PLPTableWidget with sets from our model:
        sets = []
        for i, name in enumerate(self.plot.set_names):
            row = traj_plot_models.SetRow()
            row.name = name
            sets.append(row)
        self.model.sets = sets
        spec = self.ui.sets_table.makeAutoSpec(self.model.sets)
        self.ui.sets_table.setSpec(spec)
        self.ui.sets_table.setPLP(self.model.sets)
        # By default select all sets:
        self.model.selected_sets = [s for s in sets] 
[docs]    def defineMappings(self):
        M = self.model_class
        ui = self.ui
        return [
            (ui.sets_table, M.sets),
            (ui.sets_table.selection_target, M.selected_sets),
            (ui.exclude_self_terms_cb, M.exclude_self_terms),
            (ui.coulomb_cb, M.coulomb),
            (ui.van_der_waals_cb, M.van_der_waals),
            (ui.bond_cb, M.bond),
            (ui.angle_cb, M.angle),
            (ui.dihedral_cb, M.dihedral),
        ]  # yapf: disable 
[docs]    def getSignalsAndSlots(self, model):
        return [
            (model.selected_setsChanged, self.updatePlotValues),
            (model.exclude_self_termsChanged, self.updatePlotValues),
            (model.coulombChanged, self.updatePlotValues),
            (model.van_der_waalsChanged, self.updatePlotValues),
            (model.bondChanged, self.updatePlotValues),
            (model.angleChanged, self.updatePlotValues),
            (model.dihedralChanged, self.updatePlotValues),
        ]  # yapf: disable 
    def _updateEnergyToggles(self):
        """
        Update check box states for different energy terms based on current UI state.
        """
        ui = self.ui
        disable = len(self.model.selected_sets) == 1 or \
            
(not ui.coulomb_cb.isChecked() and not ui.van_der_waals_cb.isChecked())
        self._enableEnergyToggle(ui.exclude_self_terms_cb, not disable)
        enable = not self.model.exclude_self_terms
        self_term_toggles = [ui.bond_cb, ui.angle_cb, ui.dihedral_cb]
        for cb in self_term_toggles:
            self._enableEnergyToggle(cb, enable)
    def _enableEnergyToggle(self, check_box, enable):
        """
        Enable or disable given check box depending on the 'enable' argument.
        :param check_box: energy term check box
        :type check_box: QtWidgets.QCheckBox
        :param enable: defines whether check box should be enabled or not
        :type enable: bool
        """
        check_box.setEnabled(enable)
        if not enable:
            check_box.setChecked(False)
[docs]    def updatePlotValues(self):
        """
        Slot for updating the chart based on current UI selection.
        """
        # Update state of energy check boxes.
        self._updateEnergyToggles()
        m = self.model
        term_name_map = {
            'Coulomb': m.coulomb,
            'van der Waals': m.van_der_waals,
            'Bond': m.bond,
            'Angle': m.angle,
            'Dihedral': m.dihedral,
        }
        terms_used = [name for name, param in term_name_map.items() if param]
        num_terms_used = len(terms_used)
        if num_terms_used == 0:
            pass
        elif num_terms_used == len(term_name_map):
            term_str = 'Total Energy'
        elif num_terms_used == 1:
            term_str = terms_used[0] + ' Energy'
        elif num_terms_used == 2:
            term_str = ' and '.join(terms_used) + ' Energies'
        else:
            term_str = ', '.join(
                terms_used[:-1]) + ' and ' + terms_used[-1] + ' Energies'
        if num_terms_used == 0 or not m.selected_sets:
            title = ''
        else:
            title = ' - '.join((setrow.name for setrow in m.selected_sets))
            if m.exclude_self_terms:
                title += ' Interactions'
            title += ': ' + term_str
        self.ui.plot_title_lbl.setText(title)
        self.plot.setPlotData(self.getEnergyValues()) 
[docs]    def getEnergyValues(self):
        """
        Return the energy values based on the current panel settings.
        :return:
        """
        m = self.model
        use_sets = []
        for i, set in enumerate(m.sets):
            if set in m.selected_sets:
                result_id = f'sel_{i:03}'
                use_sets.append(result_id)
        if not use_sets:
            return None
        checked_by_term = {
            'elec': m.coulomb,
            'vdw': m.van_der_waals,
            'stretch': m.bond,
            'angle': m.angle,
            'dihedral': m.dihedral,
        }
        use_terms = [
            name for name, checked in checked_by_term.items() if checked
        ]
        if not use_terms:
            return None
        include_self = not m.exclude_self_terms
        return energy_plots.sum_results(self.plot.results, use_sets, use_terms,
                                        include_self)