import itertools
import math
import re
from collections import deque
from past.utils import old_div
from schrodinger.Qt import QtCore
from schrodinger.Qt.QtWidgets import QCheckBox
from schrodinger.Qt.QtWidgets import QComboBox
from schrodinger.Qt.QtWidgets import QDoubleSpinBox
from schrodinger.Qt.QtWidgets import QHBoxLayout
from schrodinger.Qt.QtWidgets import QLabel
from schrodinger.Qt.QtWidgets import QVBoxLayout
from schrodinger.Qt.QtWidgets import QWidget
from schrodinger.ui import dendrogram as dendro
[docs]def makeListOf(input):
    if isinstance(input, str):
        input = [input]
    else:
        try:
            iter(input)
        except TypeError:
            input = [input]
        else:
            input = list(input)
    return input 
[docs]class DendrogramSyncer:
[docs]    def __init__(self):
        self._dendrograms = []
        self._views = [] 
[docs]    def addTree(self, tree):
        self._dendrograms.append(tree)
        tree.m_graphicTree.aboutToUpdatePositions.connect(
            self.respondToUpdatePositions)
        tree.m_graphicTree.aboutToUpdatePositions.connect(
            self.respondToChangeSelection)
        tree.m_graphicTree.selectionChanged.connect(
            lambda: self.respondToChangeSelection(tree.m_graphicTree)) 
[docs]    def removeTree(self, tree):
        self._dendrograms.remove(tree) 
[docs]    def respondToChangeSelection(self, graphicTree):
        if (len(self._dendrograms) < 2):
            return
        modifiedGraphicTree = graphicTree
        otherTrees = [
            x.m_graphicTree
            for x in self._dendrograms
            if x.m_graphicTree != modifiedGraphicTree
        ]
        for tree in otherTrees:
            tree.copySelectionFrom(graphicTree) 
[docs]    def respondToUpdatePositions(self, graphicTree):
        if (len(self._dendrograms) < 2):
            return
        modifiedGraphicTree = graphicTree
        otherTrees = [
            x.m_graphicTree
            for x in self._dendrograms
            if x.m_graphicTree != modifiedGraphicTree
        ]
        iterators = []
        iterators.append((modifiedGraphicTree,
                          dendro.DendrogramNodeBFIterator(
                              modifiedGraphicTree.getFirstNode())))
        for tree in otherTrees:
            iterators.append(
                (tree, dendro.DendrogramNodeBFIterator(tree.getFirstNode())))
        while (iterators[0][1].node() is not None):
            graphNode = modifiedGraphicTree.getNode(iterators[0][1].node())
            for otherTree in iterators[1:]:
                otherGraphNode = otherTree[0].getNode(otherTree[1].node())
                otherGraphNode.setPos(graphNode.pos())
                otherGraphNode.setCoordinates(graphNode.getCoordinates())
                otherGraphNode.setVisible(graphNode.isVisible())
                otherGraphNode.getLabel().setVisible(
                    graphNode.getLabel().isVisible())
                otherGraphNode.m_hideAfterAnimation = graphNode.m_hideAfterAnimation
            for it in iterators:
                it[1].next()  # nopy3migrate
        for tree in otherTrees:
            tree.updatePositions(False)  
[docs]class NewickNode:
    """
    Temporary node structure.
    """
[docs]    def __init__(self):
        self.children = []
        self.parent = None
        self.leaf = []
        self.label = ''
        self.distance = 1 
[docs]    def addChild(self, node):
        self.children.append(node)
        node.parent = self  
[docs]def assignLeaves(node):
    """
    Walk the tree from leaves to root, storing on each node all the leaves of its leaves
    """
    nodes = []
    queue = deque([])
    queue.append(node)
    while (queue):
        thisNode = queue.popleft()
        for child in thisNode.children:
            queue.append(child)
        nodes.append(thisNode)
    for node in reversed(nodes):
        if (node.parent):
            node.parent.leaf += node.leaf 
[docs]def buildTree(node):
    """
    This recursive function builds hierarchy of `DendrogramNode` nodes
    by traversing passed tree, and returns dendrogram root node.
    """
    child_nodes = []
    for child in node.children:
        child_nodes.append(buildTree(child))
    dendro_node = dendro.DendrogramNode()
    dendro_node.setDistance(node.distance)
    dendro_node.setLabel(node.label)
    for child in child_nodes:
        dendro_node.addChild(child)
    for leaf in node.leaf:
        dendro_node.setLeaf(leaf)
    return dendro_node 
[docs]def parseNewickString(dndstring):
    """
    Parses a Newick-formatted string and generates dendrogram tree.
    :rtype: `DendroGraphicalTree`
    :return: Graphic dendrogram tree.
    """
    tree = root = None
    leaf_count = 0
    # parse DND string
    while dndstring:
        c = dndstring[0]
        if c.isspace() or c == ',':  # skip whitespace
            pass
        elif c == ';':  # end of string
            break
        elif c == '(':  # open a branch
            node = NewickNode()
            if root is None:  # store root node
                root = node
            if tree:
                tree.addChild(node)
            node.parent = tree
            tree = node
        elif c == ')':  # close a branch
            if len(dndstring) > 2 and dndstring[1] == ':':
                # extract branch distance (this should be split out
                # into a function)
                end = re.search(r',|\)|;', dndstring[1:]).start()
                dist = dndstring[2:end + 1]
                tree.distance = float(dist)
                dndstring = dndstring[end:]
            if tree:
                tree = tree.parent  # get a parent node
        else:  # parse branch
            # get text between this position and ')' or ','
            end = re.search(r',|\)|;', dndstring).start()
            label_dist = dndstring[:end].split(
                ':')  # extract label and distance
            label = label_dist[0]
            dist = 0.0
            if len(label_dist) > 1:
                dist = float(label_dist[1])
            # Create a new leaf node.
            node = NewickNode()
            node.label = label
            node.leaf.append(leaf_count)
            #   node.distance = dist
            leaf_count += 1
            tree.addChild(node)
            dndstring = dndstring[end - 1:]
        # process next character
        dndstring = dndstring[1:]
    assignLeaves(root)
    # traverse the tree bottom-up and build dendrogram
    dendrogram = buildTree(root)
    return dendrogram 
[docs]class GuiPanel(QWidget):
[docs]    def __init__(self, dendr):
        self._dendro = dendr
        self._preferences = dendro.DendrogramGraphicTreeCoordinatesPreferences()
        super(GuiPanel, self).__init__()
        super(GuiPanel, self).setLayout(QVBoxLayout())
        self.combobox = QComboBox()
        self.combobox.addItem("linear", "linear")
        self.combobox.addItem("logarithmic", "log")
        self.rectCB = QCheckBox("Rectangular")
        self.sr = QCheckBox("Self-reorganize")
        self.layout().addWidget(self.combobox)
        self.combobox.currentIndexChanged.connect(self.updateCoordinates)
        self.sf = self.addFloatSpinBox("scale factor", 0, 100000)
        self.ml = self.addFloatSpinBox("minimum length")
        self.cl = self.addFloatSpinBox("cutoff length")
        self.ta = self.addFloatSpinBox("total angle", 0, 360)
        self.ct = self.addFloatSpinBox("color threshold", 0, 100000)
        self.ta.setValue(360)
        self.sf.setValue(100)
        self.cl.setValue(200)
        self.ml.setValue(30)
        self.bt = self.addFloatSpinBox("branch thickness", 0, 8.0)
        self.bt.valueChanged.connect(self.changeThickness)
        self.bt.setValue(0)
        self.ct.valueChanged.connect(self.colorByThreshold)
        self.layout().addWidget(self.rectCB)
        self.rectCB.stateChanged.connect(self.updateCoordinates)
        self.layout().addWidget(self.sr)
        self.sr.stateChanged.connect(self.updateCoordinates)
        self.changeThickness()
        self.updateCoordinates() 
[docs]    def addFloatSpinBox(self, name, min=0, max=1000):
        widget = QWidget()
        widget.setLayout(QHBoxLayout())
        self.layout().addWidget(widget)
        widget.layout().addWidget(QLabel(name))
        spinbox = QDoubleSpinBox()
        spinbox.setRange(min, max)
        spinbox.valueChanged.connect(self.updateCoordinates)
        widget.layout().addWidget(spinbox)
        return spinbox 
[docs]    def changeThickness(self):
        dendros = makeListOf(self._dendro)
        for tree in dendros:
            nodeIter = dendro.DendrogramNodeBFIterator(
                tree.m_graphicTree.getFirstNode())
            while (nodeIter.node() is not None):
                node = tree.m_graphicTree.getNode(nodeIter.node())
                if (self.bt.value() < 1):
                    m = 0.5
                    M = 6
                    width = max(old_div(len(node.getLeafIndices()), 5.), m)
                    width = min(M, width)
                    node.setBranchWidth(width)
                else:
                    node.setBranchWidth(self.bt.value())
                nodeIter.next()  # nopy3migrate 
[docs]    def changeScaling(self):
        self._preferences.logScale = (self.combobox.currentIndex() == 1)
        self._preferences.scaleFactor = self.sf.value()
        self._preferences.minimumLength = self.ml.value()
        self._preferences.cutoffLength = self.cl.value()
        self._preferences.totalAngle = self.ta.value() * math.pi / 180
        self._preferences.rectangular = self.rectCB.isChecked()
        self._preferences.selfReorganize = self.sr.isChecked() 
[docs]    def updateCoordinates(self):
        self.changeScaling()
        dendros = makeListOf(self._dendro)
        for tree in dendros:
            tree.m_graphicTree.initializeCoordinates(self._preferences)
            tree.m_graphicTree.updatePositions() 
[docs]    def colorByThreshold(self, threshold):
        dendros = makeListOf(self._dendro)
        for tree in dendros:
            tree.m_graphicTree.assignColorsWithThreshold(threshold)  
[docs]class Dendrogram(QtCore.QObject):
    selectionChanged = QtCore.pyqtSignal(set)
[docs]    def __init__(self, ViewCls=None):
        super().__init__()
        if ViewCls is None:
            ViewCls = dendro.DendrogramView
        self._ViewCls = ViewCls
        self.leaves = []
        self.m_graphicTree = None 
[docs]    def nodes(self):
        return self.m_graphicTree.listAllNodes() 
[docs]    def setLeaves(self, key, lis):
        self.initializeLeaves(lis)
        self.addLeafProperty(key, lis) 
[docs]    def isVisible(self):
        return self.m_graphicTree is not None and self.m_graphicTree.isVisible() 
[docs]    def initializeLeaves(self, lis):
        self.leaves = []
        for entry in lis:
            self.leaves.append({}) 
[docs]    def addLeafProperty(self, key, lis):
        if (len(lis) == len(self.leaves)):
            for index, entry in enumerate(lis):
                self.leaves[index][key] = entry 
[docs]    def getProperty(self, key, index):
        return self.leaves[index][key] 
[docs]    def loadTreeFromNewick(self, dndstring):
        self.m_tree = parseNewickString(dndstring) 
[docs]    def buildTree(self, similarityFunction):
        builder = dendro.DendrogramTreeBuilder()
        size = len(self.leaves)
        builder.initializeWithSize(size)
        for combination in itertools.combinations(list(range(size)), 2):
            element1 = self.leaves[combination[0]]
            element2 = self.leaves[combination[1]]
            distance = similarityFunction(element1, element2)
            builder.setDistance(*(combination + (distance,)))
        self.m_tree = builder.buildTree() 
[docs]    def runOnEachGraphicNode(self, function):
        for node in self.nodes():
            function(self, node) 
[docs]    def showTree(self, initializeFunction, pref=None):
        """
        Show the tree as per the options in the pref.
        :param initializeFunction: Method to call on all node.
        :type initializeFunction: Method
        :param pref: options to display the dendrogram
        :type pref: schrodinger.ui.dendrogram.DendrogramGraphicTreeCoordinatesPreferences
        :return: Dendrogram view.
        :rtype: schrodinger.application.msv.gui.dendrogram_viewer.MSVDendrogramView
        """
        self.m_graphicTree = dendro.DendrogramGraphicTree(self.m_tree)
        self.runOnEachGraphicNode(initializeFunction)
        if not pref:
            pref = dendro.DendrogramGraphicTreeCoordinatesPreferences()
        self.m_graphicTree.initializeCoordinates(pref)
        self.m_graphicTree.updatePositions(True, False)
        self.m_view = self._ViewCls()
        self.m_view.getScene().drawTree(self.m_graphicTree)
        return self.m_view 
[docs]    def showLeaves(self, number):
        nodes = self.m_graphicTree.listAllNodes()
        nodes[0].show()
        shownLeaves = 1
        for node in nodes:
            children = node.getChildren()
            if not len(children):
                continue
            isVisible = (shownLeaves + len(children) - 1 <= number)
            if (isVisible):
                shownLeaves += len(children) - 1
            for child in children:
                child.setVisible(isVisible)
        self.m_graphicTree.displayLabels() 
[docs]    def getoptionsGUI(self):
        widget = GuiPanel(self)
        widget.show()
        return widget