"""
Functions and classes for rendering a Route as an image.
"""
import collections
import numpy
from rdkit import Chem
_Img = collections.namedtuple('_Img', 'size objs')
"""
An _Img has a size tuple (width, height) in pixels, and a list of "objects",
each of which is a 3-tuple (x, y, obj), where obj can be either a QPicture or a
string.
"""
# Special-case strings for _Img rendering.
ARROW = '->'
PLUS = '+'
MISSING_SMILES_WARNING = (
    "Warning: SMILES pattern is missing for at least one node. Molecules with\n"
    "         missing SMILES will be represented by the '*' placeholder")
[docs]class RouteRenderer(object):
    """
    A class for rendering a Route as an image.
    Note: this class only supports "specific routes", for which all nodes have a
    'mol' attribute. Generic routes, in which nodes might have only a reagent
    class, are not supported yet.
    """
[docs]    def __init__(self,
                 plus_size=50,
                 arrow_length=250,
                 arrowhead_size=25,
                 label_padding=25,
                 label_font=("helvetica", 30),
                 plus_font=("helvetica", 48),
                 max_scale=0.3,
                 mol_padding=50,
                 scale=0.375,
                 logger=None):
        """
        Configure the RouteRenderer with various optional parameters controlling
        the layout of the route. All sizes are in pixels.
        :param plus_size: width (including padding) taken by the plus signs.
        """
        # Import Qt modules here because the import fails if certain libraries
        # are missing and route rendering is an optional feature.
        global QtCore, QtGui
        from schrodinger.Qt import QtCore
        from schrodinger.Qt import QtGui
        self.plus_size = plus_size
        self.arrow_length = arrow_length
        self.arrowhead_size = arrowhead_size
        self.label_padding = label_padding
        self.label_font = QtGui.QFont(*label_font)
        self.plus_font = QtGui.QFont(*plus_font)
        self.max_scale = max_scale
        self.mol_padding = mol_padding
        self.scale = scale
        self.logger = logger
        self.missing_smiles_found = False 
[docs]    def renderToFile(self, route, filename):
        """
        Render a Route and save the image to a file.
        :type route: Route
        :type filename: str
        """
        img = self._arrangeRoute(route)
        # Paste all the objects into the right locations in a QPicture.
        painter = QtGui.QPainter()
        qpic = QtGui.QPicture()
        painter.begin(qpic)
        for x, y, obj in img.objs:
            if obj == ARROW:
                self._drawArrow(painter, x, y)
            elif isinstance(obj, str):
                font = self.plus_font if obj == PLUS else self.label_font
                painter.setFont(font)
                painter.drawText(x, y, obj)
            else:
                r = obj.boundingRect()
                painter.drawPicture(x - r.left() + self.mol_padding,
                                    y - r.top() + self.mol_padding, obj)
        painter.end()
        # Rasterize into a QImage and write to file.
        image_size = [int(s * self.scale) for s in img.size]
        qimg = QtGui.QImage(QtCore.QSize(*image_size),
                            QtGui.QImage.Format_RGB32)
        qimg.fill(QtCore.Qt.white)
        painter.begin(qimg)
        painter.scale(self.scale, self.scale)
        qpic.play(painter)
        painter.end()
        qimg.save(filename) 
    def _arrangeRoute(self, route):
        """
        Recursively arrange a route as a tree, from top to bottom.
        :type route: Route
        :return: Route image
        :rtype: _Img
        """
        precursor_imgs = [self._arrangeRoute(p) for p in route.precursors]
        if route.mol:
            mol = route.mol
        elif hasattr(route, 'smiles_list') and route.smiles_list:
            mol = Chem.MolFromSmiles(route.smiles_list[0])
        else:
            if self.logger and not self.missing_smiles_found:
                self.missing_smiles_found = True
                self.logger.warning(MISSING_SMILES_WARNING)
            mol = Chem.MolFromSmiles('*')
        pic = mol_to_qpicture(mol)
        r = pic.boundingRect()
        mol_size = (r.width() + 2 * self.mol_padding,
                    r.height() + 2 * self.mol_padding)
        size = self._computeSize(precursor_imgs, mol_size)
        img = _Img(size, [])
        img.objs.append(((img.size[0] - mol_size[0]) // 2, 0, pic))
        if precursor_imgs:
            self._arrangePrecursors(img, precursor_imgs, mol_size)
            img.objs.append((size[0] // 2, mol_size[1], ARROW))
            self._addLabel(img, route.reaction_instance.name, mol_size)
        return img
    def _computeSize(self, precursor_imgs, mol_size):
        if not precursor_imgs:
            return mol_size
        precursor_height = max(i.size[1] for i in precursor_imgs)
        width = sum(i.size[0] for i in precursor_imgs) + self.plus_size * (
            len(precursor_imgs) - 1)
        height = precursor_height + mol_size[1] + self.arrow_length
        return (width, height)
    def _arrangePrecursors(self, img, precursor_imgs, mol_size):
        # Compute the average height of the immediate precursors, which will
        # be used to determine the y coordinate of the plus signs.
        mol_heights = []
        for prec_img in precursor_imgs:
            for o_x, o_y, obj in prec_img.objs:
                if o_y == 0:
                    # Objects at the top are the depictions of the immediate
                    # precursors.
                    r = obj.boundingRect()
                    mol_heights.append(r.height())
        mean_mol_height = int(numpy.mean(mol_heights))
        # Translate all precursor objects into the new frame of reference and
        # add the plus signs.
        x = 0
        y = mol_size[1] + self.arrow_length
        for prec_img in precursor_imgs:
            for o_x, o_y, obj in prec_img.objs:
                img.objs.append((o_x + x, o_y + y, obj))
            x += prec_img.size[0]
            plus_y = y + mean_mol_height // 2 + self.mol_padding
            if x < img.size[0]:
                img.objs.append([x, plus_y, PLUS])
            x += self.plus_size
    def _drawArrow(self, painter, x, y):
        painter.setRenderHint(QtGui.QPainter.Antialiasing)
        painter.setPen(QtGui.QPen(QtCore.Qt.black, 3))
        painter.drawLine(x, y, x, y + self.arrow_length)
        painter.drawLine(x, y, x + self.arrowhead_size, y + self.arrowhead_size)
        painter.drawLine(x, y, x - self.arrowhead_size, y + self.arrowhead_size)
    def _addLabel(self, img, label, mol_size):
        x = img.size[0] // 2 + self.label_padding
        y = mol_size[1] + self.arrow_length // 2
        img.objs.append((x, y, label)) 
[docs]def mol_to_qpicture(mol):
    """
    Generate a QPicture from an RDKit Mol.
    :param mol: molecule to render
    :type target_mol: Mol
    :rtype: QPicture
    """
    from schrodinger.infra import canvas2d
    smiles = Chem.MolToSmiles(mol)
    chmmol = canvas2d.ChmMol.fromSMILES(smiles)
    render_model = canvas2d.ChmRender2DModel()
    renderer = canvas2d.Chm2DRenderer(render_model)
    pic = renderer.getQPicture(chmmol)
    return pic