Source code for schrodinger.application.desmond.mplchart
"""
Tools for using matplotlib charts.
Copyright Schrodinger, LLC. All rights reserved.
"""
# Contributors: Dave Giesen
import atexit
import os
import shutil
import tempfile
from past.utils import old_div
import matplotlib
import matplotlib.font_manager as font_manager
import numpy
from matplotlib import cm as mpl_cmap
# Even though we don't use Axes3D directly, we have to import it so that
# matplotlib recognizes 3D as a valid projection.
from schrodinger.utils import fileutils
[docs]def remove_temp_dir(path):
    """
    Remove the temporary directory we created
    """
    shutil.rmtree(path, True)
def _get_home():
    """
    Find user's home directory if possible.
    Otherwise raise error.
    """
    path = ''
    try:
        path = os.path.expanduser("~")
    except:
        pass
    if not os.path.isdir(path):
        for evar in ('HOME', 'USERPROFILE', 'TMP'):
            try:
                path = os.environ[evar]
                if os.path.isdir(path):
                    break
            except:
                pass
    return path
def _is_writable_dir(path):
    """
    path is a string pointing to a putative writable dir -- return True p
    is such a string, else False
    """
    if path:
        try:
            with tempfile.TemporaryFile(mode='w', dir=path) as t:
                t.write('1')
        except OSError:
            return False
        else:
            return True
    return False
def _get_configdir():
    """
    Figure out where matplotlib is going to try to write its configuration
    directory
    """
    home = _get_home()
    cpath = os.path.join(home, '.matplotlib')
    if os.path.exists(cpath):
        if _is_writable_dir(cpath):
            return cpath
    else:
        if _is_writable_dir(home):
            try:
                os.mkdir(cpath)
            except OSError:
                # DESMOND-7025, race condition
                pass
            return cpath
    return False
# This module can be run in situations where the normal matplotlib directory is
# not writable.  We'll use the exact method matplotlib uses to determine this to
# see if we need to redirect to a temp directory.
if not _get_configdir():
    error, tempdir = fileutils.get_directory(fileutils.TEMP)
    if error or not tempdir or not _is_writable_dir(tempdir):
        tempdir = tempfile.gettempdir()
    newconfigdir = os.path.join(tempdir, '.matplotlib')
    if not os.path.exists(newconfigdir):
        os.mkdir(newconfigdir)
        atexit.register(remove_temp_dir, newconfigdir)
    os.environ['MPLCONFIGDIR'] = newconfigdir
COLOR_NAME = [
    "black",
    "red",
    "green",
    "blue",
    "purple",
    "yellow",
    "orange",
    "violet",
    "skyblue",
    "gold",
    "grey",
]
MARKER_TRANSLATION = {
    "cross": "+",
    "rectangle": "s",
    "diamond": "d",
    "circle": "o",
    "square": "s",
    "x": "x",
    "arrow": "^"
}
DEFAULT_FONTSIZE = 'x-small'
default_size = font_manager.FontManager().get_default_size()
FONT_SCALINGS = {
    'xx-small': 0.579 * default_size,
    'x-small': 0.694 * default_size,
    'small': 0.833 * default_size,
    'medium': 1.0 * default_size,
    'large': 1.200 * default_size,
    'x-large': 1.440 * default_size,
    'xx-large': 1.728 * default_size
}
[docs]def prevent_overlapping_x_labels(canvas, axes_number=0):
    """
    Given a canvas that contains a figure that contains at least one axes
    instance, checks the x-axis tick labels to make sure they don't overlap.  If
    they do, the number of ticks is reduced until there is no overlap.
    :type canvas: matplotlib canvas object
    :param canvas: the canvas that contains the figure/axes objects
    :type axes_number: int
    :param axes_number: the index of the axes on the figure to examine.  Default
        is 0, which is the first set of axis added to the figure.
    """
    # Force the canvas to draw, or it won't have determined the tick marks
    # Disable draw due to https://github.com/matplotlib/matplotlib/issues/10874
    #canvas.draw()
    xaxis = canvas.figure.get_axes()[axes_number].get_xaxis()
    labels = xaxis.get_majorticklabels()
    overlap = True
    # matplotlib will insist on keeping a minimum number of labels (some of
    # which may be blank), so to be safe we need to set a maximum number of
    # times we'll try to reduce the number of labels.
    max_tries = len(labels) - 1
    tries = 0
    while overlap and tries < max_tries:
        tries = tries + 1
        overlap = False
        for index, label in enumerate(labels):
            if index:
                old_label_pos = labels[index -
                                       1].get_window_extent().get_points()
                new_label_pos = labels[index].get_window_extent().get_points()
                # For unknown reason, sometimes there are labels of zero area.
                # Excludes such labels from comparison.
                if (old_label_pos[0][0] == old_label_pos[1][0] or
                        new_label_pos[0][0] == new_label_pos[1][0]):
                    continue
                # Determine if the left side of this label overlaps the right
                # side of the previous label
                old_right_x = old_label_pos[1][0]
                new_left_x = new_label_pos[0][0]
                if label.get_text():
                    if new_left_x - old_right_x < 1:
                        overlap = True
                        break
        if overlap:
            # Reduce the number of ticks by one to make room for the labels
            locator = xaxis.get_major_locator()
            # Number of bins = number of ticks - 1, so to get one fewer tick, we
            # need two fewer bins than the current number of ticks.
            locator.set_params(nbins=len(labels) - 2)
            # Must reset the ticks - matplotlib leaves an extra tick at the end
            # when it figures out new ticks. This at least makes sure the last
            # tick is blank.
            xaxis.reset_ticks()
            canvas.draw()
            labels = xaxis.get_majorticklabels()
[docs]def get_xy_plot_widget(xvals, *ylists, **kw):
    """
    Create a scatter or line chart.  The line chart may optionally have error
    bars associated with it.  Multiple series can be plotted by passing in more
    than one list of y values, i.e. get_xy_plot(x, y1, y2, chart_type='scatter')
    The plot is returned in a QFrame widget.
    :type xvals: list
    :param xvals: the x values to plot
    :type ylists: one or more lists
    :keyword ylists: Each y series to plot should be given as an argument to the
        function, and each should be the same length as x
    :type err_y: list of lists
    :keyword err_y: the i'th item of err_y is a list of numerical error bars,
        one for each item in the i'th y list
    :type chart_type: str
    :keyword chart_type: type of chart to produce
            - scatter: scatterplot
            - line: line (default)
    :type marker: tuple
    :keyword marker: tuple of (symbol, color, size), only used for scatter
        plots:
            - symbol (1-character str)
            - s - square ('square', rectangle accepted)
            - o - circle ('circle' accepted)
            - ^ - triangle up ('arrow' accepted)
            - > - triangle right
            - < - triangle left
            - v - triangle down
            - d - diamond ('diamond' accepted)
            - p - pentagon
            - h - hexagon
            - 8 - octagon
            - + - plus ('cross' accepted)
            - x - x
        - color (str):
            - black
            - red
            - green
            - blue
            - purple
            - yellow
            - orange
            - violet
            - skyblue
            - gold
            - grey
        - size (int)
    :type size: tuple
    :keyword size: (x, y) plot size
    :type x_label: str
    :keyword x_label: X-axis label
    :type y_label: str
    :keyword y_label: Y-axis label
    :type x_range: tuple
    :keyword x_range: (min, max) values for the X-axis
    :type y_range: tuple
    :keyword y_range: (min, max) values for the Y-axis
    :type color: list
    :keyword color: list of color names to cycle through.  See marker:color for
        some color names.
    :type bg: str
    :keyword bg: color name for the plot background.  See marker:color for some
        color names.
    :type legend: list
    :keyword legend: list of strings, each item is the name of a y series in the
        legend
    :type title: str
    :keyword title: the title of the plot
    :type dpi: int
    :keyword dpi: dots per inch for the plot
    :type fontsize: int or str
    :keyword fontsize: size in points, or one of the following -
            - xx-small
            - x-small
            - small
            - medium
            - large
            - x-large
            - xx-large
    :rtype: QFrame
    :return: The QFrame widget that contains the plot
    """
    from schrodinger.ui.qt import smatplotlib
    from schrodinger.ui.qt import swidgets
    # Check for a PyQt application instance and create one if needed:
    if not QtWidgets.QApplication.instance():
        app = QtWidgets.QApplication([])
    # Grab some of the keywords, the rest will be passed on
    dpi = int(kw.get('dpi', 100))
    plotsize = kw.get('size', (300, 200))
    width = old_div(plotsize[0], dpi)
    height = old_div(plotsize[1], dpi)
    # Create the plot and return it
    frame = QtWidgets.QFrame()
    layout = swidgets.SVBoxLayout(frame)
    canvas = smatplotlib.SmatplotlibCanvas(width=width,
                                           height=height,
                                           dpi=dpi,
                                           layout=layout)
    create_mpl_plot_on_figure(canvas.figure, xvals, *ylists, **kw)
    # Make sure the tick labels don't overlap
    prevent_overlapping_x_labels(canvas)
    return frame
[docs]def get_xy_plot(xvals, *ylists, **kw):
    """
    Create a scatter or line chart.  The line chart may optionally have error
    bars associated with it.  Multiple series can be plotted by passing in more
    than one list of y values, i.e. get_xy_plot(x, y1, y2, chart_type='scatter')
    The plot is saved as an image file,  the filename parameter should contain
    the path to the file to write.  filename is written and the value of
    filename is simply returned.
    :type xvals: list
    :param xvals: the x values to plot
    :type ylists: one or more lists
    :keyword ylists: Each y series to plot should be given as an argument to the
        function, and each should be the same length as x
    :type err_y: list of lists
    :keyword err_y: the i'th item of err_y is a list of numerical error bars,
        one for each item in the i'th y list
    :type chart_type: str
    :keyword chart_type: type of chart to produce
            - scatter: scatterplot
            - line: line (default)
    :type marker: tuple
    :keyword marker: tuple of (symbol, color, size), only used for scatter plots
            - symbol (1-character str)
            - s - square ('square', rectangle accepted)
            - o - circle ('circle' accepted)
            - ^ - triangle up ('arrow' accepted)
            - > - triangle right
            - < - triangle left
            - v - triangle down
            - d - diamond ('diamond' accepted)
            - p - pentagon
            - h - hexagon
            - 8 - octagon
            - + - plus ('cross' accepted)
            - x - x
        - color (str)
            - black
            - red
            - green
            - blue
            - purple
            - yellow
            - orange
            - violet
            - skyblue
            - gold
            - grey
        - size (int)
    :type size: tuple
    :keyword size: (x, y) plot size
    :type x_label: str
    :keyword x_label: X-axis label
    :type y_label: str
    :keyword y_label: Y-axis label
    :type x_range: tuple
    :keyword x_range: (min, max) values for the X-axis
    :type y_range: tuple
    :keyword y_range: (min, max) values for the Y-axis
    :type color: list
    :keyword color: list of color names to cycle through.  See marker:color for
        some color names.
    :type bg: str
    :keyword bg: color name for the plot background.  See marker:color for some
        color names.
    :type legend: list
    :keyword legend: list of strings, each item is the name of a y series in the
        legend
    :type title: str
    :keyword title: the title of the plot
    :type dpi: int
    :keyword dpi: dots per inch for the plot
    :type fontsize: int or str
    :keyword fontsize: size in points, or one of the following -
            - xx-small
            - x-small
            - small
            - medium
            - large
            - x-large
            - xx-large
    :type filename: str
    :keyword filename: The pathway to a file that the image of this plot should
        be saved in.
    :type format: str
    :param format: the image format to save the chart in.  Must be a
        matplotlib-recognized format argument to Figure.savefig(format=nnn).
        Default is nnn='png'
    :rtype: filename
    :return: The filename the image of the plot was saved into (same string as
        passed in with the filename keyword).
    """
    from matplotlib.backends.backend_agg import FigureCanvasAgg
    from matplotlib.figure import Figure
    # Grab some of the keywords, the rest will be passed on
    dpi = int(kw.get('dpi', 100))
    plotsize = kw.get('size', (300, 200))
    width = old_div(plotsize[0], dpi)
    height = old_div(plotsize[1], dpi)
    # Create the basic plot area
    figure = Figure(figsize=(width, height), dpi=dpi)
    canvas = FigureCanvasAgg(figure)
    create_mpl_plot_on_figure(figure, xvals, *ylists, **kw)
    # Make sure the tick labels don't overlap
    prevent_overlapping_x_labels(canvas)
    # Save the plot to a png file
    filename = kw.get('filename')
    format = kw.get('format', 'png')
    figure.savefig(filename, dpi=dpi, orientation='landscape', format=format)
    return filename
[docs]def create_mpl_plot_on_figure(figure, xvals, *ylists, **kw):
    """
    Create a scatter or line chart.  The line chart may optionally have error
    bars associated with it.  Multiple series can be plotted by passing in more
    than one list of y values, i.e. get_xy_plot(x, y1, y2, chart_type='scatter')
    :type figure: matplotlib Figure object
    :param figure: the Figure object the plot should be created on.
    :type xvals: list
    :param xvals: the x values to plot
    :type ylists: one or more lists
    :keyword ylists: Each y series to plot should be given as an argument to the
        function, and each should be the same length as x
    :type err_y: list of lists
    :keyword err_y: the i'th item of err_y is a list of numerical error bars,
        one for each item in the i'th y list
    :type chart_type: str
    :keyword chart_type: type of chart to produce
            - scatter: scatterplot
            - line: line (default)
    :type marker: tuple
    :keyword marker: tuple of (symbol, color, size), only used for scatter
        plots:
        - symbol (1-character str):
            - s - square ('square', rectangle accepted)
            - o - circle ('circle' accepted)
            - ^ - triangle up ('arrow' accepted)
            - > - triangle right
            - < - triangle left
            - v - triangle down
            - d - diamond ('diamond' accepted)
            - p - pentagon
            - h - hexagon
            - 8 - octagon
            - + - plus ('cross' accepted)
            - x - x
        - color (str):
            - black
            - red
            - green
            - blue
            - purple
            - yellow
            - orange
            - violet
            - skyblue
            - gold
            - grey
        - size (int)
    :type x_label: str
    :keyword x_label: X-axis label
    :type y_label: str
    :keyword y_label: Y-axis label
    :type x_range: tuple
    :keyword x_range: (min, max) values for the X-axis
    :type y_range: tuple
    :keyword y_range: (min, max) values for the Y-axis
    :type color: list
    :keyword color: list of color names to cycle through.  See marker:color for
        some color names.
    :type bg: str
    :keyword bg: color name for the plot background.  See marker:color for some
        color names.
    :type legend: list
    :keyword legend: list of strings, each item is the name of a y series in the
        legend
    :type title: str
    :keyword title: the title of the plot
    :type fontsize: int or str
    :keyword fontsize: size in points, or one of the following -
            - xx-small
            - x-small
            - small
            - medium
            - large
            - x-large
            - xx-large
    """
    title = kw.get('title')
    bg = kw.get('bg', 'white')
    legend = kw.get('legend', [])
    x_label = kw.get('x_label')
    y_label = kw.get('y_label')
    colors = kw.get('color', COLOR_NAME[:])
    layout = kw.get('layout')
    fontsize = kw.get('fontsize', DEFAULT_FONTSIZE)
    figure.set_facecolor(bg)
    chart_type = kw.get('chart_type', 'line')
    if chart_type == "scatter":
        symbol, color, size = kw.get('marker', ('o', 'skyblue', 0))
        symbol = MARKER_TRANSLATION.get(symbol, symbol)
    xstart = 0.20
    if legend:
        xextent = 1.0 - (xstart + .20)
    else:
        xextent = 1.0 - (xstart + .10)
    plot = figure.add_axes([xstart, .2, xextent, .7])
    plot.set_facecolor(bg)
    plot.tick_params(labelsize=fontsize)
    # Remove the axis lines and ticks on the top and right
    plot.spines['right'].set_color('none')
    plot.spines['top'].set_color('none')
    plot.xaxis.set_ticks_position('bottom')
    plot.yaxis.set_ticks_position('left')
    # Add labels
    if title:
        plot.set_title(title, size=fontsize)
    if x_label:
        plot.set_xlabel(x_label, size=fontsize)
    if y_label:
        plot.set_ylabel(y_label, size=fontsize)
    # Axes ranges
    x_range = kw.get('x_range', (min(xvals), max(xvals)))
    y_range = kw.get(
        'y_range',
        (min([min(y) for y in ylists]), max([max(y) for y in ylists])))
    # Compiles error bars - if they exist, err_y ends up being a list of
    # [lower bound, upperbound] items for each list in y
    num_y = len(ylists)
    err_y = kw.get('err_y')
    if err_y:
        err_top = []
        err_bot = []
        # ylists and err_y are both lists of lists, each inner list is the data
        # for a series.
        for this_y, this_err_y in zip(ylists, err_y):
            for e, f in zip(this_y, this_err_y):
                err_top.append(e + f)
                err_bot.append(e - f)
            err_y.append([err_top, err_bot])
            if 'y_range' not in kw:
                err_min = min(min(err_bot), min(err_top))
                ymin = min(y_range[0], err_min)
                err_max = max(max(err_bot), max(err_top))
                ymax = max(y_range[1], err_max)
                y_range = (ymin, ymax)
    plot.set_xlim(x_range)
    plot.set_ylim(y_range)
    # Plot each series
    for aseries in range(len(ylists)):
        num_cols = len(colors)
        color = colors[aseries % num_cols]
        try:
            label = legend[aseries]
        except IndexError:
            label = 'none'
        if chart_type == 'scatter':
            plot.scatter(xvals,
                         ylists[aseries],
                         c=color,
                         marker=symbol,
                         edgecolors='none',
                         s=9,
                         label=label)
        else:
            if err_y:
                plot.plot(xvals, ylists[aseries], c=color, label=label)
                # Creates data lists for points with nonzero errors.
                xvals_err, yvals_err, yerr_err = list(
                    zip(*[
                        e for e in list(
                            zip(xvals, ylists[aseries], err_y[aseries])) if e[2]
                    ]))
                plot.errorbar(xvals_err,
                              yvals_err,
                              yerr=yerr_err,
                              c=color,
                              ls='none',
                              label=label)
            else:
                plot.plot(xvals, ylists[aseries], c=color, label=label)
    # Put on a legend
    if legend:
        backend = (old_div(1.0, (xstart + xextent)) + .07, 0.50)
        if chart_type == 'scatter':
            plot.legend(prop={'size': fontsize},
                        frameon=False,
                        loc='center right',
                        scatterpoints=1,
                        handlelength=0.1,
                        handletextpad=0.5,
                        bbox_to_anchor=backend)
        else:
            plot.legend(prop={'size': fontsize},
                        frameon=False,
                        loc='center right',
                        numpoints=1,
                        handletextpad=0.3,
                        handlelength=1.0,
                        bbox_to_anchor=backend)
[docs]def get_2var_plot(data, **kw):
    """
    Create a 2d contour or 3d surface plot.
    The plot is saved as an image file,  the filename parameter should contain
    the path to the file to write.  filename is written and the value of
    filename is simply returned.
    :type data: list of tuples
    :param data: the (x, y, z) values to plot.  List items should be arranged so
        they are sorted by X and then by Y (so that the Y varies faster than X), and
        there is a Z value for every X/Y combination.
    :type chart_type: str
    :keyword chart_type: type of chart to produce
            - contour: 2d contour
            - surface: 3d surface
            - wireframe: 3d wireframe surface
    :type size: tuple
    :keyword size: (x, y) plot size
    :type x_label: str
    :keyword x_label: X-axis label
    :type y_label: str
    :keyword y_label: Y-axis label
    :type z_label: str
    :keyword z_label: Z-axis label (surface and wireframe only)
    :type x_range: tuple
    :keyword x_range: (min, max) values for the X-axis.  (contour plots only)
    :type y_range: tuple
    :keyword y_range: (min, max) values for the Y-axis  (contour plots only)
    :type x_reverse: bool
    :keyword x_reverse: True if the X axis should be reversed, False (default)
        if not
    :type y_reverse: bool
    :keyword y_reverse: True if the Y axis should be reversed, False (default)
        if not
    :type z_reverse: bool
    :keyword z_reverse: True if the Z axis should be reversed, False (default)
        if not (surface and wireframe only)
    :type color_map: str
    :param color_map: Name of a matplotlib color map
    :type color_range: tuple
    :keyword color_range: (min, max) of the color range.  Values of min and
        below will get the minimum color, values of max and above will get the
        maximum color.  Setting all the contour levels with the levels keyword may
        the be preferred way of accomplishing this.
    :type bg: str
    :keyword bg: color name for the plot background.
    :type legend: bool
    :keyword legend: True if a colorbar legend that shows the contour levels
        should be included, False if not (False is default)
    :type legend_format: str
    :keyword legend_format: String format specifier for the colorbar legend.
        This format is also used for contour labels.
    :type legend_orientation: str
    :keyword legend_orientation: Either 'vertical' (default) or 'horizontal'
    :type title: str
    :keyword title: the title of the plot
    :type dpi: int
    :keyword dpi: dots per inch for the plot
    :type fontsize: int or str
    :keyword fontsize: size in points, or one of the following -
            - xx-small
            - x-small
            - small
            - medium
            - large
            - x-large
            - xx-large
    :type filename: str
    :keyword filename: The pathway to a file that the image of this plot should
        be saved in.
    :type format: str
    :keyword format: the image format to save the chart in.  Must be a
        matplotlib-recognized format argument to Figure.savefig(format=nnn).
        Default is nnn='png'
    :type fill: bool
    :keyword fill: True if the contours should be filled, False (default) if
        lines on a white background (contour plots only)
    :type labels: bool
    :keyword labels: True (default) if the contour lines should be labeled,
        False if not (contour plots only)
    :type contours: int
    :keyword contours: The number of contour lines to draw (default of 0
        indicates that matplotlib should choose the optimal number) (contour plots
        only)
    :type levels: list
    :keyword levels: list of values to place contour lines at (contour plots
        only)
    :type viewpoint: tuple of 2 float
    :param viewpoint: (a, b) describing the point of view from which the plot is
        viewed, where a is the elevation and b is the rotation of the plot.  Default
        is (45, 45). (surface and wireframe only)
    :rtype: filename
    :return: The filename the image of the plot was saved into (same string as
        passed in with the filename keyword).
    """
    from matplotlib.backends.backend_agg import FigureCanvasAgg
    from matplotlib.figure import Figure
    # Grab some of the keywords, the rest will be passed on
    dpi = int(kw.get('dpi', 100))
    plotsize = kw.get('size', (300, 200))
    width = old_div(plotsize[0], dpi)
    height = old_div(plotsize[1], dpi)
    # Create the basic plot area
    figure = Figure(figsize=(width, height), dpi=dpi)
    canvas = FigureCanvasAgg(figure)
    chart_type = kw.get('chart_type', 'contour')
    if chart_type in ['surface', 'wireframe']:
        create_surface_on_figure(figure, data, **kw)
    else:
        create_contour_plot_on_figure(figure, data, **kw)
    # Make sure the tick labels don't overlap
    prevent_overlapping_x_labels(canvas)
    # Save the plot to a png file
    filename = kw.get('filename')
    format = kw.get('format', 'png')
    figure.savefig(filename, dpi=dpi, orientation='landscape', format=format)
    return filename
def _get_legend_format(values):
    """
    Determine the format for legend values.  The format is based on the type and
    span of the items in values:
        - If type(values[0]) is int, format will be integer
        - If span `(abs(max(values)-min(values))) < 1` format is 3 decimal
          places
        - If span < 10 format is 2 decimal places
        - If span < 100 format is 1 decimal place
        - Else format is integer
    :type values: iterable
    :param values: The list of values to check for type and span
    :rtype: str
    :return: A Python format string (such as '%.1f') for the legend values.
    """
    if not values or isinstance(values[0], int):
        return '%d'
    span = abs(max(values) - min(values))
    if span < 1.0:
        return '%.3f'
    elif span < 10.0:
        return '%.2f'
    elif span < 100.0:
        return '%.1f'
    else:
        return '%.0f'
[docs]def create_contour_plot_on_figure(figure, data, **kw):
    """
    Create a 2d contour or 3d surface plot.
    The plot is saved as an image file,  the filename parameter should contain
    the path to the file to write.  filename is written and the value of
    filename is simply returned.
    :type figure: matplotlib Figure object
    :param figure: the Figure object the plot should be created on.
    :type data: list of tuples
    :param data: the (x, y, z) values to plot.  List items should be arranged so
        they are sorted by X and then by Y (so that the Y varies faster than X), and
        there is a Z value for every X/Y combination.
    :type x_label: str
    :keyword x_label: X-axis label
    :type y_label: str
    :keyword y_label: Y-axis label
    :type x_range: tuple
    :keyword x_range: (min, max) values for the X-axis
    :type y_range: tuple
    :keyword y_range: (min, max) values for the Y-axis
    :type x_reverse: bool
    :keyword x_reverse: True if the X axis should be reversed, False (default)
        if not
    :type y_reverse: bool
    :keyword y_reverse: True if the Y axis should be reversed, False (default)
        if not
    :type color_map: str
    :param color_map: Name of a matplotlib color map
    :type color_range: tuple
    :keyword color_range: (min, max) of the color range.  Values of min and
        below will get the minimum color, values of max and above will get the
        maximum color.  Setting all the contour levels with the levels keyword may
        the be preferred way of accomplishing this.
    :type bg: str
    :keyword bg: color name for the plot background.
    :type legend: bool
    :keyword legend: True if a colorbar legend that shows the contour levels
        should be included, False if not (False is default)
    :type legend_format: str
    :keyword legend_format: String format specifier for the colorbar legend.
        This format is also used for contour labels.
    :type legend_orientation: str
    :keyword legend_orientation: Either 'vertical' (default) or 'horizontal'
    :type title: str
    :keyword title: the title of the plot
    :type fontsize: int or str
    :keyword fontsize: size in points, or one of the following -
            - xx-small
            - x-small
            - small
            - medium
            - large
            - x-large
            - xx-large
    :type filename: str
    :keyword filename: The pathway to a file that the image of this plot should
        be saved in.
    :type fill: bool
    :keyword fill: True if the contours should be filled, False (default) if
        lines on a white background
    :type labels: bool
    :keyword labels: True (default) if the contour lines should be labeled,
        False if not
    :type contours: int
    :keyword contours: The number of contour lines to draw (default of 0
        indicates that matplotlib should choose the optimal number)
    :type levels: list
    :keyword levels: list of values to place contour lines at
    """
    title = kw.get('title')
    bg = kw.get('bg', 'white')
    legend = kw.get('legend', False)
    legend_format = kw.get('legend_format', None)
    legend_orientation = kw.get('legend_orientation', 'vertical')
    x_label = kw.get('x_label')
    y_label = kw.get('y_label')
    color_map = kw.get('color_map', None)
    color_range = kw.get('color_range', None)
    layout = kw.get('layout')
    fontsize = kw.get('fontsize', DEFAULT_FONTSIZE)
    fill = kw.get('fill', False)
    labels = kw.get('labels', True)
    contours = kw.get('contours', 0)
    levels = kw.get('levels', [])
    xreverse = kw.get('x_reverse', False)
    yreverse = kw.get('y_reverse', False)
    figure.set_facecolor(bg)
    xstart = 0.20
    plot = figure.add_axes([.13, .13, .8, .8])
    plot.set_facecolor(bg)
    plot.tick_params(labelsize=fontsize)
    plot.xaxis.set_ticks_position('bottom')
    plot.yaxis.set_ticks_position('left')
    # Add labels
    if title:
        plot.set_title(title, size=fontsize)
    if x_label:
        plot.set_xlabel(x_label, size=fontsize)
    if y_label:
        plot.set_ylabel(y_label, size=fontsize)
    # Unroll the data list.  We need a list of X of length M, a list
    # of Y of length N, and a 2-D array of Z sized MxN
    try:
        xvals = []
        yvals = []
        ztemp = []
        firstx = True
        for point in data:
            pointx = point[0]
            if not xvals:
                xvals.append(pointx)
            elif pointx != xvals[-1]:
                xvals.append(pointx)
                # We've moved on to the next x-value, so have cycled through all
                # Y points
                firstx = False
            if firstx:
                yvals.append(point[1])
            ztemp.append(point[2])
    except IndexError:
        print('The data parameter should be tuples of (x, y, z) values')
        raise
    if not legend_format:
        legend_format = _get_legend_format(ztemp)
    # Create a 2-D array of zvalues
    numx = len(xvals)
    numy = len(yvals)
    if len(ztemp) != numx * numy:
        raise ValueError('The number of data points should be MxN, where M'
                         'is the number of\nunique X values and N is the '
                         'number of unique Y values')
    zvals = numpy.zeros((len(yvals), len(xvals)), numpy.double)
    for xindex in range(len(xvals)):
        for yindex in range(len(yvals)):
            zvals[yindex, xindex] = ztemp.pop(0)
    # Axes ranges
    x_range = kw.get('x_range', (min(xvals), max(xvals)))
    y_range = kw.get('y_range', (min(yvals), max(yvals)))
    if xreverse:
        x_range = (x_range[1], x_range[0])
    if yreverse:
        y_range = (y_range[1], y_range[0])
    # Plot the contours
    contour_colors = None
    # Filled contours
    if fill:
        if levels:
            contour_fill = plot.contourf(xvals,
                                         yvals,
                                         zvals,
                                         levels=levels,
                                         extend='both',
                                         cmap=color_map)
        elif contours:
            contour_fill = plot.contourf(xvals,
                                         yvals,
                                         zvals,
                                         contours,
                                         extend='both',
                                         cmap=color_map)
        else:
            contour_fill = plot.contourf(xvals,
                                         yvals,
                                         zvals,
                                         extend='both',
                                         cmap=color_map)
        if color_range:
            contour_fill.set_clim(color_range)
        contour_colors = 'k'
    # Plot the contour boundaries
    if levels:
        contour = plot.contour(xvals,
                               yvals,
                               zvals,
                               levels=levels,
                               colors=contour_colors,
                               cmap=color_map)
    elif contours:
        contour = plot.contour(xvals,
                               yvals,
                               zvals,
                               contours,
                               colors=contour_colors,
                               cmap=color_map)
    else:
        contour = plot.contour(xvals,
                               yvals,
                               zvals,
                               extend='both',
                               colors=contour_colors,
                               cmap=color_map)
    if color_range:
        contour_fill.set_clim(color_range)
    # Add boundary labels
    if labels:
        if fill:
            plot.clabel(contour,
                        colors='k',
                        fmt=legend_format,
                        fontsize=FONT_SCALINGS[fontsize])
        else:
            plot.clabel(contour,
                        fontsize=FONT_SCALINGS[fontsize],
                        fmt=legend_format)
    # Add the color bar
    if legend:
        if fill:
            colorbar = figure.colorbar(contour_fill,
                                       shrink=0.8,
                                       format=legend_format,
                                       orientation=legend_orientation)
            cb_axis = colorbar.ax
        else:
            colorbar = figure.colorbar(contour,
                                       shrink=0.8,
                                       format=legend_format,
                                       orientation=legend_orientation)
            cb_axis = colorbar.ax
            lines = colorbar.lines
            lines.set_linewidth(6)
        cb_axis.tick_params(labelsize=fontsize)
    plot.set_xlim(x_range)
    plot.set_ylim(y_range)
[docs]def create_surface_on_figure(figure, data, **kw):
    """
    Create a 2d contour or 3d surface plot.
    The plot is saved as an image file,  the filename parameter should contain
    the path to the file to write.  filename is written and the value of
    filename is simply returned.
    :type figure: matplotlib Figure object
    :param figure: the Figure object the plot should be created on.
    :type data: list of tuples
    :param data: the (x, y, z) values to plot.  List items should be arranged so
        they are sorted by X and then by Y (so that the Y varies faster than X), and
        there is a Z value for every X/Y combination.
    :type x_label: str
    :keyword x_label: X-axis label
    :type y_label: str
    :keyword y_label: Y-axis label
    :type z_label: str
    :keyword z_label: Y-axis label
    :type x_reverse: bool
    :keyword x_reverse: True if the X axis should be reversed, False (default)
        if not
    :type y_reverse: bool
    :keyword y_reverse: True if the Y axis should be reversed, False (default)
        if not
    :type z_reverse: bool
    :keyword z_reverse: True if the Z axis should be reversed, False (default)
        if not
    :type bg: str
    :keyword bg: color name for the plot background.
    :type legend: bool
    :keyword legend: True if a colorbar legend that shows the surface levels
        should be included, False if not (False is default).  Surface only.
    :type legend_format: str
    :keyword legend_format: String format specifier for the colorbar legend.
    :type legend_orientation: str
    :keyword legend_orientation: Either 'vertical' (default) or 'horizontal'
    :type title: str
    :keyword title: the title of the plot
    :type fontsize: int or str
    :keyword fontsize: size in points, or one of the following -
            - xx-small
            - x-small
            - small
            - medium
            - large
            - x-large
            - xx-large
    :type viewpoint: tuple of 2 float
    :param viewpoint: (a, b) describing the point of view from which the plot is
        viewed, where a is the elevation and b is the rotation of the plot.  Default
        is (45, 45).
    """
    title = kw.get('title')
    bg = kw.get('bg', 'white')
    legend = kw.get('legend', False)
    legend_format = kw.get('legend_format', None)
    legend_orientation = kw.get('legend_orientation', 'vertical')
    x_label = kw.get('x_label')
    y_label = kw.get('y_label')
    z_label = kw.get('z_label')
    fontsize = kw.get('fontsize', DEFAULT_FONTSIZE)
    # Note that changing xtick changes all axis font sizes
    matplotlib.rc('xtick', labelsize=fontsize)
    xreverse = kw.get('x_reverse', False)
    yreverse = kw.get('y_reverse', False)
    zreverse = kw.get('z_reverse', False)
    elevation, azimuth = kw.get('viewpoint', (45, 45))
    chart_type = kw.get('chart_type', 'surface')
    figure.set_facecolor(bg)
    xstart = 0.20
    plot = figure.add_axes([.13, .13, .8, .8], projection='3d')
    plot.set_facecolor(bg)
    plot.tick_params(labelsize=fontsize)
    plot.xaxis.set_ticks_position('bottom')
    plot.yaxis.set_ticks_position('left')
    xax = plot.get_xaxis()
    # Add labels
    if title:
        plot.set_title(title, size=fontsize)
    if x_label:
        plot.set_xlabel(x_label, size=fontsize)
    if y_label:
        plot.set_ylabel(y_label, size=fontsize)
    if z_label:
        plot.set_zlabel(z_label, size=fontsize)
    # Unroll the data list.  We need a list of X of length M, a list
    # of Y of length N, and a 2-D array of Z sized MxN
    try:
        xtemp = []
        ytemp = []
        ztemp = []
        firstx = True
        for point in data:
            pointx = point[0]
            if not xtemp:
                xtemp.append(pointx)
            elif pointx != xtemp[-1]:
                xtemp.append(pointx)
                # We've moved on to the next x-value, so have cycled through all
                # Y points
                firstx = False
            if firstx:
                ytemp.append(point[1])
            ztemp.append(point[2])
    except IndexError:
        print('The data parameter should be tuples of (x, y, z) values')
        raise
    if not legend_format:
        legend_format = _get_legend_format(ztemp)
    # Axes ranges
    if 'x_range' in kw or 'y_range' in kw or 'z_range' in kw:
        # Note that setting the Z-axis range doens't appear to work at all,
        # while setting the X & Y ranges makes the plot go haywire.
        raise ValueError('Modifying axis ranges is not allowed for surface or'
                         '\n   wireframe plots due to matplotlib limitations')
    x_range = (min(xtemp), max(xtemp))
    y_range = (min(ytemp), max(ytemp))
    z_range = (min(ztemp), max(ztemp))
    if xreverse:
        x_range = (x_range[1], x_range[0])
    if yreverse:
        y_range = (y_range[1], y_range[0])
    if zreverse:
        z_range = (z_range[1], z_range[0])
    # Create a 2-D array of zvalues
    numx = len(xtemp)
    numy = len(ytemp)
    if len(ztemp) != numx * numy:
        raise ValueError('The number of data points should be MxN, where M'
                         'is the number of\nunique X values and N is the '
                         'number of unique Y values')
    zvals = numpy.zeros((len(ytemp), len(xtemp)), numpy.double)
    for xindex in range(len(xtemp)):
        for yindex in range(len(ytemp)):
            zvals[yindex, xindex] = ztemp.pop(0)
    xvals, yvals = numpy.meshgrid(xtemp, ytemp)
    # Plot the contours
    if chart_type == 'wireframe':
        surface = plot.plot_wireframe(xvals, yvals, zvals, cmap=mpl_cmap.jet)
    else:
        surface = plot.plot_surface(xvals, yvals, zvals, cmap=mpl_cmap.jet)
        if legend:
            colorbar = figure.colorbar(surface,
                                       shrink=0.8,
                                       format=legend_format,
                                       orientation=legend_orientation)
            cb_axis = colorbar.ax
            cb_axis.tick_params(labelsize=fontsize)
    # Set the axis ranges
    #plot.set_xlim3d(x_range)
    #plot.set_ylim3d(y_range)
    #plot.set_zlim3d(z_range)
    plot.view_init(elevation, azimuth)
if ("__main__" == __name__):
    import os
    import sys
    from schrodinger.Qt import QtWidgets
    # Check for a PyQt application instance and create one if needed:
    #app = QtWidgets.QApplication([])
    if len(sys.argv) != 2:
        print("Usage: $SCHRODINGER/run %s <data-file>" % sys.argv[0])
        sys.exit(0)
    if (not os.path.isfile(sys.argv[1])):
        print("Data file not found: %s" % sys.argv[1])
        sys.exit(0)
    lines = open(sys.argv[1], "r").read().split("\n")
    ### XY plot testing
    #x     = []
    #y     = []
    #y_err = []
    #for line in lines :
    #    line = line.strip()
    #    if (line != "" and line[0] != "#") :
    #        token = line.split()
    #        x    .append( float( token[0] ) )
    #        y    .append( float( token[1] ) )
    #y_err.append( float( token[2] ) )
    ##plot = get_xy_plot(x, y, err_y = [y_err], x_label="time (ps)",
    ##y_label='Y values', legend=['Test', 'Test2'],
    ##filename='test.png')
    ##plot = get_xy_plot(x, y, err_y = [y_err], x_label="time (ps)",
    #plot = get_xy_plot(x, y, x_label="time (ps)",
    #y_label='Y values', legend=['Test', 'Test2'],
    #filename='test2.png')
    ##frame = get_xy_plot_widget(x, y, err_y = [y_err], x_label="time (ps)",
    #frame = get_xy_plot_widget(x, y, x_label="time (ps)",
    #                            y_label='Y values', legend=['Test', 'Test2'])
    ### End XY plot testing
    ### Contour plot testing
    data = []
    for line in lines:
        line = line.strip()
        if (line != "" and line[0] != "#"):
            tokens = line.split()
            linelist = []
            for atoken in tokens:
                try:
                    linelist.append(int(atoken))
                except ValueError:
                    linelist.append(float(atoken))
            data.append(tuple(linelist))
    plot = get_2var_plot(data,
                         fill=True,
                         legend=True,
                         labels=False,
                         legend_orientation='horizontal',
                         x_label='bob',
                         y_label='joe',
                         title='jim',
                         viewpoint=(45, 45),
                         size=(500, 500),
                         z_label='jeff',
                         filename='contour.png',
                         chart_type='surface')
    #filename='contour.png', levels=[15, 20, 25, 30])
#plot = get_2var_plot(data, fill=True, legend=True, labels=False,
#legend_orientation='vertical', x_label='CV1',
#y_label='CV2', size=(600, 400), filename='ev.png')
#frame.show()
#app.exec_()