# -*- coding: utf-8 -*-
"""
A FEP configuration dialog.  It should not be run directly, but is instead used
by the specific Desmond scripts.
Copyright Schrodinger, LLC. All rights reserved.
"""
import sys
from schrodinger.application.desmond import feputils
from schrodinger.application.desmond import platforms
from schrodinger.Qt import QtGui
from schrodinger.Qt import QtWidgets
from schrodinger.ui.qt import config_dialog
from schrodinger.ui.qt.appframework2 import af2
from schrodinger.ui.qt.appframework2.settings import generate_preference_key
REST_DEFAULT_NUM_CPUS = 12
REST_DEFAULT_NUM_GPUS = 4
# Which widget to show in config dialog
CPU_LAYOUT = 0
GPU_LAYOUT = 1
[docs]class FEPConfigDialog(config_dialog.ConfigDialog):
    HOST_LABEL_TEXT = "CPU Host:"
    MAX_SUBJOBS_LABEL_TEXT = "Maximum simultaneous FEP subjobs:"
[docs]    def __init__(self,
                 parent,
                 title="",
                 jobname="",
                 checkcommand=None,
                 use_rest=True,
                 per_subjob=None,
                 single_gpu=False,
                 ff_builder_enabled=False,
                 **kw):
        """
        :param use_rest: Specifies whether this is a FEP/REST job or not.
                Setting this to False will hide "Replica" options.
                (This was done for backwards compatibility).
        :type use_rest: bool
        :param per_subjob: Whether to show per job CPUs field, or Total CPUs
                field. By default, per_subjob = not use_rest.
        :type per_subjob: bool
        :param single_gpu: Whether to allow only a single GPU
        :type single_gpu: bool
        """
        # Get host list
        self.hosts = self.getHosts()
        self.subhost_menu = QtWidgets.QComboBox(parent)
        self._single_gpu = single_gpu
        self.setupSubHostCombo(self.subhost_menu)
        # Establish key for saving subhost selection persistently
        self.last_subhost_prefkey = generate_preference_key(
            parent, 'last-subhost')
        self.use_rest = bool(use_rest)
        self.num_cpus_sw = None  # so we can call the initialize super
        super().__init__(parent, title, jobname, checkcommand, **kw)
        self.subhost_menu.currentIndexChanged.connect(self.updateMaxjobsDefault)
        self.job_group.setTitle("Main Job")
        self.host_menu_layout.addStretch()
        self.maxjobs_ef = self.addNumericLineEdit(
            self.host_menu_layout,
            value=0,
            prelabel=self.MAX_SUBJOBS_LABEL_TEXT)
        self.host_menu_layout.addStretch()
        self.setupSubHostLayout()
        # Create a label and keep a reference to it so that we can update it
        # below
        self.num_cpus_sb_postlabel = QtWidgets.QLabel("")
        # If per_subjob is not specified, default to total #cpus if use_rest
        # is True, and per job #cpus otherwise.
        if per_subjob is not None:
            self.per_subjob = per_subjob
        else:
            self.per_subjob = not use_rest
        if use_rest:
            self.replica_layout = QtWidgets.QHBoxLayout()
            self.replica_lbl = QtWidgets.QLabel("Replica:")
            self.replica_ef = self.addNumericLineEdit(self.replica_layout,
                                                      prelabel=self.replica_lbl)
            self.replica_layout.addStretch()
            cpus_layout = self.replica_layout
            self.replica_ef.setDisabled(True)
            self.replica_lbl.setDisabled(True)
            self.replica_ef.setText('12')
        else:
            cpus_layout = self.subhost_layout
        prelabel = "Use:" if per_subjob else "Total:"
        self.num_cpus_sw = self.addSubprocessStackedWidget(
            cpus_layout,
            prelabel=prelabel,
            postlabel=self.num_cpus_sb_postlabel)
        self.num_cpus_sb = self.num_cpus_sw.currentWidget()
        if use_rest:
            self.num_cpus_sw.widget(CPU_LAYOUT).setValue(12)
        self.updateNumCPUsLabel()
        self.updateMaxjobsDefault()
        self.ffb_enabled = ff_builder_enabled
        self.last_ffb_host_prefkey = generate_preference_key(parent, 'ffb-host')
        self.setupFFBuilderOptions()
        self.subjob_group = QtWidgets.QGroupBox("Subjob", self.dialog)
        self.subjob_layout = QtWidgets.QVBoxLayout(self.subjob_group)
        self.subjob_layout.addLayout(self.subhost_layout)
        if use_rest:
            self.subjob_layout.addLayout(self.replica_layout)
        count = self.main_layout.count()
        self.main_layout.insertWidget(count - 1, self.subjob_group)
        self.updateCPULimits() 
[docs]    def updateCPULimits(self):
        """
        This method is called whenever host selection is changed. It updates
        maximum number of allowed CPUs as well as GPUs.
        """
        host = self.currentHost()
        if host and self.num_cpus_sw:
            self.num_cpus_sw.widget(CPU_LAYOUT).setMaximum(host.processors)
            num_gpus = 1 if self._single_gpu else host.num_gpus
            gpus_combo = self.num_cpus_sw.widget(GPU_LAYOUT)
            gpus_combo.setMaximum(num_gpus) 
[docs]    def updateNumCPUsLabel(self):
        """
        We update the label here, if present.
        """
        if not hasattr(self, 'num_cpus_sb_postlabel'):
            return
        if self.isGPUHost():
            label_text = "GPUs"
            self.num_cpus_sw.setCurrentIndex(GPU_LAYOUT)
        else:
            label_text = "processors"
            self.num_cpus_sw.setCurrentIndex(CPU_LAYOUT)
        self.num_cpus_sb = self.num_cpus_sw.currentWidget()
        if not self.per_subjob:
            self.num_cpus_sb_postlabel.setText(label_text)
        else:
            self.num_cpus_sb_postlabel.setText("%s per subjob" % label_text)
        host = self.currentHost()
        if host is None:
            return
        if host.num_gpus == 0:
            self.num_cpus_sb.setValue(REST_DEFAULT_NUM_CPUS)
        else:
            self.num_cpus_sb.setValue(REST_DEFAULT_NUM_GPUS) 
[docs]    def setupSubHostCombo(self, combo):
        """
        Add only GPU Hosts to the combo box input. The combo box menu will be
        cleared first.
        :param combo: combo box to append to.
        :type combo: QtWidgets.QComboBox
        """
        combo.clear()
        all_hosts = self.getHosts(excludeGPGPUs=False)
        for h in all_hosts:
            if h.hostType() == config_dialog.Host.GPUTYPE:
                combo.addItem(h.label(), h) 
[docs]    def setupFFBuilderOptions(self):
        """
        Set up options for ffbuilder portion of the config dialog.
        Hidden if we are not enabling ffbuilder.
        """
        self.ffbuilder_layout = QtWidgets.QVBoxLayout()
        self.ffb_host_layout = QtWidgets.QHBoxLayout()
        host_label = QtWidgets.QLabel(self.HOST_LABEL_TEXT)
        self.ffb_host_layout.addWidget(host_label)
        self.ffb_host_cb = QtWidgets.QComboBox()
        self.setupHostCombo(self.ffb_host_cb)
        self.ffb_host_layout.addWidget(self.ffb_host_cb)
        self.ffb_host_layout.addSpacerItem(
            QtWidgets.QSpacerItem(40, 0,
                                  QtWidgets.QSizePolicy.MinimumExpanding))
        self.ffbuilder_layout.addLayout(self.ffb_host_layout)
        self.ffb_subjobs_layout = QtWidgets.QHBoxLayout()
        subjob_title = 'Maximum number of concurrent FFBuilder subjobs:'
        self.ffb_subjobs_le = self.addNumericLineEdit(self.ffb_subjobs_layout,
                                                      0, subjob_title)
        self.ffb_subjobs_layout.addSpacerItem(
            QtWidgets.QSpacerItem(40, 0,
                                  QtWidgets.QSizePolicy.MinimumExpanding))
        self.ffbuilder_layout.addLayout(self.ffb_subjobs_layout)
        self.ffbuilder_group = QtWidgets.QGroupBox("Force Field Builder Job",
                                                   self.dialog)
        self.ffbuilder_group.setLayout(self.ffbuilder_layout)
        count = self.main_layout.count()
        self.main_layout.insertWidget(count - 1, self.ffbuilder_group)
        # Apply previous host
        if self.options['save_host']:
            key = self.last_ffb_host_prefkey
            host = self._app_preference_handler.get(key, None)
            if host:
                self._selectComboText(self.ffb_host_cb, host)
        self.ffbuilder_group.setVisible(self.ffb_enabled) 
[docs]    def setupSubHostLayout(self):
        self.subhost_layout = QtWidgets.QHBoxLayout()
        self.subhost_layout.setContentsMargins(0, 0, 0, 0)
        self.subhost_layout.setSpacing(3)
        host_label = QtWidgets.QLabel("GPU Host:")
        self.subhost_layout.addWidget(host_label)
        self.subhost_layout.addWidget(self.subhost_menu)
        self.subhost_layout.addStretch()
        if self.options['save_host']:
            key = self.last_subhost_prefkey
            subhost = self._app_preference_handler.get(key, None)
            if subhost:
                self._selectComboText(self.subhost_menu, subhost) 
[docs]    def validate(self):
        if not self.validateSubjobs():
            return False
        if not self.validatePlatform():
            return False
        if not self.validateSubHost():
            return False
        return super().validate() 
[docs]    def validateSubjobs(self):
        """
        Validates subjob fields are populated with values that can be
        cast into an int
        """
        subjob_warning = 'Use a valid integer for the "{}" field'
        le_to_warning = {
            self.maxjobs_ef: 'Maximum simultaneous jobs',
            self.ffb_subjobs_le: 'Maximum number of concurrent FFBuilder subjobs',
        }
        for le, field in le_to_warning.items():
            try:
                int(le.text())
            except ValueError:
                self.warning(subjob_warning.format(field))
                return False
        return True 
[docs]    def validateNumCpus(self, host, editfield, silent=False):
        if not super().validateNumCpus(host, editfield, silent):
            return False
        if not self.use_rest:
            return True
        num = int(editfield.text())
        replica = int(self.replica_ef.text())
        if num % replica != 0:
            if not silent:
                self.warning('Number of CPUs must be an integer multiple of '
                             'number of replica.')
            return False
        return True 
[docs]    def validateNumGpus(self, host, editfield, silent=False):
        if not super().validateNumGpus(host, editfield, silent):
            return False
        if not self.use_rest:
            return True
        num = int(editfield.text())
        replica = int(self.replica_ef.text())
        if replica % num != 0:
            if not silent:
                self.warning('Number of GPUs must be an integer factor of '
                             'number of replica.')
            return False
        return True 
[docs]    def validateSubHost(self):
        """
        Checks if the current SUBJOB Host is None - if so a warning dialog is
        posted to the user.
        :return: True if a subjob host is chosen, False if not.
        :rtype: bool
        """
        subjob_host = self.currentHost(self.subhost_menu)
        if subjob_host is None:
            self.warning('No GPU host available. FEP+ is only '
                         'supported on GPUs - please add a GPU host to your '
                         'schrodinger.hosts file.')
            return False
        return True 
[docs]    def currentHost(self, menu=None):
        """
        See ConfigDialog.currentHost() docstring.
        """
        if menu is None:
            menu = self.subhost_menu
        if menu.currentIndex() == -1:
            self.setupSubHostCombo(menu)
            # Return early if menu is empty - this is necessary as the super
            # class would otherwise add CPU Hosts to the menu if it's empty.
            if menu.count() == 0:
                if self.use_rest and hasattr(self, 'replica_ef'):
                    self.replica_ef.setDisabled(True)
                    self.replica_lbl.setDisabled(True)
                return
        return super().currentHost(menu) 
[docs]    def addNumericLineEdit(self,
                           layout,
                           value=1,
                           prelabel=None,
                           postlabel=None):
        """Creates a standard line edit used for input, adds it to the provided
        layout, and then returns the line edit so that it can be stored and its
        value accessed later.
        :param value: the initial value for the line edit
        :type value: int
        :type prelabel: str or QLabel
        :type postlabel: str or QLabel
        If prelabel or postlabel are strings, QLabels with the textual value
        will be created.
        """
        self.buildLabel(layout, prelabel)
        line_edit = self.buildLineEdit(value=value)
        layout.addWidget(line_edit)
        self.buildLabel(layout, postlabel)
        return line_edit 
[docs]    def buildLabel(self, layout, label):
        """
        Build a new QLabel if label is a str, and add the widget to the given
        layout.
        :param layout: layout to which the stacked widget should be added.
        :type layout: QtWidgets.QLayout
        :param label: the text or widget to add to layout.
        :type label: string or QLabel
        """
        if isinstance(label, QtWidgets.QLabel):
            layout.addWidget(label)
        else:
            layout.addWidget(QtWidgets.QLabel(label)) 
[docs]    def buildLineEdit(self, value=1):
        """
        Build a QLineEdit with specific width and validator.
        """
        line_edit = QtWidgets.QLineEdit()
        line_edit.setText(str(value))
        line_edit.setValidator(QtGui.QIntValidator(0, 10000))
        line_edit.setFixedWidth(40)
        return line_edit 
[docs]    def buildComboBox(self):
        """
        Build a QComboBox with specific included options.
        """
        combo_box = CustomGPUComboBox()
        items = ['1', '2', '4']
        combo_box.addItems(items)
        combo_box.setText('1')  # default
        return combo_box 
[docs]    def getSettings(self, extra_kws=None):
        if not extra_kws:
            kw = {}
        else:
            kw = extra_kws
        kw['cpus'] = int(self.num_cpus_sb.value())
        kw['subjob_host_text'] = self.subhost_menu.currentText()
        subhost = self.currentHost()
        subhost_name = subhost.name if subhost else ''
        if subhost_name == 'localhost-gpu':
            kw['subjob_host'] = 'localhost'
        else:
            kw['subjob_host'] = subhost_name
        kw['maxjobs'] = int(self.maxjobs_ef.text())
        if self.use_rest:
            kw['replica'] = int(self.replica_ef.text())
            proc_per_replica = max(1, kw['cpus'] // kw['replica'])
            kw['processors_per_replica'] = proc_per_replica
        kw[feputils.FFBUILDER_SUBJOBS_SETTING] = int(self.ffb_subjobs_le.text())
        ffb_host = self.currentHost(self.ffb_host_cb)
        ffb_host_name = ffb_host.name if ffb_host else ''
        kw[feputils.FFBUILDER_HOST_SETTING] = ffb_host_name
        if self.options['save_host']:
            subhost_pref = kw['subjob_host_text']
            self._app_preference_handler.set(self.last_subhost_prefkey,
                                             subhost_pref)
            if self.ffb_enabled:
                ffb_host_pref = self.ffb_host_cb.currentText()
                self._app_preference_handler.set(self.last_ffb_host_prefkey,
                                                 ffb_host_pref)
        # Will add the "gpus" option:
        return super().getSettings(extra_kws=kw) 
[docs]    def applySettings(self, settings):
        """
        See parent class docstring
        """
        super().applySettings(settings)
        if hasattr(settings, 'subjob_host_text'):
            subhost_text = settings.subjob_host_text
            self._selectComboText(self.subhost_menu, subhost_text)
        if self.use_rest:
            self._applySetting(self.replica_ef.setText, settings, 'replica')
        # Important! Maxjobs settings need to be applied after subjob
        # host since the changing the later would reset maxjobs to zero.
        # This was done to fix PANEL-2172.
        self._applySetting(self.maxjobs_ef.setText, settings, 'maxjobs')
        self._applySetting(self.num_cpus_sb.setValue, settings, 'cpus') 
[docs]    def updateMaxjobsDefault(self):
        host = self.currentHost()
        if host is None:
            return
        if host.queue:
            self.maxjobs_ef.setText("0")
        if self.num_cpus_sw:
            num_gpus = 1 if self._single_gpu else host.num_gpus
            self.num_cpus_sw.widget(GPU_LAYOUT).setMaximum(num_gpus)  
[docs]class CustomGPUComboBox(QtWidgets.QComboBox):
[docs]    def text(self):
        """
        Wrapper for currentText().
        """
        return self.currentText() 
[docs]    def value(self):
        """
        Get the int value of the current text
        :return: int value of current text
        :rtype: int
        """
        return int(self.text()) 
[docs]    def setText(self, text):
        """
        Sets text as selected entry in combo box if found, otherwise,
        the text is added to combo box and set as selected.
        :param text: set either existing or new entry with given text
        :type text: str
        """
        index = self.findText(text)
        if index == -1:
            self.addItem(text)
            index = self.findText(text)
        self.setCurrentIndex(index) 
[docs]    def setValue(self, val):
        """
        Set the current value of the combobox to the specified value.
        :param val: Value to be set
        :type val: int
        """
        self.setText(str(val)) 
[docs]    def setMaximum(self, value):
        """
        Disables combo box entries that are larger than value, adds tje value if
        it wasn't present, and decrements the index till the selected value is
        acceptable.
        :param value: the maximum number of GPUs selectable
        :type value: int
        """
        idx_ngpu_map = {
            idx: int(self.itemText(idx)) for idx in range(self.count())
        }
        for idx, n_gpu in idx_ngpu_map.items():
            self.model().item(idx).setEnabled(n_gpu <= value)
        # determine whether value should be added
        current_idx = self.currentIndex()
        if value not in idx_ngpu_map.values():
            self.addItem(str(value))
        self.setCurrentIndex(current_idx)
        # determine if current index value is too high, go up till we're OK
        while idx_ngpu_map[current_idx] > value and current_idx > 0:
            current_idx -= 1
        self.setCurrentIndex(current_idx)