"""
Base class for STU utilities.  Each utility should probably implement
`getParser` and `_doCommand`
@copyright: Schrodinger, Inc. All rights reserved.
"""
# Add _strptime to sys.modules before anything else happens to avoid threading
# bug in requests module.  See SHARED-2608.
import argparse
import datetime
import future.utils
import os.path
import sys
from schrodinger.utils import cmdline
from schrodinger.utils import log
from . import client
from . import common
from . import constants
from . import run
from . import testscripts
from . import workup
# Add encodings.idna to sys.modules before anything else happens to import
# bug in requests module.  See SHARED-2618.
# This creates the logger.  If we ever need to log with a logger,
# we must import this module
logger = common.logger
[docs]class store_server(argparse.Action):
    """Argparse action to monkeypatch the server address."""
    def __call__(self, parser, namespace, values, option_string=None):
        if not values.startswith('http'):
            values = 'http://' + values
        common.BASE_URL = values 
def _check_dir(string):
    """Check whether a directory exists. Used in parser."""
    if not os.path.isdir(string):
        raise argparse.ArgumentTypeError("Directory %s is not accessible." %
                                         string)
    return os.path.abspath(string)
[docs]class TestRunnerParser(argparse.ArgumentParser):
    "Subclass that allows easy adding of options."
[docs]    def __init__(self, present_tense, past_tense, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.past_tense = past_tense.lower()
        self.cap_present_tense = present_tense.capitalize() 
[docs]    def addSkipArg(self):
        self.add_argument("-t",
                          "--test_skip",
                          action="store_false",
                          dest="verify",
                          help=("Skip test execution of job before adding "
                                "to database.")) 
[docs]    def addTestIDArg(self):
        self.add_argument(
            "--id",
            dest="test_ids",
            type=common.str2list,
            help=
            "%s one or more tests with ID numbers indicated. Format: 1,7,10-13 or '1 7 10-13' will choose tests 1, 7, 10, 11, 12, and 13."
            % self.cap_present_tense)
        self.add_argument("-s",
                          dest="test_ids",
                          type=common.str2list,
                          help=argparse.SUPPRESS) 
[docs]    def addDirectories(self, required=False, noscratch=False):
        if noscratch:
            unless = (' Tests will be run in a scratch directory unless the '
                      '--noscratch option is specified.')
        else:
            unless = ''
        help_msg = 'One or more directories to be {}.{}'.format(
            self.past_tense, unless)
        if required:
            self.add_argument('directories',
                              metavar='DIRECTORY',
                              type=_check_dir,
                              nargs='+',
                              help=help_msg)
        else:
            self.add_argument('--directories',
                              metavar='DIRECTORY',
                              type=_check_dir,
                              nargs='+',
                              help=help_msg)
        if noscratch:
            self.add_argument('--noscratch',
                              dest='scratch',
                              action='store_false',
                              help=('%s in the test directory instead of a '
                                    'scratch directory.' %
                                    self.cap_present_tense)) 
[docs]    def addExpandArg(self):
        self.add_argument("--expand_variables",
                          action="store_true",
                          dest="expand_variables",
                          help=("Expand ${CWD} and ${SHARED} variables upon "
                                "test extraction. (Default: False)")) 
[docs]    def addTestLimitations(self):
        launch_location_mutex = self.add_mutually_exclusive_group()
        launch_location_mutex.add_argument(
            "--remote_only",
            action="store_true",
            help=("Only run tests that can be executed on a remote host."))
        launch_location_mutex.add_argument(
            "--local_only",
            action="store_true",
            help=("Only run tests that must be run on the local host."))
        self.add_argument('--allow',
                          metavar='tag1,tag2,tag3',
                          dest='allow_tags',
                          type=common.str2strlist,
                          help="Don't exclude tests from being %s that "
                          "\"require:\" the requested resources." %
                          self.past_tense)
        self.add_argument("--mpi",
                          action="store_true",
                          dest="mpi_enabled",
                          help=("Execute *only* MPI-enabled jobs from "
                                "selected test tests.  Running without "
                                "this option runs only non-MPI-enabled "
                                "tests"))
        self.add_argument("-mpi",
                          action="store_true",
                          dest="mpi_enabled",
                          help=argparse.SUPPRESS)
        self.addForce() 
[docs]    def addCPUArg(self):
        self.add_argument("--cpu",
                          dest="ncpu",
                          metavar="<number>",
                          default=1,
                          type=int,
                          help=("Number of CPUs each subjob will run on. "
                                "(Default: 1)"))
        self.add_argument("-cpu",
                          dest="ncpu",
                          metavar="<number>",
                          default=1,
                          type=int,
                          help=argparse.SUPPRESS) 
[docs]    def addForce(self):
        self.add_argument("-f",
                          "--force",
                          dest="force",
                          action="store_true",
                          help=("Force execution of tests, overriding "
                                "any limitations such as unsupported "
                                "platforms, bug disabled, \"require:\" "
                                "labels, # of processor limitations, MPI "
                                "requirements, etc.")) 
[docs]    def addExecuteOptions(self):
        self.add_argument(
            "--additionalArgs",
            default='',
            help=
            "Additional arguments to append to every test's command.  Specify as a (quoted) space separated list"
        )
        self.add_argument(
            "-timeout",
            "--timeout",
            type=int,
            help=
            "Max duration of a test in seconds. Jobs longer than this will be killed."
        ) 
[docs]    def addVerbosityArgs(self):
        self.add_argument('--quiet',
                          action='store_true',
                          help="Print only the most important information")
        self.add_argument(
            '--verbose',
            action='store_true',
            help=
            "Print additional information. Note that -DEBUG implies -verbose.") 
[docs]    def addServerArg(self, user_option=True):
        # Specify the STU server to use for running. Suppressed, because it is
        # only used in testing STU.
        self.add_argument('--server',
                          dest='server',
                          default=common.BASE_URL,
                          action=store_server,
                          help=argparse.SUPPRESS)
        self.add_argument('--user',
                          dest='username',
                          default=client.get_stu_username(),
                          help='Username on the STU server.') 
[docs]    def addJobControlOptions(self):
        cmdline.add_jobcontrol_options(self,
                                       options=[cmdline.HOST, cmdline.DEBUG])  
[docs]class TestUtility:
    present_tense = None
    past_tense = None
[docs]    def __init__(self, arguments=None):
        self.utility = '$SCHRODINGER/utilities/stu_%s' % self.present_tense
        self.starttime = None
        if arguments is None:
            arguments = sys.argv[1:]
        self.cmd_run = self.utility + ' ' + ' '.join(arguments)
        self.getOptions(arguments) 
[docs]    def getParser(self, prog=None, description=None):
        """
        :rtype: argparse.ArgumentParser
        """
        if not self.present_tense or not self.past_tense:
            raise NotImplementedError("This is a blank base class.")
        if not description:
            description = self.__doc__
        if not prog:
            prog = self.utility
        return TestRunnerParser(present_tense=self.present_tense,
                                past_tense=self.past_tense,
                                prog=prog,
                                description=description) 
[docs]    def getOptions(self, arguments=None, user_option=True):
        parser = self.getParser()
        parser.addServerArg(user_option=user_option)
        parser.parse_args(arguments, self)
        if self.verbose or getattr(self, 'debug', False):
            logger.setLevel(log.DEBUG)
            self.verbosity = 'verbose'
        elif self.quiet:
            logger.setLevel(log.WARNING)
            self.verbosity = 'quiet'
        else:
            logger.setLevel(log.INFO)
            self.verbosity = 'normal'
        return parser 
[docs]    def getTests(self):
        """
        Get the TestScript objects required based on the mode.
        """
        if hasattr(self, 'directories') and self.directories:
            for directory in self.directories:
                yield testscripts.getTestFromDir(self.username, directory)
        else:
            test_ids = getattr(self, 'test_ids', [])
            products = getattr(self, 'products', [])
            components = getattr(self, 'components', [])
            priorities = getattr(self, 'priority', [])
            tags = getattr(self, 'tags', [])
            not_products = set(
                getattr(self, 'failed_products', None) or tuple())
            for source in getattr(self, 'failed_products', None) or tuple():
                if source.lower() == 'none':
                    continue
                not_products.update(constants.source2product[source])
            not_products.update(getattr(self, 'skip_products', None) or tuple())
            not_components = getattr(self, 'skip_components', tuple())
            not_tags = getattr(self, 'skip_tags', tuple())
            for test in client.retrieve(self.username,
                                        test_ids,
                                        products,
                                        components=components,
                                        priorities=priorities,
                                        tags=tags,
                                        not_products=not_products,
                                        not_components=not_components,
                                        not_tags=not_tags):
                try:
                    test_ids.pop(test_ids.index(test.id))
                except (ValueError, AttributeError):
                    pass
                yield test
            if test_ids:
                logger.info(
                    '\nThese test IDs were requested but not available: ' +
                    ', '.join(str(test_id) for test_id in test_ids)) 
[docs]    def runJobs(self, inscratch=False):
        """
        :param report: Should results be reported to the DB?
        :param inscratch: Should the tests be copied to a scratch directory
                before execution?
        :return: Did test execution complete as expected?
        :rtype: bool
        """
        status = True
        if inscratch:
            self.tests = []
        self.runner = run.Runner(self)
        for test in self.getTests():
            if inscratch:
                # copy test to scratch directory
                # set test.directory attribute to scratch in a copy of test.
                test.copyToScratch()
                self.tests.append(test)
            test.substituteFiles()
            self.runner.addScript(test)
        logger.debug("%s tests queued." % self.runner.job_runner.total_added)
        try:
            self.runner()
        except run.JobDJError as err:
            logger.error(err)
            status = False
        return self.runner, status 
[docs]    def printRunSummary(self, runner):
        if not runner.tests:
            return True
        if not self.quiet:
            workup.print_summary(runner.tests)
            # Add a blank line between brief and extended summary.
            logger.info('')
        success = self.printOnlyFailed(runner.tests)
        return success or not getattr(self, 'return_code', True) 
[docs]    def printOnlyFailed(self, tests):
        failed_tests = [
            test_id for (test_id, test) in future.utils.listitems(tests)
            if not test.outcome
        ]
        if failed_tests:
            logger.warning("{} tests run, {} failed.".format(
                len(tests), len(failed_tests)))
            logger.warning('Failed test IDs: ' + ' '.join(
                str(test_id)
                for test_id in workup.sort_mixed_list(failed_tests)))
        else:
            logger.warning("{} tests run, all passed.".format(len(tests)))
        return not failed_tests 
[docs]    def printTestNumbers(self, test_ids):
        for i, test in enumerate(test_ids):
            print("%8s" % test, end=' ')
            if not i + 1 % 8:
                print() 
    def __call__(self):
        self.printHeader()
        successes = []
        success = True
        for test in self.getTests():
            if self._doCommand(test):
                if hasattr(test, 'id'):
                    successes.append(test.id)
                else:
                    successes.append(test)
            else:
                success = False
        if successes:
            logger.info("\n%d tests %s:" % (len(successes), self.past_tense))
            self.printTestNumbers(successes)
        else:
            logger.warning("\nWARNING: No tests %s" % self.past_tense)
        self.printFooter()
        return success
    def _doCommand(self, test):
        raise NotImplementedError("This is a blank base class.")
[docs]    def requireApiKey(self, parser):
        try:
            common.get_api_key()
        except (OSError, RuntimeError) as err:
            parser.error(err)