"""
Support for discovery and running test executables from pytest.
"""
import contextlib
import glob
import os
import pathlib
import sys
import tempfile
import time
from subprocess import STDOUT
from subprocess import list2cmdline
from typing import Dict
from typing import List
from typing import Optional
import psutil
import pytest
from schrodinger.test import memtest
SCHRODINGER_RUN = os.path.join(os.environ['SCHRODINGER'], 'run')
VALGRIND_ERROR_CODE = 29
MEMTEST_LOG = '{}_valgrind.log'
MEMTEST = ('valgrind', '--tool=memcheck', '--time-stamp=yes',
           '--num-callers=20', '--gen-suppressions=all', '--leak-check=yes',
           "--keep-debuginfo=yes", f'--error-exitcode={VALGRIND_ERROR_CODE}',
           '--log-file={}')
[docs]class ExecutableFile(pytest.File):
    """A compiled executable file that should be executed as a test."""
[docs]    def collect(self):
        yield ExecutableTest.from_parent(name=self.fspath.basename, parent=self)  
[docs]class ProcessDied(RuntimeError):
[docs]    def __init__(self, command: List[str], output: Optional[str], retcode: int):
        """
        Exception to indicate a process has terminated with a non zero
        exit code.
        :param command: command line arguments of process
        :param output: stdout/stderr of process, if found
        :param retcode: exit code of process
        """
        msg = list2cmdline(command) + '\n'
        msg += f'Process died with return code {retcode}'
        if output:
            msg += '\nTest output:\n' + output
        super().__init__(msg)  
[docs]class ProcessKilled(ProcessDied):
[docs]    def __init__(self, command: List[str], output: Optional[str], timeout: int):
        """
        Exception to indicate a process has timed out.
        :param command: command line arguments of process
        :param output: stdout/stderr of process, if found
        :param timeout: number of seconds command run before timeout
        """
        msg = list2cmdline(command) + '\n'
        msg += f'Killed process after timeout of {timeout}s'
        if output:
            msg += '\nTest output:\n' + output
        RuntimeError.__init__(self, msg)  
[docs]class ExecutableTest(pytest.Item):
    """Use $SCHRODINGER/run to execute a file with no arguments."""
[docs]    def __init__(self, name, parent=None, **kw):
        name = parent.fspath.purebasename
        super().__init__(name, parent=parent, **kw)
        self.output = ""
        self.command = [str(self.fspath)] 
    def _get_memtest_log_files(self):
        """
        Build the name of the memtest log file for this test,
        and return it wrapped inside a list.
        :return: list
        """
        # Record results to the correct log file
        basename = self.fspath.dirpath(self.fspath.purebasename)
        if self.name != self.fspath.purebasename:
            basename += '__' + self.name
        return [MEMTEST_LOG.format(basename)]
[docs]    def getCommand(self):
        """
        Build the list command arguments to execute the test.
        :param memtest_cmd: command string section to make the test
                run through valgrind.
        :type memtest_cmd: list or None
        :return: list
        """
        cmd = []
        if self.config.getvalue('from_product'):
            product, product_exec = self.config.option.from_product
            if product != 'mmshare':
                cmd = [SCHRODINGER_RUN, '-FROM', product]
        if getattr(self.config.option, 'memtest', None):
            memtest_cmd = list(MEMTEST)
            # In this class memtests only generate a single log file
            memtest_log_file = self.memtest_log_files[0]
            memtest_cmd[-1] = memtest_cmd[-1].format(memtest_log_file)
            src_dirname = self.config.getini('src_dirname')
            suppressions = self.findValgrindSuppressionsFiles(src_dirname)
            memtest_cmd.extend(f'--suppressions={f}' for f in suppressions)
            cmd.extend(memtest_cmd)
        return cmd + self.command 
[docs]    def findValgrindSuppressionsFiles(self, src_dirname):
        """
        Search from the test directory in the source repository to
        $SCHRODINGER_SRC. Also search $SCHRODINGER_SRC/build_tools.
        """
        if not src_dirname:
            raise RuntimeError(
                "Must add 'src_dirname' to pytest.ini, eg 'maestro-src'.")
        def directories():
            schro = os.path.abspath(os.environ['SCHRODINGER'])
            schro_src = os.path.abspath(os.environ['SCHRODINGER_SRC'])
            yield os.path.join(schro_src, 'mmshare', 'build_tools')
            src_rootdir = os.path.join(schro_src, src_dirname)
            cur_dir = os.path.realpath(self.fspath.dirname)
            if schro in cur_dir:
                # running test from build tree, find corresponding directory in
                # source tree
                cur_dir = pathlib.Path(cur_dir)
                cur_dir = cur_dir.relative_to(schro)
                cur_dir = os.path.join(schro_src, src_dirname,
                                       *cur_dir.parts[1:])
            if src_rootdir not in cur_dir:
                raise RuntimeError(
                    f"Pytest file {self.fspath} is not in build or source "
                    "directories, cannot find suppressions files.")
            while src_rootdir in cur_dir:
                yield cur_dir
                cur_dir = os.path.dirname(cur_dir)
        suppression_files = []
        for dirpath in directories():
            suppression_files.extend(
                glob.glob(os.path.join(dirpath, '*suppressions*')))
        return suppression_files 
[docs]    def runtest(self,
                env: Optional[Dict[str, str]] = None,
                capture: Optional[bool] = None,
                stdout=None,
                stderr=None):
        """
        Executed for each test. (pytest method)
        Relegate this to a function that is easier to test.
        :param env: Shell environment for subprocess.Popen
        :type env: dict
        :param capture: Should output be captured and stored? If not, it goes
            to stdout.
        :param stdout: file descriptor, file object, or subprocess special
            variable for use as stdout argument to subprocess call. Overrides
            capture if not None.
        :type stdout: file-like object
        :param stderr: file descriptor, file object, or subprocess special
            variable for use as stderr argument to subprocess call. Overrides
            capture if not None.
        :type stdout: file-like object
        """
        if capture is None:
            capture = self.config.getvalue('capture') != 'no'
        kwargs = {}
        if getattr(self.config.option, 'memtest', None):
            self.memtest_log_files = self._get_memtest_log_files()
            kwargs['memtest_log_files'] = self._get_memtest_log_files()
        run_subprocess_test(cmd=self.getCommand(),
                            capture=capture,
                            timeout=_get_subprocess_timeout(
                                self.config.option.faulthandler_timeout),
                            verbose=self.config.option.verbose,
                            env=env,
                            stdout=stdout,
                            stderr=stderr,
                            **kwargs) 
[docs]    def repr_failure(self, excinfo):
        """Called when self.runtest() raises an exception. (pytest method)"""
        if isinstance(excinfo.value, ProcessDied):
            return str(excinfo.value)
        return super().repr_failure(excinfo) 
[docs]    def reportinfo(self):
        """The short and long names of the test. (pytest method)"""
        return self.fspath, 0, f"C++: $SCHRODINGER/run {self.name}"  
def _kill_process_and_children(process: psutil.Process, timeout: int = 5):
    """
    Terminate and ultimately kill a process and all of its children after a
    waiting (at most) `timeout` seconds.
    This is required when we kill a process on Windows because we need to wait
    for the process to complete before cleaning up its files.
    """
    # use kill() + communicate() as per subprocess.TimeoutExpired docs to wait
    # for process termination and close file descriptors
    for child in process.children(recursive=True):
        try:
            child.kill()
        except (psutil.NoSuchProcess, psutil.AccessDenied):
            continue
    process.kill()
    process.communicate()
def _remove_logfile(logfile):
    """
    Tries to remove the logfile after a test runs in subprocess.
    On windows, retry after waiting a short time.
    :param logfile: path to logfile
    :type logfile: str
    """
    try:
        os.remove(logfile)
    except OSError:
        if sys.platform == "win32":
            # See if we will re-experience the error message after
            # waiting 30 sec. The files in use appear to be sh.exe
            # processes that take a few secs to die.
            time.sleep(30)
            os.remove(logfile)
        else:
            raise
def _get_memtest_report(memtest_log_file):
    """
    Parse a valgrind memtest log file, extract any detected errors,
    and summarize them into a minimal report.
    :param memtest_log_file: valgrind log file
    :type memtest_log_file: str
    :return: str; empty if no errors were found by valgrind.
    """
    log_name = os.path.basename(memtest_log_file)
    with open(memtest_log_file) as fh:
        leaks = memtest.read_valgrind_log(log_name, fh)
    message = ""
    if len(leaks) > 0:
        memtest.uniquify_leaks(leaks)
        message = f'{len(leaks)} memory errors detected in {log_name}, of types:\n'
        for error in {leak.short_description for leak in leaks}:
            message += f'* {error}\n'
    return message
@contextlib.contextmanager
def _get_temporary_logfile(capture: bool):
    """
    Add a context manager so it is easier to manage temporary output. If
    capture is false, this operation is False.
    """
    if not capture:
        yield
        return
    logfile = tempfile.NamedTemporaryFile(delete=False)
    try:
        yield logfile
    finally:
        logfile.close()
        _remove_logfile(logfile.name)
[docs]def run_subprocess_test(cmd: List[str],
                        timeout: Optional[int],
                        capture: bool,
                        verbose: bool,
                        env: Optional[Dict[str, str]],
                        *,
                        memtest_log_files: List[str] = None,
                        stdout=None,
                        stderr=None):
    """
    Run the test in a subprocess.
    :param cmd: subprocess.Popen command argument
    :param timeout: passed to subprocess.wait
    :param capture: capture test output
    :param verbose: pytest verbosity
    :param env: Shell environment for subprocess.Popen
    :param memtest_log_files: iterable with the names of the
            log files generated by the memtestz, or None if
            not running a memtest.
    :type memtest_log_files: iterable
    :param stdout: file descriptor, file object, or subprocess special variable
        for use as stdout argument to subprocess call. Overrides capture if not
        None.
    :type stdout: file-like object
    :param stderr: file descriptor, file object, or subprocess special variable
        for use as stderr argument to subprocess call. Overrides capture if not
        None.
    :type stderr: file-like object
    """
    def get_output(logfile: Optional[str] = None):
        """
        Return output of a given file.
        """
        if not logfile:
            return ''
        logfile.close()
        with open(logfile.name, errors="backslashescape") as fh:
            return fh.read()
    kwargs = {}
    logfile = None
    if stdout is not None:
        kwargs['stdout'] = stdout
    if stderr is not None:
        kwargs['stderr'] = stderr
    if verbose:
        print("\n executing: " + list2cmdline(cmd))
    with _get_temporary_logfile(capture) as logfile:
        if capture:
            kwargs['stdout'] = logfile
            kwargs['stderr'] = STDOUT
        process = psutil.Popen(cmd, env=env, **kwargs)
        try:
            process.wait(timeout)
        except psutil.TimeoutExpired:
            _kill_process_and_children(process)
            output = get_output(logfile)
            raise ProcessKilled(cmd, output, timeout)
        if process.returncode:
            # It would be nice to distinguish between test failures,
            # crashes, hangs, and memtest failures here
            output = get_output(logfile)
            raise ProcessDied(cmd, output, process.returncode)
        elif memtest_log_files is not None:
            # If this is a memcheck, parse the log to see if any errors
            # were found, but not reported (valgrind exits with status
            # 0 if the leaks happen in a non-joined pthread)
            output = []
            for log_file in memtest_log_files:
                report = _get_memtest_report(log_file)
                if len(report) > 0:
                    output.append(report)
            if len(output) > 0:
                raise ProcessDied(cmd, '\n'.join(output), VALGRIND_ERROR_CODE) 
[docs]class Makefile(pytest.File):
[docs]    def collect(self):
        yield MakefileTest.from_parent(name=self.fspath.basename, parent=self)  
[docs]class MakefileTest(ExecutableTest):
    """Run the `make test` target in a directory."""
[docs]    def __init__(self, name, parent, **kw):
        name = parent.fspath.dirpath().basename
        pytest.Item.__init__(self, name, parent=parent, **kw) 
    def _get_memtest_log_files(self):
        """
        Provides a generator that finds valgrind log files created
        by the makefile inside the local directory.
        Actual scanning for files happens only when we start
        iterating over the generator (so not here).
        The globbing expression comes from common.mk
        :return: iterable
        """
        return glob.iglob('*_valgrind.log')
[docs]    def getCommand(self):
        test_directory = self.fspath.dirname
        if self.config.getvalue('memtest'):
            return ['make', '-C', test_directory, 'memtest_automated']
        else:
            return ['make', '-C', test_directory, 'test'] 
[docs]    def runtest(self, env=None):
        """
        Make on Windows requires the path separator to be "/", but toplevel
        turns the SCHRODINGER path into a valid path for the current os.
        "unfix" the SCHRODINGER path separator.
        """
        env = os.environ.copy()
        if sys.platform == "win32":
            env['SCHRODINGER'] = env['SCHRODINGER'].replace('\\', '/')
        for var in ("MAKEFLAGS", "MAKELEVEL", "MFLAGS"):
            if var in env:
                del env[var]
        super().runtest(env)  
def _get_subprocess_timeout(faulthandler_timeout) -> Optional[int]:
    """
    Derive subprocess timeout from faulthandler timeout.
    """
    if faulthandler_timeout is None:
        return
    # make subprocess timeout 10s faster than faulthandler timeout, so we
    # only see test, presuming we can terminate process in that time. This
    # makes logs easier to read, since there would not be misleading traceback.
    subprocess_timeout = faulthandler_timeout - 10
    if subprocess_timeout > 0:
        return subprocess_timeout
    # If the timeout is less than 10s, return faulthandler_timeout
    return faulthandler_timeout