"""
Client for communication with STU server.  The core functionality is in
`TestClient` and `ResultReporter`, but most is accessible from the module
level convenience functions `create`, `retrieve`, `update`, `delete`, and
`download`, which allow modification of the tests in the database.
@copyright: (c) Schrodinger, LLC. All rights reserved.
"""
import getpass
import io
import json
import os
import os.path
import sys
import tempfile
import time  # noqa: F401, mocked in test
import urllib
import zipfile
import backoff
import requests
# Disable InsecurePlatformWarning until we update to > Python 2.7.9
import requests.packages.urllib3
from http import HTTPStatus
from schrodinger.test.stu import common
from schrodinger.test.stu import testscripts
from schrodinger.utils import fileutils
from schrodinger.utils import machid
# Hack for SHARED-4292
if sys.platform == "win32" or sys.platform.startswith("darwin"):
    import ssl
    ssl.HAS_SNI = False
requests.packages.urllib3.disable_warnings()
logger = common.logger
_VERSION = 1
"""API version to use."""
_JSON_FORMAT = {'content-type': 'application/json'}
MAX_MESSAGE_LENGTH = 10000
SCIVAL_TAGS = {'scival', 'require:scival'}
# Adjust the default number of allowed requests to the same URL.
requests.adapters.DEFAULT_RETRIES = 3
# Requests exceptions on which backoff should retry
RETRIABLE_EXCEPTIONS = (requests.exceptions.ConnectionError,
                        requests.exceptions.RequestException,
                        requests.exceptions.SSLError)
# RequestException status codes that should be retried
RETRIABLE_STATUS_CODES = (
    HTTPStatus.REQUEST_TIMEOUT,  # 408
    HTTPStatus.LENGTH_REQUIRED,  # 411
    499,  # not available in HTTPStatus
    HTTPStatus.INTERNAL_SERVER_ERROR,  # 500
    HTTPStatus.BAD_GATEWAY,  # 502
    HTTPStatus.SERVICE_UNAVAILABLE,  # 503
)
# Connection timeout should be "slightly larger than a multiple of 3, which is
# the default TCP packet retransmission window."
# (see http://docs.python-requests.org/en/master/user/advanced/#timeouts)
CONNECTION_TIMEOUT_SECONDS = 60.1
READ_TIMEOUT_SECONDS = 3600
STU_REMOTE_USER = "STU_REMOTE_USER"
[docs]class ClientError(Exception):
    """Generic STU client error.""" 
[docs]class ClientValueError(ValueError):
    """
    Arguments to a client function are incorrect, or not fully determined.
    """ 
#########################
#   Convenience functions
#########################
[docs]def create(username, test, directory=None, upload=True):
    myclient = TestClient.instance(username)
    test.id = None
    number = myclient.create(test)
    try:
        if number:
            test.id = number
            if upload:
                myclient.upload(test, directory)
    except:
        myclient.delete(number)
        raise
    return number 
[docs]def retrieve(username, *args, **kwargs):
    myclient = TestClient.instance(username)
    return myclient.retrieve(*args, **kwargs) 
[docs]def update(username, test, directory=None, upload=True):
    myclient = TestClient.instance(username)
    status = myclient.update(test)
    if status and upload:
        status = myclient.upload(test, directory)
    return status 
[docs]def delete(username, test):
    myclient = TestClient.instance(username)
    return myclient.delete(test) 
[docs]def download(username, test, directory=None, overwrite=True):
    myclient = TestClient.instance(username)
    return myclient.download(test, directory, overwrite) 
[docs]class ResultReporter:
    """
    Reporter that will upload results and files for a specific test run.
    (This is the `report` method.) Also marks runs as complete, with the
    option to send an email about test failures to the interested users (the
    `completeRun` method).
    """
[docs]    def __init__(self,
                 buildtype,
                 build_id,
                 mmshare,
                 local_system,
                 remote_system,
                 username=None,
                 release=None,
                 build_log_address=None,
                 comment=None,
                 base_url=None,
                 api_version=_VERSION):
        self.client = _BaseClient(username, base_url, api_version)
        release = release or local_system.release
        self._local_system = local_system
        self._remote_system = remote_system
        self._localhost = None
        self._remotehost = None
        self.run = self._createRun(mmshare, release, buildtype, build_id,
                                   build_log_address, comment) 
[docs]    def report(self, test, upload=True, files=None):
        """Report the result of ONE test and upload its files."""
        data = self._getTestData(test)
        with tempfile.TemporaryFile('w+b') as content:
            if upload:
                size = 0
                try:
                    if files is None:
                        common.zip_directory(test.directory, content)
                    else:
                        common.zip_files(content, test.directory, files)
                    size = content.tell()
                    # Verify zip contents
                    content.seek(0)
                    common.verify_zip(content)
                except common.ZipError as err:
                    # Roundtrip string to utf-8 for prep for server
                    err_name = err.__doc__.encode(
                        'utf-8', errors="replace").decode('utf-8')
                    set_failure_data(data, err_name, str(err))
                content.seek(0)
            if upload and size:
                # BLDMGR-3781 apparently requests doesn't read the HTTP response
                # until it finishes the upload. This is a problem when it tries to
                # upload a file larger than 2000 MB (the body limit set in nginx)
                # because it just comes back with connection reset by peer, rather
                # than HTTP 413 (Request Entity Too Large). Check file size here
                # and fail the test if over the limit.
                limit = 2000 * 1024 * 1024
                if size > limit:
                    err_message = (
                        'Test results are too large, they must not '
                        'exceed 2000 MB when compressed to a ZIP file.'
                        'Current ZIP size is {:0.2f} MB, {} bytes too '
                        'large. Edit your test to decrease the result '
                        'size.'.format(
                            float(size) / (1024.0 * 1024.0), size - limit))
                    err_name = 'Result exceeds size limit'
                    set_failure_data(data, err_name, err_message)
            data = json.dumps(data)
            try:
                response = self.client.post(self.client.item_uri('outcome'),
                                            data=data)
            except requests.exceptions.HTTPError as err:
                if err.response.status_code == 410:
                    logger.debug(f' {test} was deleted. Skipping report.')
                    return
                raise
            except:
                msg = 'Failed to report for {}. Requested data: {}'.format(
                    test, data)
                logger.exception(msg)
                raise
            if not (upload and size):
                return
            address = response.headers['location'] + 'upload/'
            self.client.post(address,
                             data=content,
                             headers={
                                 'content-encoding': 'application/zip',
                                 'Content-Length': str(size),
                             }) 
[docs]    def userRunURL(self):
        """Return the URL where users can go to see details about this run."""
        url = self.client._base + self.run
        url = url.replace('api/v1/', '')
        return url 
[docs]    def completeRun(self, duration, email=False):
        """
        Record that the run is complete, and include the total duration.  If
        email is True, trigger an email about test failures.
        """
        data = dict(total_time=duration, complete=True)
        data = json.dumps(data)
        if email:
            # Doesn't actually create a resource, so status code is not 201.
            self.client.post(self.run + 'email/',
                             data=data,
                             required_statuses=None)
        else:
            self.client.patch(self.run, data=data) 
    def _getOrCreateBuild(self, mmshare, release, buildtype, build_id):
        data = dict(mmshare=mmshare,
                    release=release,
                    buildtype=buildtype,
                    build_id=build_id)
        if not all(data.values()):
            raise ClientValueError('Build is not fully identified - some '
                                   'required values are blank. Values: %s' %
                                   data)
        build_api = self.client.item_uri('build')
        response = self.client.get(build_api, params=data, data=data)
        response = response.json()
        if response['meta']['total_count'] > 1:
            msg = (
                'More than one build found: %s' %
                ', '.join(val['resource_uri'] for val in response['objects']))
            raise ClientValueError(msg)
        elif not response['meta']['total_count']:
            flat_git_hashes = {}
            for product_hashes in machid.get_product_git_hashes().values():
                flat_git_hashes.update(product_hashes)
            data['_git_hashes'] = flat_git_hashes
            response = self.client.post(build_api,
                                        data=json.dumps(data),
                                        required_statuses=(201, 303))
            response = response.json()
            return response['resource_uri']
        else:
            return response['objects'][0]['resource_uri']
    def _createRun(self, mmshare, release, buildtype, build_id,
                   build_log_address, comments):
        build = self._getOrCreateBuild(mmshare, release, buildtype, build_id)
        data = dict(build=build,
                    build_log_address=build_log_address,
                    build_results_dir=None,
                    platform=self._local_system.platform,
                    executor=self.client._username,
                    comments=comments,
                    schrodinger=self._local_system.schrodinger,
                    total_time=0,
                    complete=False,
                    localhost=self.localhost,
                    remotehost=self.remotehost)
        data = json.dumps(data)
        response = self.client.post(self.client.item_uri('run'), data=data)
        return self.client.getResourceURI(response)
    @property
    def localhost(self):
        if not self._localhost:
            self._localhost = self._getOrCreateHost(
                **self._local_system.toDict())
        return self._localhost
    @property
    def remotehost(self):
        if not self._remotehost:
            if self._localhost and self._local_system == self._remote_system:
                self._remotehost = self._localhost
            else:
                self._remotehost = self._getOrCreateHost(
                    **self._remote_system.toDict())
        return self._remotehost
    def _getOrCreateHost(self, **kwargs):
        response = self.client.get('host', params=dict(limit=1, **kwargs))
        objects = response.json()['objects']
        if objects:
            return objects[0]['resource_uri']
        response = self.client.post('host', data=json.dumps(kwargs))
        location = self.client.getResourceURI(response)
        return location
    def _getTestData(self, test):
        local = not test.runsRemotely()
        if len(test.workup_messages) > MAX_MESSAGE_LENGTH:
            suffix = '\n\nTRUNCATED due to message length > {}.\n'
            suffix = suffix.format(MAX_MESSAGE_LENGTH)
            workup_messages = test.workup_messages[:MAX_MESSAGE_LENGTH]
            workup_messages += suffix
        else:
            workup_messages = test.workup_messages
        post_data = dict(test=test.resource_uri,
                         passed=test.outcome,
                         run_time_sec=test.timing,
                         run=self.run,
                         workup_messages=workup_messages,
                         failure_type=dict(name=test.failure_type))
        if local:
            post_data['host'] = self.localhost
        else:
            post_data['host'] = self.remotehost
        return post_data 
[docs]def set_failure_data(data, error_name, error_message):
    """
    Modify test data in-place to indicate a test failure to the STU API.
    """
    data['passed'] = False
    data['workup_messages'] = '{}\n\n\n{}'.format(error_message,
                                                  data['workup_messages'])
    data['failure_type'] = error_name 
[docs]def raise_for_status(response):
    """
    If `response` has a bad status, raise an Exception. First, however,
    be sure to print any data available from the exception.
    :type response: requests.models.Response
    :param response: Response to check for exit status problems.
    """
    try:
        response.raise_for_status()
    except requests.exceptions.HTTPError as err:
        err_args = list(err.args)
        try:
            data = response.json()
            error = data.get('error_message', None)
            error = error or data.get('error', None)
            error = error or response.reason
            if error and not err_args[0].endswith(error):
                err_args[0] = '{}, {}'.format(err_args[0], error)
            print('Server ' + data.get('traceback', '').replace('\n\n', '\n'))
        # Response does not contain valid json
        except ValueError:
            description = getattr(response, 'text', False)
            if description and not err_args[0].endswith(description):
                err_args[0] = '{}, {}'.format(err_args[0], description)
        err_args[0] += f" ({response.url} : {response.headers} )"
        err.args = tuple(err_args)
        raise err 
[docs]def check_status(response, required_statuses):
    """
    Raise an error if the status does not match `required_status`.
    :type required_status: int
    :param required_status: Status to match
    """
    if response.status_code not in required_statuses:
        msg = ('Response %s (%s) did not match required status "%s"' %
               (response.reason, response.status_code, required_statuses))
        # Match exception class of raise_for_status
        raise requests.exceptions.HTTPError(msg)
    return True 
[docs]def fatal_status(exception):
    """
    This method should return True only if exception is an HTTPError, and the
    status code is NOT retriable.
    """
    if not isinstance(exception, requests.exceptions.HTTPError):
        return False
    # HTTPErrors should normally have a response, but if the calling code does
    # "raise HTTPError()", the response member is initialized to None
    if not hasattr(exception.response, "status_code"):
        return False
    return exception.response.status_code not in RETRIABLE_STATUS_CODES 
[docs]class ApiKeyAuth(requests.auth.AuthBase):
    """
    An authorization method that uses an api key.
    """
[docs]    def __init__(self, username):
        api_key = common.get_api_key()
        self._auth = f'ApiKey {username}:{api_key}' 
    def __call__(self, r):
        r.headers['Authorization'] = self._auth
        return r 
class _BaseClient:
    """
    Base client class for submitting requests to a server.
    Adds defaults for server url, authentication, and retrying.
    """
    _INSTANCE = None
    """
    Singleton instance of the client to simplify access for apps that always
    use the defaults.
    """
    def __init__(self, username=None, base_url=None, api_version=_VERSION):
        if not base_url:
            base_url = common.BASE_URL
        self._base = base_url.rstrip('/')
        self._api = '/api/v%s/' % api_version
        self._auth = None
        self._username = username
        self._session = requests.Session()
    @classmethod
    def instance(cls, username, base_url=None, api_version=_VERSION):
        if not cls._INSTANCE:
            cls._INSTANCE = cls(username, base_url, api_version)
        if username != cls._INSTANCE._username:
            cls._INSTANCE._auth = None
        cls._INSTANCE._username = username
        return cls._INSTANCE
    @property
    def auth(self):
        if not self._auth:
            self._auth = ApiKeyAuth(self._username)
        return self._auth
    @property
    def fullapi(self):
        return self._base + self._api
    def item_uri(self, item):
        return f"{self.fullapi}{item}/"
    def list2str(self, items):
        return ','.join(str(x) for x in items)
    def safename(self, oldname):
        """
        Ensure that the name is safe and that path separators are consistent on
        Linux and Windows.
        :type oldname: str
        :param oldname: String to be protected
        :rtype: str
        :return: String with with all path separators replaced by / and all
                non-URL-safe characters protected.
        """
        newname = oldname.replace(os.path.sep, '/')
        newname = newname.replace('\\', '/')
        return urllib.parse.quote(newname, '')
    @backoff.on_exception(backoff.expo,
                          RETRIABLE_EXCEPTIONS,
                          max_tries=4,
                          giveup=fatal_status,
                          logger=logger,
                          factor=30)
    def _request(self,
                 method,
                 uri,
                 headers=_JSON_FORMAT,
                 auth=None,
                 required_statuses=None,
                 verify=False,
                 **kwargs):
        if auth is None:
            auth = self.auth
        if not uri.startswith('http'):
            if uri.startswith('/'):
                uri = self._base + uri
            else:
                uri = self.item_uri(uri)
        response = method(uri,
                          headers=headers,
                          auth=auth,
                          verify=verify,
                          timeout=(CONNECTION_TIMEOUT_SECONDS,
                                   READ_TIMEOUT_SECONDS),
                          **kwargs)
        raise_for_status(response)
        if required_statuses:
            check_status(response, required_statuses)
        return response
    def post(self, uri, required_statuses=(201,), **kwargs):
        return self._request(self._session.post,
                             uri,
                             required_statuses=required_statuses,
                             **kwargs)
    def put(self, uri, **kwargs):
        return self._request(self._session.put, uri, **kwargs)
    def patch(self, uri, required_statuses=(202,), **kwargs):
        return self._request(self._session.patch,
                             uri,
                             required_statuses=required_statuses,
                             **kwargs)
    def get(self, uri, **kwargs):
        return self._request(self._session.get, uri, **kwargs)
    def delete(self, uri, **kwargs):
        return self._request(self._session.delete, uri, **kwargs)
    def getResourceURI(self, response=None):
        """Get an API rooted address from a full address."""
        location = response.headers['location']
        return location.replace(self._base, '')
[docs]class TestClient(_BaseClient):
    """
    Interact with the Test server.  Create, Retrieve, Update, and Delete tests.
    Also upload and download the associated files.
    """
[docs]    def __init__(self, username=None, base_url=None, api_version=_VERSION):
        super().__init__(username=username,
                         base_url=base_url,
                         api_version=api_version)
        self._testapi = self.item_uri('test')
        self._sharedapi = self.item_uri('sharedfile') 
[docs]    def create(self, test):
        """
        Create a test.
        """
        test_dict = test.toDict()
        post_data = json.dumps(test_dict)
        response = self.post(self._testapi, data=post_data)
        return response.json()['number'] 
[docs]    def retrieve(
        self,
        test_ids=None,
        products=None,
        components=None,
        priorities=None,
        tags=None,
        not_products=tuple(),  # noqa: M511
        not_components=tuple(),  # noqa: M511
        not_tags=tuple()):  # noqa: M511
        """
        Retrieve tests based on some limiting criteria.  The only method that
        operates on more than one test at a time.  A bit weird?
        """
        query_dict = dict(limit=0)
        if test_ids:
            query_dict['number__in'] = self.list2str(test_ids)
        if products:
            # Use the custom STU query that searches JIRA names too
            query_dict['products'] = self.list2str(products)
        if components:
            query_dict['component__in'] = self.list2str(components)
        if priorities:
            query_dict['priority__in'] = self.list2str(priorities)
        if tags:
            query_dict['tags__in'] = self.list2str(tags)
        not_products = [product.lower() for product in not_products or tuple()]
        not_components = [
            component.lower() for component in not_components or tuple()
        ]
        not_tags = [tag.lower() for tag in not_tags or tuple()]
        response = self.get(self._testapi, params=query_dict)
        data = response.json()
        if data['meta']['total_count'] != len(data['objects']):
            raise ClientError(
                'Did not download all tests! (%s requested. %s '
                'found' % (data['meta']['total_count'], len(data['objects'])))
        for test_data in data['objects']:
            # Filter tests that use products, tags and subfeatures from the
            # skip lists.
            if test_data['product'].lower() in not_products:
                continue
            if test_data['component'] and test_data['component'].lower(
            ) in not_components:
                continue
            matches = [tag for tag in test_data['tags'] if tag in not_tags]
            if matches:
                if has_scival_tags(test_data['tags']):
                    msg = f'Skipping scival test: {test_data["number"]}'
                    if test_ids and test_data['number'] in test_ids:
                        logger.critical(msg)
                    else:
                        logger.debug(msg)
                continue
            # Backwards compatibility for SHARED-3063
            if 'jira_tracking' not in test_data and test_data.get(
                    'bug_disabled'):
                test_data['disabled_for_bug'] = True
                test_data['jira_tracking'] = test_data.pop('bug_disabled')
            test = testscripts.TestScript(**test_data)
            yield test 
[docs]    def find_one(self, criteria):
        """
        Find a single STU test that matches the search criteria.
        :param criteria: Search criteria, uses the exact names of parameters on
            the server. (does not interpret them in the way that retrieve does)
        :type criteria: dict
        :raise ClientError: If more than one test is found.
        :raise IndexError: If no tests are found.
        """
        response = self.get(self._testapi, params=criteria)
        data = response.json()
        total_count = data['meta']['total_count']
        if total_count > 1:
            raise ClientError(
                'Found too many tests matching criteria {}. Found {} tests'.
                format(criteria, total_count))
        try:
            test_data = data['objects'][0]
        except IndexError:
            raise IndexError(
                'Did not find a test matching the criteria: {}'.format(
                    criteria))
        return testscripts.TestScript(**test_data) 
[docs]    def get_or_create(self, search_criteria, creation_data=None):
        """
        Find a test corresponding to search_criteria. If no test is found,
        create a test using creation_data.
        :param search_criteria: Search criteria.
        :type search_criteria: dict
        :param creation_data: Data to be used to create the test if none exists.
        :type creation_data: dict
        """
        try:
            test = self.find_one(search_criteria)
        except IndexError:
            # Doesn't exist yet.
            if creation_data is None:
                creation_data = search_criteria
            test = testscripts.TestScript(**creation_data)
            creation_data = creation_data.copy()
            creation_data['number'] = self.create(test)
            # Need to get again to have access to the URL for the test.
            test = self.find_one(creation_data)
            assert test.id
        return test 
[docs]    def update(self, test):
        """
        Update files and metadata for a test.
        """
        test_dict = test.toDict()
        post_data = json.dumps(test_dict)
        self.put(self._testapi + '%s/' % test.id, data=post_data)
        return True 
[docs]    def delete(self, test):
        """
        Delete a test metadata and files.
        """
        if hasattr(test, 'id'):
            testid = test.id
        else:
            testid = int(test)
        super().delete(self._testapi + "%s/" % testid)
        return True 
[docs]    def download(self, test, directory=None, overwrite=True):
        """
        Download and extract the files associated with `test`. Overwrites
        existing contents of the directory.
        """
        if has_scival_tags(test.tags):
            raise ClientValueError(f'Cannot download scival test: {test}.')
        directory = self._getTestDir(test, directory, overwrite)
        self._getSharedFiles(test.shared_files, overwrite)
        return directory 
[docs]    def upload(self, test, directory=None):
        """
        Upload a directory to the server.  If the test has a directory
        attribute, use that.  Otherwise, based on the test number.
        """
        self._setSharedFiles(test.shared_files)
        self._setTestDir(test, directory)
        return True 
    def _setTestDir(self, test, directory=None):
        path = directory or test.directory or str(test.id)
        test.directory = path
        content = io.BytesIO()
        common.zip_directory(path, fileobj=content, skipped_files={'README'})
        content = io.BytesIO(content.getvalue())
        files = {'%s.zip' % test.id: content}
        if getattr(test, 'upload', None):
            uri = self._base + test.upload
        else:
            response = self.get(self._testapi, params={'number': test.id})
            uri = self._base + response.json()['objects'][0]['upload']
        response = self.post(uri,
                             files=files,
                             headers={'content-encoding': 'application/zip'},
                             required_statuses=None)
    def _setSharedFiles(self, shared_files):
        """
        Upload shared files to server.
        """
        for filename in shared_files:
            fullname = os.path.join('shared', filename)
            if os.path.isdir(fullname):
                raise RuntimeError('"%s" is a directory. Only files can be '
                                   'shared.' % fullname)
            # Check to see if it already exists on the server:
            # This metadata check should be much cheaper than sending files.
            uri = self._sharedapi + '%s' % self.safename(filename)
            response = self.get(uri)
            data = response.json()
            if data['thefile']:
                continue
            fileobj = open(fullname, 'rb')
            uri = self._base + data['upload']
            response = self.post(uri,
                                 files={'file': fileobj},
                                 headers=None,
                                 required_statuses=None)
            fileobj.close()
    def _getTestDir(self, test, directory=None, overwrite=True):
        """
        Download the zipped test file and extract it to a directory.
        """
        if not directory:
            directory = test.directory
        if overwrite and os.path.isdir(directory):
            fileutils.force_rmtree(directory)
        test.directory = directory
        if test.download:
            uri = self._base + test.download
        else:
            response = self.get(self._testapi, params={'number': test.id})
            uri = self._base + response.json()['objects'][0]['download']
        zipped = self._getZip(uri)
        if not zipped.namelist():  # empty zip file
            os.makedirs(directory, exist_ok=True)
        else:
            zipped.extractall(path=directory)
        test.write_readme(os.path.join(directory, 'README'))
        return directory
    # BLDMGR-3710 It seems that under high load, the server might return a
    # partial zip file with status 200. Retry twice before raising.
    # FIXME: remove retry logic once the server issue is fixed
    @backoff.on_exception(backoff.expo,
                          zipfile.BadZipfile,
                          max_tries=3,
                          logger=logger,
                          factor=60)
    def _getZip(self, uri):
        response = self.get(uri)
        content = io.BytesIO(response.content)
        return zipfile.ZipFile(content)
    def _getSharedFiles(self, shared_files, overwrite=False):
        """
        Download the shared files.
        """
        # Don't create the shared directory unless it is needed.
        if not shared_files:
            return None
        for filename in shared_files:
            local_name = os.path.join('shared', filename)
            if os.path.isfile(local_name) and not overwrite:
                continue
            directory = os.path.dirname(local_name)
            if directory and not os.path.isdir(directory):
                os.makedirs(directory)
            uri = self._sharedapi + '%s/download/' % filename
            response = self.get(uri)
            with open(local_name, 'wb') as fh:
                fh.write(response.content)
        return None 
[docs]def get_stu_username():
    """
    Set STU user to the current user calling this code
    """
    return os.getenv(STU_REMOTE_USER, default=getpass.getuser())