import csv
from past.utils import old_div
from schrodinger.Qt import QtCore
from schrodinger.Qt import QtGui
from schrodinger.Qt import QtWidgets
from schrodinger.ui.qt import filedialog
from schrodinger.utils import csv_unicode
from . import compare_sequences_ui
from .sequence import Sequence
[docs]class CompareSequencesDialog(QtWidgets.QDialog):
    """
    Create a panel.
    """
[docs]    def __init__(self, parent):
        """
        The parent is expected to be an instance of SequenceViewer.
        """
        QtWidgets.QDialog.__init__(self, parent)
        self.viewer = parent
        self.ui = compare_sequences_ui.Ui_Dialog()
        self.ui.setupUi(self)
        self.createColorTable()  #widget displays a color legend
        self.ui.refreshButton.clicked.connect(self.populateTable)
        self.ui.closeButton.clicked.connect(self.close)
        self.ui.sequenceTable.verticalHeader().sectionClicked.connect(
            self.headerClicked)
        self.ui.sequenceTable.horizontalHeader().sectionClicked.connect(
            self.headerClicked)
        self.ui.identityRadioButton.clicked.connect(self.populateTable)
        self.ui.similarityRadioButton.clicked.connect(self.populateTable)
        self.ui.homologyRadioButton.clicked.connect(self.populateTable)
        self.ui.columnsCheckBox.clicked.connect(self.populateTable)
        self.ui.exportCSVButton.clicked.connect(self.exportCSV)
        self.ui.exportImageButton.clicked.connect(self.exportImage) 
[docs]    def show(self):
        """
        Show the dialog and populate its contents.
        """
        self.populateTable()
        QtWidgets.QDialog.show(self) 
[docs]    def createColorTable(self):
        """
        Create color scale (blue to red, half saturation).
        """
        # Color table widget displays a color legend
        #self.color_table = QtWidgets.QTableWidget(1, 21)
        self.color_scale = []
        for test_id in range(21):
            color = QtGui.QColor()
            color.setHsvF(0.66 * float(20 - test_id) / 20.0, 0.5, 1.0)
            self.color_scale.append(color)
            item = QtWidgets.QTableWidgetItem("%3d" % int(5.0 * test_id))
            self.ui.colorTable.setItem(0, test_id, item)
            item.setBackground(color)
        # Set color table properties
        self.ui.colorTable.horizontalHeader().hide()
        self.ui.colorTable.verticalHeader().hide()
        self.ui.colorTable.resizeColumnsToContents()
        self.ui.colorTable.resizeRowsToContents()
        self.ui.colorTable.setHorizontalScrollBarPolicy(
            QtCore.Qt.ScrollBarAlwaysOff)
        width = 21 * self.ui.colorTable.columnWidth(0)
        height = self.ui.colorTable.rowHeight(0)
        self.ui.colorTable.setMinimumWidth(width)
        self.ui.colorTable.setMaximumWidth(width)
        self.ui.colorTable.setMinimumHeight(height)
        self.ui.colorTable.setMaximumHeight(height) 
[docs]    def calculatePairwiseIdentity(self, id_function, column_mode):
        """
        Iterates over sequence pairs and calculate
        sequence identity between the sequences.
        :type id_function: method of Sequence class
        :param id_function: Method that calculate sequence identity
            between two instances of Sequence class.
        :type column_mode: boolean
        :param column_mode: Optional column mode parameter.
        """
        # Iterate over sequence pairs and calculate
        # sequence identity between the sequences
        for idx1, seq1 in enumerate(self.sequences):
            for idx2 in range(idx1):
                seq2 = self.sequences[idx2]
                id_value = id_function(seq1, seq2, True, column_mode)
                item1 = QtWidgets.QTableWidgetItem("%3.0f" %
                                                   float(100.0 * id_value))
                item1.setBackground(
                    QtGui.QBrush(self.color_scale[int(id_value * 20.0)]))
                item1.setTextAlignment(QtCore.Qt.AlignCenter)
                self.ui.sequenceTable.setItem(idx1, idx2, item1)
                # Create a symmetrical table item
                item2 = QtWidgets.QTableWidgetItem(item1)
                item2.setTextAlignment(QtCore.Qt.AlignCenter)
                self.ui.sequenceTable.setItem(idx2, idx1, item2)
        # Populate a diagonal
        if column_mode:
            for idx, seq in enumerate(self.sequences):
                id_value = id_function(seq, seq, True, True)
                item = QtWidgets.QTableWidgetItem("%3.0f" %
                                                  float(100.0 * id_value))
                item.setBackground(
                    QtGui.QBrush(self.color_scale[int(id_value * 20.0)]))
                item.setTextAlignment(QtCore.Qt.AlignCenter)
                self.ui.sequenceTable.setItem(idx, idx, item)
        else:
            for idx in range(len(self.sequences)):
                item = QtWidgets.QTableWidgetItem("100")
                item.setBackground(QtGui.QBrush(self.color_scale[20]))
                item.setTextAlignment(QtCore.Qt.AlignCenter)
                self.ui.sequenceTable.setItem(idx, idx, item)
        self.ui.sequenceTable.resizeColumnsToContents()
        self.ui.sequenceTable.resizeRowsToContents() 
[docs]    def populateTable(self):
        """
        Rebuilds the contents of the table.
        """
        self.createHeaderLabels()
        # Determine calculation mode (identity, similarity or homology)
        id_function = Sequence.calcIdentity
        if self.ui.similarityRadioButton.isChecked():
            id_function = Sequence.calcSimilarity
        elif self.ui.homologyRadioButton.isChecked():
            id_function = Sequence.calcHomology
        column_mode = False
        if self.ui.columnsCheckBox.isChecked():
            column_mode = True
        self.calculatePairwiseIdentity(id_function, column_mode) 
[docs]    def exportCSV(self):
        """
        Exports table contents as CSV file.
        """
        if not self.ui.sequenceTable.columnCount():
            return
        file_name = filedialog.get_save_file_name(self, "Export as CSV", ".",
                                                  "CSV Files (*.csv)")
        if file_name:
            table = self.ui.sequenceTable
            with csv_unicode.writer_open(file_name) as csv_file:
                writer = csv.writer(csv_file)
                names = []
                for column in range(table.columnCount()):
                    names.append(str(table.horizontalHeaderItem(column).text()))
                if self.ui.similarityRadioButton.isChecked():
                    line = ['Similarity']
                elif self.ui.homologyRadioButton.isChecked():
                    line = ['Homology']
                else:
                    line = ['Identity']
                line.extend(names)
                writer.writerow(line)
                for row in range(table.rowCount()):
                    line = [names[row]]
                    for column in range(table.columnCount()):
                        line.append(table.item(row, column).text())
                    writer.writerow(line) 
[docs]    def exportImage(self):
        """
        Exports table contents as 2x supersampled PNG image.
        """
        if not self.ui.sequenceTable.columnCount():
            return
        file_name = filedialog.get_save_file_name(self, "Export as PNG", ".",
                                                  "PNG Files (*.png)")
        if file_name:
            table = self.ui.sequenceTable
            total_width, total_height = self.resizeTable()
            # table to render to image
            image_size = QtCore.QSize(total_width, total_height)
            pixmap = QtGui.QPixmap(image_size)
            rect = self.ui.sequenceTable.geometry()
            table.setGeometry(0, 0, total_width, total_height)
            table.render(pixmap)
            table.setGeometry(rect)
            self.restoreTable()
            # save image to PNG file
            image = pixmap.toImage()
            image.save(file_name, "PNG") 
[docs]    def resizeTable(self):
        """
        Temporarily resizes the table by rescaling width and height by 2.
        :return: (width, height) tuple of rescaled table dimensions.
        """
        table = self.ui.sequenceTable
        self.original_widths = []
        self.original_heights = []
        # Resize table headers, cells and fonts
        table.verticalHeader().setFixedWidth(
            2 * table.verticalHeader().sizeHint().width())
        font = table.verticalHeader().font()
        font.setPointSize(2 * font.pointSize())
        table.verticalHeader().setFont(font)
        table.horizontalHeader().setFixedHeight(
            2 * table.horizontalHeader().sizeHint().height())
        font = table.horizontalHeader().font()
        font.setPointSize(2 * font.pointSize())
        table.horizontalHeader().setFont(font)
        total_width = table.verticalHeader().width()
        for column in range(table.columnCount()):
            width = table.columnWidth(column)
            self.original_widths.append(width)
            table.setColumnWidth(column, 2 * width)
            total_width += 2 * width
        total_height = table.horizontalHeader().height()
        for row in range(table.rowCount()):
            height = table.rowHeight(row)
            self.original_heights.append(height)
            table.setRowHeight(row, 2 * height)
            total_height += 2 * height
        for row in range(table.rowCount()):
            for column in range(table.columnCount()):
                item = table.item(row, column)
                font = item.font()
                font.setPointSize(2 * font.pointSize())
                item.setFont(font)
        total_width += 1
        total_height += 1
        return (total_width, total_height) 
[docs]    def restoreTable(self):
        """
        Restores original table dimensions.
        """
        table = self.ui.sequenceTable
        # restore original cell dimensions
        for column in range(table.columnCount()):
            table.setColumnWidth(column, self.original_widths[column])
        for row in range(table.rowCount()):
            table.setRowHeight(row, self.original_heights[row])
        # restore table items font size
        for row in range(table.rowCount()):
            for column in range(table.columnCount()):
                item = table.item(row, column)
                font = item.font()
                font.setPointSize(old_div(font.pointSize(), 2))
                item.setFont(font)
        # restore original header dimensions and font
        table.verticalHeader().setFixedWidth(
            old_div(table.verticalHeader().sizeHint().width(), 2))
        font = table.verticalHeader().font()
        font.setPointSize(old_div(font.pointSize(), 2))
        table.verticalHeader().setFont(font)
        table.horizontalHeader().setFixedHeight(
            old_div(table.horizontalHeader().sizeHint().height(), 2))
        font = table.horizontalHeader().font()
        font.setPointSize(old_div(font.pointSize(), 2))
        table.horizontalHeader().setFont(font)