"""
Functions and classes for plotting enrichment job outputs. e.g. basic
Sensitivity v 1-Specificity plots, etc.
Copyright Schrodinger, LLC. All rights reserved.
"""
import warnings
from scipy import integrate
import schrodinger.utils.log as log
import schrodinger.utils.moduleproxy as moduleproxy
from schrodinger.analysis.enrichment import metrics
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    pylab = moduleproxy.try_import("pylab")
matplotlib = moduleproxy.try_import("matplotlib")
logger = log.get_output_logger(__file__)
[docs]def calc_sensitivity_with_fraction(total_actives, total_ligands, active_ranks,
                                   fraction_of_screen):
    """
    Calculates sensitivity at a particular fraction of screen.
    :param total_actives: The number of all active ligands in the screen, ranked
                          and unranked.
    :type total_actives: int
    :param total_ligands: The number of the total number of ligands (actives and
                          unknowns/decoys) used in the screen.
    :type total_ligands: int
    :param active_ranks: List of *unadjusted* integer ranks for the actives
                    found in the screen. For example, a screen result that
                    placed three actives as the first three ranks has an
                    active_ranks list of = [1, 2, 3].
    :type active_ranks: list(int)
    :param fraction_of_screen: Fraction of screen at which to calculate
                               the specificity.
    :type fraction_of_screen: float
    :return: Sensitivity of the screen at a given fraction of screen.
    :rtype: float
    """
    threshold = round(total_ligands * fraction_of_screen)
    ranked_actives = len([x for x in active_ranks if x < threshold])
    sensitivity = ranked_actives / total_actives
    return sensitivity 
[docs]def calc_sensitivity_with_rank(total_actives, active_ranks, rank):
    """
    Calculates sensitivity at a particular rank, defined as:
        Se(rank) = found_actives / total_actives
    :param total_actives: The number of all active ligands in the screen, ranked
                          and unranked.
    :type total_actives: int
    :param active_ranks: List of *unadjusted* integer ranks for the actives
                    found in the screen. For example, a screen result that
                    placed three actives as the first three ranks has an
                    active_ranks list of = [1, 2, 3].
    :type active_ranks: list(int)
    :param rank: Active rank at which to calculate the specificity.
    :type rank: int
    :return: Sensitivity of the screen at a given rank.
    :rtype: float
    """
    # Number of actives found at given rank
    # offset 0-based index
    found_actives = active_ranks.index(rank) + 1
    return found_actives / float(total_actives) 
[docs]def calc_specificity_with_rank(total_actives, total_ligands, active_ranks,
                               rank):
    """
    Calculates specificity at a particular rank, defined as:
        Sp(rank) = discarded_decoys / total_decoys
    :param total_actives: The number of all active ligands in the screen, ranked
                          and unranked.
    :type total_actives: int
    :param total_ligands: The number of the total number of ligands (actives and
                          unknowns/decoys) used in the screen.
    :type total_ligands: int
    :param active_ranks: List of *unadjusted* integer ranks for the actives
                    found in the screen. For example, a screen result that
                    placed three actives as the first three ranks has an
                    active_ranks list of = [1, 2, 3].
    :type active_ranks: list(int)
    :param rank: Active rank at which to calculate the specificity.
    :type rank: int
    :return: Specificity of the screen at a given rank.
    :rtype: float
    """
    total_decoys = total_ligands - total_actives
    if total_decoys == 0:
        raise ValueError("calculateSpecificity caught ZeroDivisionError")
    # Number of actives found at given rank
    # offset 0-based index
    found_actives = active_ranks.index(rank) + 1
    discarded_decoys = total_decoys - rank + found_actives
    return discarded_decoys / float(total_decoys) 
[docs]def get_percent_screen_curve_points(total_actives, total_ligands, active_ranks):
    """
    :param total_actives: The number of all active ligands in the screen, ranked
                          and unranked.
    :type total_actives: int
    :param total_ligands: The number of the total number of ligands (actives and
                          unknowns/decoys) used in the screen.
    :type total_ligands: int
    :param active_ranks: List of *unadjusted* integer ranks for the actives
                    found in the screen. For example, a screen result that
                    placed three actives as the first three ranks has an
                    active_ranks list of = [1, 2, 3].
    :type active_ranks: list(int)
    :return: List of (%Screen, %Actives Found) tuples for the active ranks.
    """
    curve_points = []
    for act_idx, rank in enumerate(active_ranks, start=1):
        prcnt_act = float(act_idx) / float(total_actives) * 100.0
        prcnt_scrn = float(rank) / float(total_ligands) * 100.0
        item = (prcnt_scrn, prcnt_act)
        logger.debug("getPercentScreenCurvePoints: %s" % str(item))
        curve_points.append(item)
    return curve_points 
[docs]def get_plot_data(total_actives, total_ligands, active_ranks, title_ranks):
    """
    :param total_actives: The number of all active ligands in the screen, ranked
                          and unranked.
    :type total_actives: int
    :param total_ligands: The number of the total number of ligands (actives and
                          unknowns/decoys) used in the screen.
    :type total_ligands: int
    :param active_ranks: List of *unadjusted* integer ranks for the actives
                    found in the screen. For example, a screen result that
                    placed three actives as the first three ranks has an
                    active_ranks list of = [1, 2, 3].
    :type active_ranks: list(int)
    :return: A list of active Title, Rank, Sensitivity, Specificity,
             %Actives Found, %Screen tuples.
    :rtype: list
    :note: This function should only be accessed via Calculator unless we
           figured how to get title_ranks without an additional O(n) work.
    :note: This list may grow, but the relative order of the columns
           should remain fixed.
    """
    headers = [
        'Title',
        'Rank',
        'Specificity',
        '1-Specificity',
        'Sensitivity',
        '%Screen',
        '%Actives Found',
    ]
    rows = []
    rows.append(headers)
    for rank_idx, rank in enumerate(sorted(title_ranks), start=1):
        sp = calc_sensitivity_with_rank(total_actives, active_ranks, rank)
        prcnt_act = float(rank_idx) / float(total_actives) * 100.0
        prcnt_scrn = float(rank) / float(total_ligands) * 100.0
        active_data = [
            title_ranks[rank],  # Title
            rank,
            sp,
            1.0 - sp,
            calc_sensitivity_with_rank(total_actives, active_ranks, rank),
            prcnt_scrn,
            prcnt_act,
        ]
        rows.append(active_data)
    return rows 
[docs]def get_roc_curve_points(total_actives, total_ligands, active_ranks):
    """
    Calculates set of points in ROC curve along each active rank.
    :param total_actives: The number of all active ligands in the screen, ranked
                          and unranked.
    :type total_actives: int
    :param total_ligands: The number of the total number of ligands (actives and
                          unknowns/decoys) used in the screen.
    :type total_ligands: int
    :param active_ranks: List of *unadjusted* integer ranks for the actives
                    found in the screen. For example, a screen result that
                    placed three actives as the first three ranks has an
                    active_ranks list of = [1, 2, 3].
    :type active_ranks: list(int)
    :return: List of (1 - specificity, sensitivity, rank) tuples.
    :rtype: list of tuples
    """
    roc_curve_points = []
    for rank in active_ranks:
        # False positive rate
        x = 1 - calc_specificity_with_rank(total_actives, total_ligands,
                                           active_ranks, rank)
        # True positive rate
        y = calc_sensitivity_with_rank(total_actives, active_ranks, rank)
        roc_curve_points.append((x, y, rank))
    return roc_curve_points 
[docs]def get_roc_area_romberg(lower_limit=0.0, upper_limit=1.0):
    """
    :return: Receiver Operator Characteristic area under the curve as
             defined by a Romberg integration between arbitrary points
             along 1-Sp (domain: 0-1).
    :rtype: float
    """
    if lower_limit < 0:
        msg = "getROCArea:  limit our of 0-1 range: %.f, %.f" % (lower_limit,
                                                                 upper_limit)
        raise ValueError(msg)
    elif lower_limit > upper_limit:
        msg = "getROCArea:  invalid limits: %.f, %.f  " % (lower_limit,
                                                           upper_limit)
        raise ValueError(msg)
    else:
        roc_area = integrate.romberg(
            calc_sensitivity_with_fraction,
            lower_limit,
            upper_limit,
        )
    return roc_area 
[docs]def save_plot(total_actives,
              total_ligands,
              active_ranks,
              adjusted_active_ranks,
              total_ranked,
              png_file="plot.png",
              title='Screen Results',
              xlabel='1-Specificity',
              ylabel='Sensitivity'):
    """
    Saves an image of the ROC plot, Sensitivity v 1-Specificity,
    to a png file.
    :param total_actives: The number of all active ligands in the screen, ranked
                          and unranked.
    :type total_actives: int
    :param total_ligands: The number of the total number of ligands (actives and
                          unknowns/decoys) used in the screen.
    :type total_ligands: int
    :param active_ranks: List of *unadjusted* integer ranks for the actives
                    found in the screen. For example, a screen result that
                    placed three actives as the first three ranks has an
                    active_ranks list of = [1, 2, 3].
    :type active_ranks: list(int)
    :param adjusted_active_ranks: Modified active ranks; each rank is improved
                              by the number of preceding actives. For
                              example, a screen result that placed three
                              actives as the first three ranks, [1, 2, 3],
                              has adjusted ranks of [1, 1, 1]. In this way,
                              actives are not penalized by being outranked
                              by other actives.
    :type adjusted_active_ranks: list(int)
    :param total_ranked: The number of unique ranked ligands. Deduced from
                    results_file or merged_file.
    :type total_ranked: int
    :param png_file: Path to output file, default is 'plot.png'.
    :type png_file: str
    :param title: Plot title, default is 'Screen Results'.
    :type title: str
    :param xlabel: x-axis label, default is '1-Specificity'.
    :type xlabel: str
    :param ylabel: y-axis label, default is 'Sensitivity'.
    :type ylabel: str
    """
    # FIXME: this method shares a log of code with Plotter.savePlot()
    #        Determine if it is more useful to return a 'figure'
    #        rather than serializing the image.
    y_points = [0]
    x_points = [0]
    for x, y, rank in get_roc_curve_points(total_actives, total_ligands,
                                           active_ranks):
        y_points.append(y)
        x_points.append(x)
        msg = "plotROC: rank %d, Se %f, 1-Sp %.2f" % (rank, y, x)
        logger.debug(msg)
    pylab.xlabel(xlabel)
    pylab.ylabel(ylabel)
    pylab.plot(x_points, y_points, 'bo-')
    # Reference curve for random performance.
    sp_ref = [0, 1.0]
    se_ref = [0, 1.0]
    pylab.plot(sp_ref, se_ref, 'k-')
    x_points.append(1.0)  # Add the top right corner to the shaded area.
    y_points.append(y)  # Add the top right corner to the shaded area.
    x_points.append(1.0)  # Add the bottom right corner to shaded area.
    y_points.append(0)  # Add the bottom right corner to shaded area.
    pylab.grid(True)
    pylab.fill(
        x_points,
        y_points,
        facecolor='blue',
        edgecolor='r',
        linewidth=5.0,
        alpha=0.2  # Transparency.
    )
    pylab.title(title)
    pylab.text(
        0.30, 0.20, "BEDROC(\alpha=20, \alpha*Ra=%.4f): %.3f" %
        (metrics.calc_BEDROC(
            total_actives, total_ligands, active_ranks, alpha=20)[1],
         metrics.calc_BEDROC(
             total_actives, total_ligands, active_ranks, alpha=20)[0]))
    pylab.text(
        0.30, 0.15, "RIE: %.3f" %
        metrics.calc_RIE(total_actives, total_ligands, active_ranks, alpha=20))
    pylab.text(
        0.30, 0.10, "ROC: %.3f" %
        metrics.calc_ROC(total_actives, total_ligands, adjusted_active_ranks))
    pylab.text(
        0.30, 0.05, "AUAC: %.3f" % metrics.calc_AUAC(
            total_actives, total_ligands, total_ranked, active_ranks))
    pylab.savefig(png_file)
    pylab.close() 
################################################################################
# Classes
################################################################################
class _BasePlotter(object):
    """
    Class on which Plotter and PercentScreenPlotter are based.
    """
    def __init__(self, calculators, title, xlabel, ylabel, xmax, ymax, legend):
        """
        :param calculators: List of Calculator instances.
        :type calculators: list(calculator.Calculator)
        :param title: The plot title.
        :type title: string
        """
        self.calculators = calculators
        self.title = title
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.legend = legend
        self.xmax = xmax
        self.ymax = ymax
        self.styles = ['bo-', 'bx-', 'b^-', 'b<-', 'b>-', 'bv-']
        self.series_labels = []
        self.legend_location = 4  # Lower right.
        self.legend_font_size = 8
        self.alpha = 0.1  # Fill transparency.
    def getPointsFromCalc(self, calculator):
        """
        Returns points for this metric from the given Calculator instance.
        Implement this method in subclasses.
        :param calculator: Calculator instance.
        :type calculator: calculator.Calculator
        :return List of (x, y) points
        :rtype: list
        """
        raise NotImplementedError()
    def plot(self, indexes=None):
        """
        Launch interactive matplotlib viewer loaded with the plot.
        :param indexes: List of indexes into self.calcs to include in the plot.
                        If the argument is None then all series are plotted.
        :type indexes:  list
        """
        plot_figure = self.getPlotFigure(indexes)
        pylab.figure(num=plot_figure.number)
        pylab.show()
    def showPlotWindow(self, win_title):
        """
        Open a window with this plot.
        :param win_title: Title for the plot window.
        :type win_title: str
        """
        # These imports are here and not at the top of the module because
        # many clients of the module don't need this functionality, and because
        # the import of navtoolbar itself produces annoying warnings when
        # Maestro is not available.
        from matplotlib.backends.backend_qt5agg import \
            FigureCanvasQTAgg as FigureCanvas
        import schrodinger.ui.qt.navtoolbar as navtoolbar
        plot_figure = self.getPlotFigure()
        graph = FigureCanvas(plot_figure)
        nav_toolbar = navtoolbar.NavToolbar(graph, graph.window(), False)
        nav_toolbar.lastDir = '.'
        nav_toolbar.show()
        graph.setWindowTitle(win_title)
        graph.show()
    def savePlot(self, png_file='plot.png'):
        """
        Serialized figure to a png format file.
        :param png_file: Path to output png file.
        :type png_file: string
        :return: None
        """
        plot_figure = self.getPlotFigure()
        pylab.figure(num=plot_figure.number)
        pylab.savefig(png_file)
        pylab.close()
    def getPlotFigure(self, indexes=None):
        """
        Returns a new pylab figure the plot.
        :param indexes: List of indexes into self.calcs to include in the plot.
                        If the argument is None then all series are plotted.
        :type indexes: list
        :return: a pylab figure.
        """
        plot_figure = pylab.figure()
        self.series_labels = []  # Clear any previous labels.
        if indexes is None:
            for calculator in self.calculators:
                self.addSeries(calculator, plot_figure=plot_figure)
        else:
            for index in indexes:
                self.addSeries(self.calcs[index], plot_figure=plot_figure)
        self.addReferenceSeries(plot_figure=plot_figure)
        pylab.title(self.title)
        pylab.legend(self.series_labels, loc=self.legend_location)
        ltext = pylab.gca().get_legend().get_texts()
        pylab.setp(ltext[0], fontsize=self.legend_font_size)
        return plot_figure
    def addSeries(self, calculator, style=None, plot_figure=None):
        """
        :param calculator: Calculator instance.
        :type calculator: calculator.Calculator
        :param style: The matplotlib linestyle.
                      If style is None then a style is selected by fetching
                      the next element from self.styles.
        :type style:  string
        :param plot_figure: A pylab figure.
                            If None then the current figure is used.
        :type plot_figure:  pylab.Figure
        :return: None
        """
        if plot_figure is None:
            plot_figure = pylab.figure()
        pylab.figure(num=plot_figure.number)
        if style is None:
            self.styles = self.styles[1:] + [self.styles[0]]
            style = self.styles[-1]
        # Unzip points to list of x and y coordinates, adding endpoints
        points = [(0.0, 0.0)]
        points += self.getPointsFromCalc(calculator)
        points += [(self.xmax, self.ymax)]
        x, y = list(zip(*points))
        # PANEL-8740: Points are changes in True Positive Rate (y-coordinate),
        # must use "steps-post" to get correct step plot.
        pylab.plot(x, y, style, drawstyle="steps-post", marker=None)
        pylab.xlim([0.0, self.xmax * 1.05])
        pylab.ylim([0.0, self.ymax * 1.05])
        pylab.xlabel(self.xlabel)
        pylab.ylabel(self.ylabel)
        pylab.grid(True)
        self.series_labels.append(self.legend)
    def addReferenceSeries(self, style=None, plot_figure=None):
        """
        Adds a diagonal representing random performance.
        :param style: The matplotlib linestyle.
                      If style is None then 'k-' is used.
        :type style:  string
        :param plot_figure: A pylab figure.
                            If None then the current figure is used.
        :type plot_figure:  pylab.Figure
        :return: None
        """
        if style is None:
            style = 'k-'
        if plot_figure is None:
            plot_figure = pylab.figure()
        pylab.figure(num=plot_figure.number)
        # Reference curve for random performance.
        pylab.plot([0, self.xmax], [0, self.ymax], '#BFBFBF')
        self.series_labels.append('Random')
[docs]class Plotter(_BasePlotter):
    """
    A class to plot multiple series of Calculator instances.
    API example where `enrich_calc1` and `enrich_calc2` are instances of
    Calculator::
        enrich_plotter = plotter.Plotter([enrich_calc1, enrich_calc2])
        enrich_plotter.plot() # Launch interactive plot window.
        enrich_plotter.savePlot('my_plot.png') # Save plot to file.
    There are six line styles defined by default.  Plotting more than
    six results cycles through the styles.
    """
[docs]    def __init__(self,
                 calculators,
                 title='Screen Results',
                 xlabel='1-Specificity',
                 ylabel='Sensitivity',
                 legend_label='Screen Results'):
        """
        :param calculators: List of Calculator instances.
        :type calculators: list(calculator.Calculator)
        :param title: The plot title.
        :type title:  string
        :param xlabel: The x-axis label.
        :type xlabel: str
        :param ylabel: The y-axis label.
        :type ylabel: str
        :param legend_label: The legend label.
        :type legend_label: str
        """
        super(Plotter, self).__init__(calculators,
                                      title,
                                      xlabel,
                                      ylabel,
                                      xmax=1.0,
                                      ymax=1.0,
                                      legend=legend_label) 
[docs]    def getPointsFromCalc(self, calculator):
        """
        Returns points for this metric from the given Calculator instance.
        :param calculator: Calculator instance.
        :type calculator: calculator.Calculator
        :return List of (x, y) points
        :rtype: list
        """
        return [(x, y) for x, y, rank in get_roc_curve_points(
            calculator.total_actives, calculator.total_ligands,
            calculator.active_ranks)]  
[docs]class PercentScreenPlotter(_BasePlotter):
    """
    A class to plot multiple series of Calculator data as %Actives Found
    vs %Screen.
    API example where `enrich_calc1` and `enrich_calc2` are instances of
    Calculator::
        enrich_plotter = plotter.PercentScreenPlotter([enrich_calc1,
                                                       enrich_calc2])
        enrich_plotter.plot() # Launch interactive plot window.
        enrich_plotter.savePlot('my_plot.png') # Save plot to file.
    There are six line styles defined by default.  Plotting more than
    six results cycles through the styles.
    """
[docs]    def __init__(self,
                 calculators,
                 title='Screen Results',
                 xlabel='Percent Screen',
                 ylabel='Percent Actives Found',
                 legend_label='Screen Results'):
        """
        :param calculators: List of Calculator instances.
        :type calculators: list(calculator.Calculator)
        :param title: The plot title.
        :type title: string
        :param xlabel: The x-axis label.
        :type xlabel: str
        :param ylabel: The y-axis label.
        :type ylabel: str
        :param legend_label: The legend label.
        :type legend_label: str
        """
        super(PercentScreenPlotter, self).__init__(calculators,
                                                   title,
                                                   xlabel,
                                                   ylabel,
                                                   xmax=100.0,
                                                   ymax=100.0,
                                                   legend=legend_label) 
[docs]    def getPointsFromCalc(self, calculator):
        """
        Returns points for this metric from the given Calculator instance.
        :param calculator: Calculator instance.
        :type calculator: calculator.Calculator
        :return List of (x, y) points
        :rtype: list
        """
        return get_percent_screen_curve_points(calculator.total_actives,
                                               calculator.total_ligands,
                                               calculator.active_ranks)