"""
Client for communication with STU server. The core functionality is in
`TestClient` and `ResultReporter`, but most is accessible from the module
level convenience functions `create`, `retrieve`, `update`, `delete`, and
`download`, which allow modification of the tests in the database.
@copyright: (c) Schrodinger, LLC. All rights reserved.
"""
import getpass
import io
import json
import os
import os.path
import sys
import tempfile
import time # noqa: F401, mocked in test
import urllib
import zipfile
import backoff
import requests
# Disable InsecurePlatformWarning until we update to > Python 2.7.9
import requests.packages.urllib3
from http import HTTPStatus
from schrodinger.test.stu import common
from schrodinger.test.stu import testscripts
from schrodinger.utils import fileutils
from schrodinger.utils import machid
# Hack for SHARED-4292
if sys.platform == "win32" or sys.platform.startswith("darwin"):
import ssl
ssl.HAS_SNI = False
requests.packages.urllib3.disable_warnings()
logger = common.logger
_VERSION = 1
"""API version to use."""
_JSON_FORMAT = {'content-type': 'application/json'}
MAX_MESSAGE_LENGTH = 10000
SCIVAL_TAGS = {'scival', 'require:scival'}
# Adjust the default number of allowed requests to the same URL.
requests.adapters.DEFAULT_RETRIES = 3
# Requests exceptions on which backoff should retry
RETRIABLE_EXCEPTIONS = (requests.exceptions.ConnectionError,
requests.exceptions.RequestException,
requests.exceptions.SSLError)
# RequestException status codes that should be retried
RETRIABLE_STATUS_CODES = (
HTTPStatus.REQUEST_TIMEOUT, # 408
HTTPStatus.LENGTH_REQUIRED, # 411
499, # not available in HTTPStatus
HTTPStatus.INTERNAL_SERVER_ERROR, # 500
HTTPStatus.BAD_GATEWAY, # 502
HTTPStatus.SERVICE_UNAVAILABLE, # 503
)
# Connection timeout should be "slightly larger than a multiple of 3, which is
# the default TCP packet retransmission window."
# (see http://docs.python-requests.org/en/master/user/advanced/#timeouts)
CONNECTION_TIMEOUT_SECONDS = 60.1
READ_TIMEOUT_SECONDS = 3600
STU_REMOTE_USER = "STU_REMOTE_USER"
[docs]class ClientError(Exception):
"""Generic STU client error."""
[docs]class ClientValueError(ValueError):
"""
Arguments to a client function are incorrect, or not fully determined.
"""
#########################
# Convenience functions
#########################
[docs]def create(username, test, directory=None, upload=True):
myclient = TestClient.instance(username)
test.id = None
number = myclient.create(test)
try:
if number:
test.id = number
if upload:
myclient.upload(test, directory)
except:
myclient.delete(number)
raise
return number
[docs]def retrieve(username, *args, **kwargs):
myclient = TestClient.instance(username)
return myclient.retrieve(*args, **kwargs)
[docs]def update(username, test, directory=None, upload=True):
myclient = TestClient.instance(username)
status = myclient.update(test)
if status and upload:
status = myclient.upload(test, directory)
return status
[docs]def delete(username, test):
myclient = TestClient.instance(username)
return myclient.delete(test)
[docs]def download(username, test, directory=None, overwrite=True):
myclient = TestClient.instance(username)
return myclient.download(test, directory, overwrite)
[docs]class ResultReporter:
"""
Reporter that will upload results and files for a specific test run.
(This is the `report` method.) Also marks runs as complete, with the
option to send an email about test failures to the interested users (the
`completeRun` method).
"""
[docs] def __init__(self,
buildtype,
build_id,
mmshare,
local_system,
remote_system,
username=None,
release=None,
build_log_address=None,
comment=None,
base_url=None,
api_version=_VERSION):
self.client = _BaseClient(username, base_url, api_version)
release = release or local_system.release
self._local_system = local_system
self._remote_system = remote_system
self._localhost = None
self._remotehost = None
self.run = self._createRun(mmshare, release, buildtype, build_id,
build_log_address, comment)
[docs] def report(self, test, upload=True, files=None):
"""Report the result of ONE test and upload its files."""
data = self._getTestData(test)
with tempfile.TemporaryFile('w+b') as content:
if upload:
size = 0
try:
if files is None:
common.zip_directory(test.directory, content)
else:
common.zip_files(content, test.directory, files)
size = content.tell()
# Verify zip contents
content.seek(0)
common.verify_zip(content)
except common.ZipError as err:
# Roundtrip string to utf-8 for prep for server
err_name = err.__doc__.encode(
'utf-8', errors="replace").decode('utf-8')
set_failure_data(data, err_name, str(err))
content.seek(0)
if upload and size:
# BLDMGR-3781 apparently requests doesn't read the HTTP response
# until it finishes the upload. This is a problem when it tries to
# upload a file larger than 2000 MB (the body limit set in nginx)
# because it just comes back with connection reset by peer, rather
# than HTTP 413 (Request Entity Too Large). Check file size here
# and fail the test if over the limit.
limit = 2000 * 1024 * 1024
if size > limit:
err_message = (
'Test results are too large, they must not '
'exceed 2000 MB when compressed to a ZIP file.'
'Current ZIP size is {:0.2f} MB, {} bytes too '
'large. Edit your test to decrease the result '
'size.'.format(
float(size) / (1024.0 * 1024.0), size - limit))
err_name = 'Result exceeds size limit'
set_failure_data(data, err_name, err_message)
data = json.dumps(data)
try:
response = self.client.post(self.client.item_uri('outcome'),
data=data)
except requests.exceptions.HTTPError as err:
if err.response.status_code == 410:
logger.debug(f' {test} was deleted. Skipping report.')
return
raise
except:
msg = 'Failed to report for {}. Requested data: {}'.format(
test, data)
logger.exception(msg)
raise
if not (upload and size):
return
address = response.headers['location'] + 'upload/'
self.client.post(address,
data=content,
headers={
'content-encoding': 'application/zip',
})
[docs] def userRunURL(self):
"""Return the URL where users can go to see details about this run."""
url = self.client._base + self.run
url = url.replace('api/v1/', '')
return url
[docs] def completeRun(self, duration, email=False):
"""
Record that the run is complete, and include the total duration. If
email is True, trigger an email about test failures.
"""
data = dict(total_time=duration, complete=True)
data = json.dumps(data)
if email:
# Doesn't actually create a resource, so status code is not 201.
self.client.post(self.run + 'email/',
data=data,
required_statuses=None)
else:
self.client.patch(self.run, data=data)
def _getOrCreateBuild(self, mmshare, release, buildtype, build_id):
data = dict(mmshare=mmshare,
release=release,
buildtype=buildtype,
build_id=build_id)
if not all(data.values()):
raise ClientValueError('Build is not fully identified - some '
'required values are blank. Values: %s' %
data)
build_api = self.client.item_uri('build')
response = self.client.get(build_api, params=data, data=data)
response = response.json()
if response['meta']['total_count'] > 1:
msg = (
'More than one build found: %s' %
', '.join(val['resource_uri'] for val in response['objects']))
raise ClientValueError(msg)
elif not response['meta']['total_count']:
flat_git_hashes = {}
for product_hashes in machid.get_product_git_hashes().values():
flat_git_hashes.update(product_hashes)
data['_git_hashes'] = flat_git_hashes
response = self.client.post(build_api,
data=json.dumps(data),
required_statuses=(201, 303))
response = response.json()
return response['resource_uri']
else:
return response['objects'][0]['resource_uri']
def _createRun(self, mmshare, release, buildtype, build_id,
build_log_address, comments):
build = self._getOrCreateBuild(mmshare, release, buildtype, build_id)
data = dict(build=build,
build_log_address=build_log_address,
build_results_dir=None,
platform=self._local_system.platform,
executor=self.client._username,
comments=comments,
schrodinger=self._local_system.schrodinger,
total_time=0,
complete=False,
localhost=self.localhost,
remotehost=self.remotehost)
data = json.dumps(data)
response = self.client.post(self.client.item_uri('run'), data=data)
return self.client.getResourceURI(response)
@property
def localhost(self):
if not self._localhost:
self._localhost = self._getOrCreateHost(
**self._local_system.toDict())
return self._localhost
@property
def remotehost(self):
if not self._remotehost:
if self._localhost and self._local_system == self._remote_system:
self._remotehost = self._localhost
else:
self._remotehost = self._getOrCreateHost(
**self._remote_system.toDict())
return self._remotehost
def _getOrCreateHost(self, **kwargs):
response = self.client.get('host', params=dict(limit=1, **kwargs))
objects = response.json()['objects']
if objects:
return objects[0]['resource_uri']
response = self.client.post('host', data=json.dumps(kwargs))
location = self.client.getResourceURI(response)
return location
def _getTestData(self, test):
local = not test.runsRemotely()
if len(test.workup_messages) > MAX_MESSAGE_LENGTH:
suffix = '\n\nTRUNCATED due to message length > {}.\n'
suffix = suffix.format(MAX_MESSAGE_LENGTH)
workup_messages = test.workup_messages[:MAX_MESSAGE_LENGTH]
workup_messages += suffix
else:
workup_messages = test.workup_messages
post_data = dict(test=test.resource_uri,
passed=test.outcome,
run_time_sec=test.timing,
run=self.run,
workup_messages=workup_messages,
failure_type=dict(name=test.failure_type))
if local:
post_data['host'] = self.localhost
else:
post_data['host'] = self.remotehost
return post_data
[docs]def set_failure_data(data, error_name, error_message):
"""
Modify test data in-place to indicate a test failure to the STU API.
"""
data['passed'] = False
data['workup_messages'] = '{}\n\n\n{}'.format(error_message,
data['workup_messages'])
data['failure_type'] = error_name
[docs]def raise_for_status(response):
"""
If `response` has a bad status, raise an Exception. First, however,
be sure to print any data available from the exception.
:type response: requests.models.Response
:param response: Response to check for exit status problems.
"""
try:
response.raise_for_status()
except requests.exceptions.HTTPError as err:
err_args = list(err.args)
try:
data = response.json()
error = data.get('error_message', None)
error = error or data.get('error', None)
error = error or response.reason
if error and not err_args[0].endswith(error):
err_args[0] = '{}, {}'.format(err_args[0], error)
print('Server ' + data.get('traceback', '').replace('\n\n', '\n'))
# Response does not contain valid json
except ValueError:
description = getattr(response, 'text', False)
if description and not err_args[0].endswith(description):
err_args[0] = '{}, {}'.format(err_args[0], description)
err_args[0] += f" ({response.url} : {response.headers} )"
err.args = tuple(err_args)
raise err
[docs]def check_status(response, required_statuses):
"""
Raise an error if the status does not match `required_status`.
:type required_status: int
:param required_status: Status to match
"""
if response.status_code not in required_statuses:
msg = ('Response %s (%s) did not match required status "%s"' %
(response.reason, response.status_code, required_statuses))
# Match exception class of raise_for_status
raise requests.exceptions.HTTPError(msg)
return True
[docs]def fatal_status(exception):
"""
This method should return True only if exception is an HTTPError, and the
status code is NOT retriable.
"""
if not isinstance(exception, requests.exceptions.HTTPError):
return False
# HTTPErrors should normally have a response, but if the calling code does
# "raise HTTPError()", the response member is initialized to None
if not hasattr(exception.response, "status_code"):
return False
return exception.response.status_code not in RETRIABLE_STATUS_CODES
[docs]class ApiKeyAuth(requests.auth.AuthBase):
"""
An authorization method that uses an api key.
"""
[docs] def __init__(self, username):
api_key = common.get_api_key()
self._auth = f'ApiKey {username}:{api_key}'
def __call__(self, r):
r.headers['Authorization'] = self._auth
return r
class _BaseClient:
"""
Base client class for submitting requests to a server.
Adds defaults for server url, authentication, and retrying.
"""
_INSTANCE = None
"""
Singleton instance of the client to simplify access for apps that always
use the defaults.
"""
def __init__(self, username=None, base_url=None, api_version=_VERSION):
if not base_url:
base_url = common.BASE_URL
self._base = base_url.rstrip('/')
self._api = '/api/v%s/' % api_version
self._auth = None
self._username = username
self._session = requests.Session()
@classmethod
def instance(cls, username, base_url=None, api_version=_VERSION):
if not cls._INSTANCE:
cls._INSTANCE = cls(username, base_url, api_version)
if username != cls._INSTANCE._username:
cls._INSTANCE._auth = None
cls._INSTANCE._username = username
return cls._INSTANCE
@property
def auth(self):
if not self._auth:
self._auth = ApiKeyAuth(self._username)
return self._auth
@property
def fullapi(self):
return self._base + self._api
def item_uri(self, item):
return f"{self.fullapi}{item}/"
def list2str(self, items):
return ','.join(str(x) for x in items)
def safename(self, oldname):
"""
Ensure that the name is safe and that path separators are consistent on
Linux and Windows.
:type oldname: str
:param oldname: String to be protected
:rtype: str
:return: String with with all path separators replaced by / and all
non-URL-safe characters protected.
"""
newname = oldname.replace(os.path.sep, '/')
newname = newname.replace('\\', '/')
return urllib.parse.quote(newname, '')
@backoff.on_exception(backoff.expo,
RETRIABLE_EXCEPTIONS,
max_tries=4,
giveup=fatal_status,
logger=logger,
factor=30)
def _request(self,
method,
uri,
headers=_JSON_FORMAT,
auth=None,
required_statuses=None,
verify=False,
**kwargs):
if auth is None:
auth = self.auth
if not uri.startswith('http'):
if uri.startswith('/'):
uri = self._base + uri
else:
uri = self.item_uri(uri)
response = method(uri,
headers=headers,
auth=auth,
verify=verify,
timeout=(CONNECTION_TIMEOUT_SECONDS,
READ_TIMEOUT_SECONDS),
**kwargs)
raise_for_status(response)
if required_statuses:
check_status(response, required_statuses)
return response
def post(self, uri, required_statuses=(201,), **kwargs):
return self._request(self._session.post,
uri,
required_statuses=required_statuses,
**kwargs)
def put(self, uri, **kwargs):
return self._request(self._session.put, uri, **kwargs)
def patch(self, uri, required_statuses=(202,), **kwargs):
return self._request(self._session.patch,
uri,
required_statuses=required_statuses,
**kwargs)
def get(self, uri, **kwargs):
return self._request(self._session.get, uri, **kwargs)
def delete(self, uri, **kwargs):
return self._request(self._session.delete, uri, **kwargs)
def getResourceURI(self, response=None):
"""Get an API rooted address from a full address."""
location = response.headers['location']
return location.replace(self._base, '')
[docs]class TestClient(_BaseClient):
"""
Interact with the Test server. Create, Retrieve, Update, and Delete tests.
Also upload and download the associated files.
"""
[docs] def __init__(self, username=None, base_url=None, api_version=_VERSION):
super().__init__(username=username,
base_url=base_url,
api_version=api_version)
self._testapi = self.item_uri('test')
self._sharedapi = self.item_uri('sharedfile')
[docs] def create(self, test):
"""
Create a test.
"""
test_dict = test.toDict()
post_data = json.dumps(test_dict)
response = self.post(self._testapi, data=post_data)
return response.json()['number']
[docs] def retrieve(
self,
test_ids=None,
products=None,
components=None,
priorities=None,
tags=None,
not_products=tuple(), # noqa: M511
not_components=tuple(), # noqa: M511
not_tags=tuple()): # noqa: M511
"""
Retrieve tests based on some limiting criteria. The only method that
operates on more than one test at a time. A bit weird?
"""
query_dict = dict(limit=0)
if test_ids:
query_dict['number__in'] = self.list2str(test_ids)
if products:
# Use the custom STU query that searches JIRA names too
query_dict['products'] = self.list2str(products)
if components:
query_dict['component__in'] = self.list2str(components)
if priorities:
query_dict['priority__in'] = self.list2str(priorities)
if tags:
query_dict['tags__in'] = self.list2str(tags)
not_products = [product.lower() for product in not_products or tuple()]
not_components = [
component.lower() for component in not_components or tuple()
]
not_tags = [tag.lower() for tag in not_tags or tuple()]
response = self.get(self._testapi, params=query_dict)
data = response.json()
if data['meta']['total_count'] != len(data['objects']):
raise ClientError(
'Did not download all tests! (%s requested. %s '
'found' % (data['meta']['total_count'], len(data['objects'])))
for test_data in data['objects']:
# Filter tests that use products, tags and subfeatures from the
# skip lists.
if test_data['product'].lower() in not_products:
continue
if test_data['component'] and test_data['component'].lower(
) in not_components:
continue
matches = [tag for tag in test_data['tags'] if tag in not_tags]
if matches:
if has_scival_tags(test_data['tags']):
msg = f'Skipping scival test: {test_data["number"]}'
if test_ids and test_data['number'] in test_ids:
logger.critical(msg)
else:
logger.debug(msg)
continue
# Backwards compatibility for SHARED-3063
if 'jira_tracking' not in test_data and test_data.get(
'bug_disabled'):
test_data['disabled_for_bug'] = True
test_data['jira_tracking'] = test_data.pop('bug_disabled')
test = testscripts.TestScript(**test_data)
yield test
[docs] def find_one(self, criteria):
"""
Find a single STU test that matches the search criteria.
:param criteria: Search criteria, uses the exact names of parameters on
the server. (does not interpret them in the way that retrieve does)
:type criteria: dict
:raise ClientError: If more than one test is found.
:raise IndexError: If no tests are found.
"""
response = self.get(self._testapi, params=criteria)
data = response.json()
total_count = data['meta']['total_count']
if total_count > 1:
raise ClientError(
'Found too many tests matching criteria {}. Found {} tests'.
format(criteria, total_count))
try:
test_data = data['objects'][0]
except IndexError:
raise IndexError(
'Did not find a test matching the criteria: {}'.format(
criteria))
return testscripts.TestScript(**test_data)
[docs] def get_or_create(self, search_criteria, creation_data=None):
"""
Find a test corresponding to search_criteria. If no test is found,
create a test using creation_data.
:param search_criteria: Search criteria.
:type search_criteria: dict
:param creation_data: Data to be used to create the test if none exists.
:type creation_data: dict
"""
try:
test = self.find_one(search_criteria)
except IndexError:
# Doesn't exist yet.
if creation_data is None:
creation_data = search_criteria
test = testscripts.TestScript(**creation_data)
creation_data = creation_data.copy()
creation_data['number'] = self.create(test)
# Need to get again to have access to the URL for the test.
test = self.find_one(creation_data)
assert test.id
return test
[docs] def update(self, test):
"""
Update files and metadata for a test.
"""
test_dict = test.toDict()
post_data = json.dumps(test_dict)
self.put(self._testapi + '%s/' % test.id, data=post_data)
return True
[docs] def delete(self, test):
"""
Delete a test metadata and files.
"""
if hasattr(test, 'id'):
testid = test.id
else:
testid = int(test)
super().delete(self._testapi + "%s/" % testid)
return True
[docs] def download(self, test, directory=None, overwrite=True):
"""
Download and extract the files associated with `test`. Overwrites
existing contents of the directory.
"""
if has_scival_tags(test.tags):
raise ClientValueError(f'Cannot download scival test: {test}.')
directory = self._getTestDir(test, directory, overwrite)
self._getSharedFiles(test.shared_files, overwrite)
return directory
[docs] def upload(self, test, directory=None):
"""
Upload a directory to the server. If the test has a directory
attribute, use that. Otherwise, based on the test number.
"""
self._setSharedFiles(test.shared_files)
self._setTestDir(test, directory)
return True
def _setTestDir(self, test, directory=None):
path = directory or test.directory or str(test.id)
test.directory = path
content = io.BytesIO()
common.zip_directory(path, fileobj=content, skipped_files={'README'})
content = io.BytesIO(content.getvalue())
files = {'%s.zip' % test.id: content}
if getattr(test, 'upload', None):
uri = self._base + test.upload
else:
response = self.get(self._testapi, params={'number': test.id})
uri = self._base + response.json()['objects'][0]['upload']
response = self.post(uri,
files=files,
headers={'content-encoding': 'application/zip'},
required_statuses=None)
def _setSharedFiles(self, shared_files):
"""
Upload shared files to server.
"""
for filename in shared_files:
fullname = os.path.join('shared', filename)
if os.path.isdir(fullname):
raise RuntimeError('"%s" is a directory. Only files can be '
'shared.' % fullname)
# Check to see if it already exists on the server:
# This metadata check should be much cheaper than sending files.
uri = self._sharedapi + '%s' % self.safename(filename)
response = self.get(uri)
data = response.json()
if data['thefile']:
continue
fileobj = open(fullname, 'rb')
uri = self._base + data['upload']
response = self.post(uri,
files={'file': fileobj},
headers=None,
required_statuses=None)
fileobj.close()
def _getTestDir(self, test, directory=None, overwrite=True):
"""
Download the zipped test file and extract it to a directory.
"""
if not directory:
directory = test.directory
if overwrite and os.path.isdir(directory):
fileutils.force_rmtree(directory)
test.directory = directory
if test.download:
uri = self._base + test.download
else:
response = self.get(self._testapi, params={'number': test.id})
uri = self._base + response.json()['objects'][0]['download']
zipped = self._getZip(uri)
if not zipped.namelist(): # empty zip file
os.makedirs(directory, exist_ok=True)
else:
zipped.extractall(path=directory)
test.write_readme(os.path.join(directory, 'README'))
return directory
# BLDMGR-3710 It seems that under high load, the server might return a
# partial zip file with status 200. Retry twice before raising.
# FIXME: remove retry logic once the server issue is fixed
@backoff.on_exception(backoff.expo,
zipfile.BadZipfile,
max_tries=3,
logger=logger,
factor=60)
def _getZip(self, uri):
response = self.get(uri)
content = io.BytesIO(response.content)
return zipfile.ZipFile(content)
def _getSharedFiles(self, shared_files, overwrite=False):
"""
Download the shared files.
"""
# Don't create the shared directory unless it is needed.
if not shared_files:
return None
for filename in shared_files:
local_name = os.path.join('shared', filename)
if os.path.isfile(local_name) and not overwrite:
continue
directory = os.path.dirname(local_name)
if directory and not os.path.isdir(directory):
os.makedirs(directory)
uri = self._sharedapi + '%s/download/' % filename
response = self.get(uri)
with open(local_name, 'wb') as fh:
fh.write(response.content)
return None
[docs]def get_stu_username():
"""
Set STU user to the current user calling this code
"""
return os.getenv(STU_REMOTE_USER, default=getpass.getuser())