"""
Contains the class `TestScript`, which provides methods to create, modify,
extract and delete a test script. See also the convenience function
l{getTestFromDir}, which can be used to read a test from a directory.
"""
import datetime
import inspect
import numbers
import os
import re
import shutil
from . import common
from . import constants
from . import sysinfo
from . import workup
from .outcomes.failures import READMESyntaxError
logger = common.logger
EXECUTED_DIR = re.compile(r'^(.+)_run_\d\d\d\d-\d\d-\d\d')
PATH_ENCODING = 'utf-8'
[docs]def get_test_id(directory):
    """
    Get the test ID based on the directory name. Also guesses whether the test
    has already been run.
    :rtype: tuple(int/str, bool)
    :return: (TestID, Was the test executed?)
    """
    basename = os.path.basename(directory)
    if basename.startswith('stu_'):
        basename = basename[4:]
    was_run = EXECUTED_DIR.match(basename)
    if was_run:
        try:
            return int(was_run.group(1)), True
        except ValueError:
            return was_run.group(1), True
    try:
        return int(basename), False
    except ValueError:
        return basename, False 
[docs]def getTestFromDir(username, directory, filename='README'):
    """
    Read test information from a file.
    """
    substitution_files = TestScript.find_substitution_files(directory)
    data = TestScript.read_readme(os.path.join(directory, filename))
    data['directory'] = directory
    data['number'], data['executed'] = get_test_id(directory)
    data['substitution_files'] = substitution_files
    data['creator'] = data.pop('created_by', username)
    test = TestScript._getTestFromDict(data)
    return test 
[docs]class TestScript:
[docs]    def __init__(
            self,
            product,
            priority,
            description,
            command,
            workup,
            build_modified=None,
            creator=None,
            number=None,
            directory=None,
            question='',
            product_subfeature=None,
            mpi_enabled=None,
            allowed_cpu="",
            jira_tracking=None,
            disabled_for_bug=False,
            unsupported_platforms=None,
            shared_files=tuple(),  # noqa: M511
            substitution_files=tuple(),  # noqa: M511
            useJC=True,
            resource_uri=None,
            download=None,
            upload=None,
            tags=None,
            minimum_version=None,
            maximum_version=None,
            executed=False,
            **kwargs):
        self.id = number
        "Test number"
        self.product = product
        self.product_subfeature = product_subfeature
        self.priority = priority
        self.description = description
        self.question = question
        self.directory = directory
        self.expect_job_failure = False
        self.tags = tags or []
        # Execution limitations
        self.mpi_enabled = mpi_enabled
        self.allowed_cpu = allowed_cpu
        if self.allowed_cpu is True:
            self.allowed_cpu = '1'
        if isinstance(self.allowed_cpu, numbers.Number):
            self.allowed_cpu = str(self.allowed_cpu)
        self.unsupported_platforms = unsupported_platforms or []
        self.disabled_for_bug = disabled_for_bug
        self.jira_tracking = jira_tracking
        self.minimum_version = minimum_version
        self.maximum_version = maximum_version
        # Script technicals
        self.build_modified = build_modified or sysinfo.LOCAL.mmshare
        self.creator = creator
        self.shared_files = [
            name for name in shared_files if name.lower() != 'none'
        ]
        # remove trailing path separators from shared files/directories
        self.shared_files = [f.rstrip('/\\') for f in self.shared_files]
        self.substitution_files = [
            name for name in substitution_files if name.lower() != 'none'
        ]
        # Automation options
        self._useJC = useJC
        # None can be returned from DB and means default, which is True
        if self._useJC is None:
            self._useJC = True
        self.command = command
        self.workupstr = workup
        if self.workupstr:
            try:
                compile(self.workupstr, str(self), 'eval')
            except TypeError as e:
                raise READMESyntaxError(
                    f'Test {self} workup causes TypeError: {e}')
            except SyntaxError as e:
                raise READMESyntaxError(
                    f"Workup is not valid for test {self}\n\n"
                    f"Workup String: {self.workupstr} \n\nError: {e}")
        # Add remote installation tag if it's needed and not present.
        if 'pdxgpu' in self.command:
            if 'require:pdxgpu_install' not in self.tags:
                self.tags.append('require:pdxgpu_install')
        if 'host bolt' in self.command.lower():
            if 'require:bolt_install' not in self.tags:
                self.tags.append('require:bolt_install')
        # expect_job_failure must be handled explicitly
        if 'expect_job_failure' in self.workupstr:
            self.expect_job_failure = True
        # Simplify future communication with the web service.
        self.resource_uri = resource_uri
        self.download = download
        self.upload = upload
        # The following will be filled in if/when the script is run.
        self.executed = executed
        self.outcome = None
        "bool : Success or Failure of the test"
        self.workup_messages = ''
        "str : Messages from the workup"
        self.failure_type = ''
        "str : Failure category as per SHARED-3037"
        self.timing = 0
        "float : how long the test took to run (s)"
        self.exit_status = constants.JOB_NOT_STARTED
        "str : exit status"
        self.validate(**kwargs) 
    @classmethod
    def _getTestFromDict(cls, data):
        """
        Creates a test object using a dictionary keyed on the names used in the
        repo and DB.
        """
        testdata = data.copy()
        # These variables have historically had a different internal and
        # external name. The long term solution is to make these names match.
        external_to_internal = dict(use_JC='useJC',
                                    automate_cmd='command',
                                    outcome_workup='workup',
                                    priority_level='priority')
        for external, internal in external_to_internal.items():
            if external in testdata:
                testdata[internal] = testdata.pop(external)
        try:
            test = cls(**testdata)
        except TypeError:
            argspec = inspect.getfullargspec(cls.__init__)
            # remove self and the args that have default values
            required_args = argspec.args[1:-len(argspec.defaults)]
            required_args = set(required_args)
            unfulfilled_args = required_args - set(testdata)
            if not unfulfilled_args:
                # Something else went wrong.
                raise
            msg = 'Missing required README fields: '
            msg += ', '.join(unfulfilled_args)
            if 'workup' in msg:
                # Basically, we want to encourage people to write a workup,
                # but in some situations the exit status really is a sufficient
                # check.
                msg += (' workup can be blank, but the field must be present '
                        'in the README.')
            # Add information about the test to the error message.
            if 'product' in data and 'number' in data:
                msg += f' {data["product"]} test {data["number"]}.'
            raise READMESyntaxError(msg)
        return test
    # ************************************************************************
[docs]    @classmethod
    def read_readme(cls, readme):
        """
        Read README and extract script information.
        Format is keyword=value pairs.
        Also does limited boolean parsing.
        """
        readme_data = {}
        try:
            inh = open(readme)
        except:
            logger.critical("WARNING:  Failed to read the %s." % readme)
            return {}
        for line in inh:
            # Strip newlines
            line = line.replace('\r', '')
            # discard '\r' characters (shouldn't be necessary).
            line = line.strip()
            # Recognize blank lines
            if not line:
                continue
            # Read data for line
            data = line.split('=', 1)
            if len(data) < 2:
                logger.warning("WARNING: Failed to read line in , \"%s\"\n%s" %
                               (readme, line))
                return {}
            keyword = data[0].strip()
            value = data[1].strip()
            # If possible, convert the field to a bool or an int.  If this is
            # not possible, we want a str anyway.
            if value.lower() in ["yes", "true", "1"]:
                value = True
            elif value.lower() in ["no", "false", "0"]:
                value = False
            else:
                try:
                    value = int(value)
                except ValueError:
                    pass
            if keyword in readme_data:
                logger.info("WARNING: keyword, %s, exists multiple times in ,"
                            "%s. Using first occurrence in file." %
                            (keyword, readme))
            else:
                readme_data[keyword] = value
        inh.close()
        # Check "unsupported_plats" for compatibility with old READMEs.
        if 'unsupported_plats' in readme_data:
            readme_data["unsupported_platforms"] = readme_data.pop(
                "unsupported_plats")
        for key in ('unsupported_platforms', 'shared_files', 'tags'):
            if key in readme_data:
                readme_data[key] = re.split(r"\s*,\s*", readme_data[key])
        return readme_data 
[docs]    def write_readme(self, fileobj=None):
        """
        Print or return the README data.
        """
        self.validate()
        rm = [
            'product = %s\n' % self.product,
            'priority = %s\n' % self.priority,
            'description = %s\n' % self.description,
            'command  = %s\n' % self.command,
            'workup = %s\n' % self.workupstr,
            'created_by = %s\n' % self.creator
        ]
        if self.question:
            rm.append('question = %s\n' % self.question)
        if self.product_subfeature:
            rm.append('product_subfeature = %s\n' % self.product_subfeature)
        if self.shared_files:
            rm.append('shared_files = %s\n' % ', '.join(self.shared_files))
        if self.mpi_enabled:
            rm.append('mpi_enabled = %s\n' % self.mpi_enabled)
        if self.unsupported_platforms:
            rm.append('unsupported_platforms = %s\n' %
                      ', '.join(self.unsupported_platforms))
        if self._useJC is not None:
            rm.append('useJC = %s\n' % self._useJC)
        if self.disabled_for_bug:
            rm.append('disabled_for_bug = %s\n' % self.disabled_for_bug)
        if self.jira_tracking:
            rm.append('jira_tracking = %s\n' % self.jira_tracking)
        if self.allowed_cpu:
            rm.append('allowed_cpu = %s\n' % self.allowed_cpu)
        if self.tags:
            rm.append('tags = %s\n' % ', '.join(self.tags))
        if self.minimum_version:
            rm.append('minimum_version = %d\n' % self.minimum_version)
        if self.maximum_version:
            rm.append('maximum_version = %d\n' % self.maximum_version)
        if not fileobj:
            return '\n'.join(rm)
        if isinstance(fileobj, str):
            with open(fileobj, 'w') as fileobj:
                fileobj.writelines(rm) 
[docs]    def validate(self, **kwargs):
        """
        Validate the data stored in a TestScript object. Should be done when
        instantiating one or dumping it to file.
        """
        if kwargs:
            extra = set(kwargs) - {
                'date_created', 'date_modified', 'modifier', 'component',
                'interested_users', 'test_directory'
            }
            if extra:
                raise READMESyntaxError('Unrecognized fields in README: '
                                        f'{extra} for {self}')
        for tag in self.tags:
            if ':' in tag and not tag.startswith('require:'):
                raise READMESyntaxError(
                    f'":" is not allowed in tag names (found "{tag}") {self}')
        if self.minimum_version:
            if not (isinstance(self.minimum_version, int) and
                    len(str(self.minimum_version)) == 5):
                raise READMESyntaxError(
                    'The minimum version needs to be a 5 digit integer '
                    f'(not {self.minimum_version}) {self}')
        if self.maximum_version:
            if not (isinstance(self.maximum_version, int) and
                    len(str(self.maximum_version)) == 5):
                raise READMESyntaxError(
                    'The maximum version needs to be a 5 digit integer '
                    f'(not {self.maximum_version}) {self}')
        if self.disabled_for_bug and not self.jira_tracking:
            raise READMESyntaxError(
                'If you are disabling a test with "disabled_for_bug", you '
                'must provide a JIRA ID and explanation in the '
                f'"jira_tracking" field. {self}')
        if not self.description.startswith('SciVal'):
            # SciVal STU tests are not actually run by STU, so they don't need
            # this validation.
            validate_command_for_host(self.command, self.tags, self,
                                      self.product)
        if '-LOCAL' in self.command:
            raise READMESyntaxError(
                '-LOCAL is not allowed STU command argument')
        if self.useJC():
            if '-NOJOBID' in self.command:
                raise READMESyntaxError(
                    '-NOJOBID argument is not allowed for STU tests with useJC=True'
                )
            elif '-WAIT' in self.command:
                raise READMESyntaxError(
                    '-WAIT argument is not allowed for STU tests with useJC=True'
                ) 
    # ************************************************************************
[docs]    def runWorkup(self, job=None, registered_workups=None):
        """Run my workup."""
        if job:
            directory = job.getCommandDir()
        else:
            directory = self.directory
        workup.workup_outcome(self,
                              directory,
                              registered_workups=registered_workups,
                              job_dj_job=job)
        return self.outcome 
    # ************************************************************************
[docs]    def getNewExecuteDirectory(self, attempts=120):
        """Get a new directory name."""
        if self.id:
            name = f'stu_{self.id}'
            basedir = os.getcwd()
        else:
            directory = self.original_directory or self.directory
            basedir, name = os.path.split(directory)
        exedir_format = "{name}_run_{date:%Y-%m-%d}_{{index:0=3}}"
        exedir_format = exedir_format.format(name=name,
                                             date=datetime.datetime.now())
        exedir_format = os.path.join(basedir, exedir_format)
        for i in range(attempts):
            path = exedir_format.format(index=i)
            if not os.path.exists(path):
                return path
        else:
            raise Exception(
                f'No unique new directory found after {attempts} attempts.') 
    # ************************************************************************
[docs]    def copyToScratch(self):
        """Copy files to a scratch folder."""
        self.original_directory = self.directory
        self.directory = self.getNewExecuteDirectory()
        try:
            shutil.copytree(self.original_directory, self.directory)
        except Exception as err:
            # Ignore errors on Darwin, see QA-646
            if not sysinfo.LOCAL.isDarwin:
                logger.debug('WARNING: Failed to copy "%s" to "%s"' %
                             (self.original_directory, self.directory))
                import traceback
                logger.debug(traceback.format_exc())
                logger.debug(err)
                return False
        return True 
    # ************************************************************************
[docs]    def recoverFromScratch(self, get_license=True):
        """
        Remove scratch folder and prepare to add/modify the test.
        """
        if get_license:
            self._getReferenceLicense()
        shutil.rmtree(self.directory)
        self.directory = self.original_directory 
    # ************************************************************************
    def _getReferenceLicense(self):
        """
        Get the license file created in the scratch directory and move it to
        the test directory.  Assumes both self.directory and
        self.original_directory are set, should be run from
        `recoverFromScratch`.
        """
        license = os.path.join(self.directory, constants.LICENSE_CHECK_FILE)
        reference = os.path.join(self.original_directory,
                                 constants.REF_LICENSE_CHECK_FILE)
        if os.path.isfile(license) and not os.path.isfile(reference):
            logger.info('Creating %s for %s' %
                        (constants.REF_LICENSE_CHECK_FILE, self.id))
            shutil.copy(license, reference)
    # ************************************************************************
[docs]    @classmethod
    def find_substitution_files(cls, directory):
        """
        Run through all files in `directory` and look for the string
        "${SHARED} or ${CWD}". Only text files need to be processed.
        """
        subfiles = []
        for root, dirs, files in os.walk(directory):
            for filename in files:
                if filename == 'README' or filename.endswith('gz'):
                    continue
                newpath = os.path.join(root, filename).encode(PATH_ENCODING)
                with open(newpath, 'rb') as inh:
                    for line in inh:
                        if b'${SHARED}' in line or b'${CWD}' in line:
                            newpath = newpath.replace(
                                directory.encode(PATH_ENCODING), b"",
                                1).decode(PATH_ENCODING)
                            # Workaround for residual / character that may or
                            # may not be found
                            if newpath[0] == "/" or newpath[0] == "\\":
                                newpath = newpath.replace("/", "", 1)
                                newpath = newpath.replace("\\", "", 1)
                            logger.debug(f"{newpath} requires substitution "
                                         f"(test {directory}.")
                            subfiles.append(newpath)
                            break
        if subfiles == []:
            logger.debug("DEBUG: ${SHARED} or ${CWD} not found in any files "
                         "in \"%s\"" % (directory))
        return subfiles 
    # *************************************************************************
[docs]    def substituteFiles(self):
        """
        Replace substitution expressions in files that require it.  Requires
        that the substitution files already be identified.
        """
        try:
            for filename in self.substitution_files:
                self._substituteInPlace(filename)
        except Exception as err:
            logger.exception("Failed to substitute files for %s:\n" % (self))
            return False
        return True 
    # *************************************************************************
    def _substituteInPlace(self, filename):
        """
        Replace ${SHARED} and ${CWD} in `filename` with the shared directory
        and test directory, respectively.
        """
        path = os.path.join(self.directory, filename)
        outlines = []
        shared = os.path.join(os.getcwd(), 'shared').encode(PATH_ENCODING)
        cwd = os.path.abspath(self.directory).encode(PATH_ENCODING)
        try:
            with open(path, 'rb') as fh:
                for line in fh:
                    if b"${SHARED}" in line:
                        line = line.replace(b"${SHARED}", shared)
                    if b"${CWD}" in line:
                        line = line.replace(b"${CWD}", cwd)
                    outlines.append(line)
            with open(path, 'wb') as fh:
                for line in outlines:
                    fh.write(line)
        except Exception:
            logger.critical('REPORTING INFO FOR SHARED-2572')
            logger.critical('Current directory: %s' %
                            os.path.abspath(os.getcwd()))
            logger.critical(f'Test directory: {self.directory} ({cwd})')
            logger.critical('{} exists: {}'.format(self.directory,
                                                   os.path.exists(cwd)))
            logger.critical('{} exists: {}'.format(path, os.path.exists(path)))
            logger.critical('All files in test directory: %s' %
                            ', '.join(os.listdir(cwd)))
            raise
        return True
    # *************************************************************************
[docs]    def useJC(self):
        """
        Determines whether the script will be run under jobcontrol, uses
        self._useJC as a default value.
        :return: Should this script be run under jobcontrol?
        :rtype: bool
        """
        return bool(self._useJC) 
    # *************************************************************************
[docs]    def runsRemotely(self):
        """
        A job is available to run on a remote host if:
            * It is a jobcontroljob that doesn't have the require:localhost tag
            * It is not a jobcontroljob, but it has ${HOST} in the command.
        """
        useJC = self.useJC()
        return (useJC and 'require:localhost' not in self.tags or
                not useJC and '${HOST}' in self.command) 
    # *************************************************************************
[docs]    def toDict(self):
        """
        Dump test object as a dict
        """
        self.validate()
        rdict = self.__dict__.copy()
        unserialized_fields = {
            'executed', 'exit_status', 'workup_messages', 'failure_type',
            'expect_job_failure', 'timing', 'original_directory', 'directory',
            'outcome', 'id'
        }
        for field in unserialized_fields:
            rdict.pop(field, None)
        rdict['useJC'] = rdict.pop('_useJC', None)
        rdict['number'] = self.id
        rdict['workup'] = rdict.pop('workupstr')
        return rdict 
    def __str__(self):
        return f'{self.product} test {self.id}' 
[docs]def validate_command_for_host(command, tags, test_id, product):
    """
    Raises READMESyntaxError if command has encoded incorrect host
    information. In general, there should be no -HOST information so
    stu can decide what hosts to run with.
    NOTE: This is same as code in forms validation for stu server.
    Use this code once we integrate stu server to use from mmshare.
    :param command: commandline which will be executed
    :type command: str
    :param tags: list of tags associated with test
    :type tags: set(str)
    :param test_id: name of test (used for error reporting)
    :type test_id: str
    :param product: name of Product test is associated with
    :type product: str
    """
    if product == 'Job control':
        return
    if 'require:specific_host' in tags:
        return
    if '-HOST' not in command:
        return
    # These validations don't apply to job control or mpi tests
    if re.search('-HOST.*localhost', command):
        if 'require:localhost' not in tags:
            raise READMESyntaxError(
                f'Tests which specify "-HOST localhost" in the '
                f'command must use the tag "require:localhost". {test_id} ')
    elif '-HOST ${HOST}' not in command and '-HOST "${HOST}:${NCPU}"' not in command:
        raise READMESyntaxError(
            f'Tests which contain "-HOST" in the command must '
            f'use the tag "require:specific_host". {test_id}')
    # Other variants of command are considered to be success
    return 
README2TEST = {}
"""Correspondence between README values and test values."""
TEST2README = {v: k for k, v in README2TEST.items()}
"""Correspondence between README values and test values."""
LIST_VALUES = []
"""Values that are lists."""
# *****************************************************************************
# This module is intended to be imported and won't typically be
# run.  To ensure the "program" is only executed when run as a script
# and not run when imported, do the following check for __main__:
if __name__ == "__main__":
    print("Nothing to see here.")