"""
Steps for deduplicating and random sampling data through AWS (currently limited
to use with Mols).
"""
import gzip
import time
import itertools
import os
import re
import uuid
import more_itertools
from schrodinger.application.steps import basesteps
from schrodinger.application.steps import env_keys
from schrodinger.application.steps import utils
from schrodinger.models import parameters
from schrodinger.stepper.cloud import aws_client
from schrodinger import stepper
from schrodinger.stepper import logger
from .basesteps import MolMolMixin
try:
import boto3
except ImportError:
boto3 = None
# ==============================================================================
# Error messages
# ==============================================================================
INSUFFICIENT_DATABASE_SETTINGS = (
'Cannot create new Redshift table with insufficient database settings:'
'\ncluster_id={0}\ndatabase={1}\ndatabase_user={2}')
MISSING_BUCKET = ('AWS Redshift requires use of S3 buckets and bucket name '
'must be set through environment variables. See '
'`schrodinger.applciation.steps.env_keys` for more details.')
QUERY_ERROR = "Errors while running query {0}\n{1}"
INVALID_TABLE_NAME = ("Invalid table name specified. Table names must only "
"contain alphanumeric values and underscores, and must "
"also begin with a non-digit value.")
# ==============================================================================
# CONSTANTS
# ==============================================================================
REDSHIFT_CLIENT = None
S3_CLIENT = None
_MAX_FILES = 1000
# Default maximum size of one file when exporting from Redshift in MB
_MAX_FILE_SIZE = 5.0 # MB
# ==============================================================================
# AWS Clients
# ==============================================================================
def _get_redshift_client():
global REDSHIFT_CLIENT
if REDSHIFT_CLIENT is None:
REDSHIFT_CLIENT = aws_client.get_client('redshift-data')
return REDSHIFT_CLIENT
def _get_s3_client():
global S3_CLIENT
if S3_CLIENT is None:
S3_CLIENT = aws_client.get_client('s3')
return S3_CLIENT
# ==============================================================================
# REDSHIFT CONNECTION METHODS
# ==============================================================================
def _get_service_credentials():
"""
Retrieves aws credentials for authentication across services
(ex: Redshift <-> S3) by checking the following settings in order:
1. a valid IAM Role set through environment variable.
2. Access keys set through environment variables.
3. Credentials set under ~/.aws with given profile set through environment
variable.
:return: credentials formatted according to redshift service query
statement structure.
:rtype: str
"""
if env_keys.REDSHIFT_S3_IAM_ROLE:
return f"IAM_ROLE '{env_keys.REDSHIFT_S3_IAM_ROLE}'"
elif env_keys.SCHRODINGER_AWS_KEY:
return f"""
ACCESS_KEY_ID '{env_keys.SCHRODINGER_AWS_KEY}'
SECRET_ACCESS_KEY '{env_keys.SCHRODINGER_AWS_SECRET_KEY}'"""
else:
# note that this is typically used when access keys are set under ~/.aws
aws_credentials = aws_client.get_credentials()
return f"""
ACCESS_KEY_ID '{aws_credentials.access_key}'
SECRET_ACCESS_KEY '{aws_credentials.secret_key}'
SESSION_TOKEN '{aws_credentials.token}'"""
def _execute_redshift_statement(sql):
"""
Execute a redshift sql query with appropriate database connection settings.
:param sql: query to execute
:type sql: str
:return: query ID
:rtype: str
"""
response = _get_redshift_client().execute_statement(
ClusterIdentifier=env_keys.REDSHIFT_CLUSTER_ID,
Database=env_keys.REDSHIFT_DATABASE,
DbUser=env_keys.REDSHIFT_DB_USER,
Sql=sql)
return response['Id']
def _wait_for_redshift_query(query_id, raise_error=True):
"""
Wait for redshift query to complete. Raises RuntimeError if query fails to
execute successfully if `raise_error` is enabled.
:param query_id: unique ID of query to monitor
:type query_id: str
:param raise_error: whether to raise an exception if the query returns an
error; default behavior is True (to raise an exception).
:type raise_error: bool
:return: the query job response
:rtype: dict
"""
client = _get_redshift_client()
# sleep till query finishes
while _get_query_status(query_id) not in ['FINISHED', 'FAILED']:
time.sleep(5)
# check for errors
response = client.describe_statement(Id=query_id)
if response.get('Error') and raise_error:
msg = QUERY_ERROR.format(response['QueryString'], response['Error'])
raise RuntimeError(msg)
return response
def _get_query_status(query_id):
"""
Helper method to retrieve query's status by ID.
:param query_id: unique ID of query to retrieve status
:type query_id: str
:return: 'SUBMITTED'|'PICKED'|'STARTED'|'FINISHED'|'ABORTED'|'FAILED'|'ALL'
:rtype: str
"""
client = _get_redshift_client()
return client.describe_statement(Id=query_id)['Status']
[docs]def run_query(sql, raise_error=True):
"""
Helper method to execute an SQL query and wait for its completion.
:param sql: query to execute
:type sql: str
:param raise_error: whether to raise an exception for failed query - see
`_wait_for_redshift_query` docstring for more details.
:type raise_error: bool
:return: the query job response
:rtype: dict
"""
query_id = _execute_redshift_statement(sql)
return _wait_for_redshift_query(query_id, raise_error)
# ==============================================================================
# REDSHIFT DATABASE METHODS
# ==============================================================================
def _create_table(table_id):
"""
Creates a new table under the default database set during redshift
execution settings.
:param table_id: unique table ID
:type table_id: str
"""
logger.debug(f"creating table... {table_id}")
sql = f'CREATE TABLE "{table_id}" (data varchar(max));'
_execute_redshift_statement(sql)
def _drop_table(table_id):
"""
Requested table is deleted.
:param table_id: table to delete by ID.
:type table_id: str
"""
logger.debug(f"dropping table... {table_id}")
sql = f'DROP TABLE "{table_id}";'
_execute_redshift_statement(sql)
def _table_exists(table_id):
"""
Check if table exists in the database.
:param table_id: table to check status for.
:type table_id: str
:return: whether the requested table exists in database.
:rtype: bool
"""
sql = f'SELECT EXISTS (SELECT 1 FROM "{table_id}");'
response = run_query(sql, raise_error=False)
if response.get('Error'):
err_msg = response.get('Error')
if f'relation "{table_id.lower()}" does not exist' in err_msg:
return False
# everything else should raise an error
raise RuntimeError(QUERY_ERROR.format(sql, err_msg))
return True
def _deduplicate_table(table_id, max_file_size=_MAX_FILE_SIZE):
"""
Deduplicates the requested table and exports to s3.
:param table_id: table to deduplicate.
:type table_id: str
:param max_file_size: maximum file size for exporting before batching.
:type max_file_size: float
"""
query_str = f"""
SELECT data FROM \"{table_id}\" GROUP BY data"""
destination = _execute_unload_query(table_id, query_str, max_file_size)
logger.debug(f"deduplicated table {table_id} and exported to: "
f"{destination}")
def _random_sample_table(table_id, n, max_file_size=_MAX_FILE_SIZE):
"""
Samples the requested table randomly and exports the results to s3.
:param table_id: table to randomly sample.
:type table_id: str
:param n: number of samples requested
:type n: int
:param max_file_size: maximum file size for exporting before batching.
:type max_file_size: float
"""
query_str = f"""
SELECT data FROM \"{table_id}\"
WHERE RANDOM() < {n}/(SELECT COUNT(*) FROM \"{table_id}\")::float"""
destination = _execute_unload_query(table_id, query_str, max_file_size)
logger.debug(f"randomly sampled {n} entries from table {table_id} and"
f"exported to: {destination} ")
def _deduplicate_and_random_sample_table(table_id,
n,
max_file_size=_MAX_FILE_SIZE):
"""
The requested table is first deduplicated, then randomly sampled for `n`
entries, with the results exported to s3.
:param table_id: table to randomly sample.
:type table_id: str
:param n: number of samples requested
:type n: int
:param max_file_size: maximum file size for exporting before batching.
:type max_file_size: float
"""
query_str = f"""
SELECT data FROM \"{table_id}\" GROUP BY data
HAVING
RANDOM() < {n}/(SELECT COUNT(DISTINCT data)
FROM \"{table_id}\")::float"""
destination = _execute_unload_query(table_id, query_str, max_file_size)
logger.debug(f"deduplicated and random sampled table {table_id} and "
f"exported to: {destination}")
def _export_table(table_id, max_file_size=_MAX_FILE_SIZE):
"""
Generic export call to unload the table into s3.
:param table_id: table to export.
:type table_id: str
:param max_file_size: maximum file size for exporting before batching.
:type max_file_size: float
"""
query_str = f"""
SELECT * FROM \"{table_id}\""""
destination = _execute_unload_query(table_id, query_str, max_file_size)
logger.debug(f"exported table.. {table_id} to: {destination}")
def _execute_unload_query(table_id, query, max_file_size):
"""
Helper method to add s3 destination and authentication credentials to query.
:param table_id: table to unload into s3.
:type table_id: str
:param query: SQL query.
:type query: str
:param max_file_size: maximum file size for exporting before batching.
:type max_file_size: float
:return: newly created s3 folder where the results are exported.
:rtype: str
"""
bucket_name = env_keys.S3_BUCKET_NAME
destination = f's3://{bucket_name}/{table_id}_output/'
query_str = f"""
UNLOAD ('{query}')
TO '{destination}'
MAXFILESIZE {max_file_size} MB
{_get_service_credentials()};"""
run_query(query_str)
return destination
def _import_table_from_s3(table_id, s3_path):
"""
Import data into a redshift table by copying the requested data from s3 to
redshift.
:param table_id: the newly generated table inside the redshift database.
:type table_id: str
:param s3_path: the absolute path to an s3 folder to copy data from.
:type s3_path: str
"""
origin = s3_path
query_str = f"""
COPY \"{table_id}\"
FROM '{origin}'
{_get_service_credentials()}
CSV GZIP;"""
run_query(query_str)
logger.debug(f"imported data to... {table_id}")
# ==============================================================================
# S3 BUCKET METHODS
# ==============================================================================
def _upload_to_s3(gen, serializer, s3_folder, chunk_size=100_000):
"""
Upload the given data to s3 under a new folder named after `s3_folder`.
:param gen: iterable of entries to upload to s3 - must be serializable to
string.
:type gen: iter
:param serializer: should be able to serializer data entries into strings.
:type serializer: Serializer
:param s3_folder: used as the folder name in s3.
:type s3_folder: str
:param chunk_size: number of lines to upload per file.
:type chunk_size: int
"""
s3_client = _get_s3_client()
bucket_name = env_keys.S3_BUCKET_NAME
output_generator = more_itertools.peekable(gen)
def get_chunk_of_rows(gen):
return [
serializer.toString(output)
for output in itertools.islice(gen, chunk_size)
]
chunk_idx = 0
suffix = str(uuid.uuid4())[:8]
while rows_to_upload := get_chunk_of_rows(output_generator):
fname = f'{chunk_idx}_{suffix}.csv.gz'
content = gzip.compress('\n'.join(rows_to_upload).encode())
s3_client.put_object(Body=content,
Bucket=bucket_name,
Key='/'.join([s3_folder, fname]))
chunk_idx += 1
destination = get_s3_absolute_path(s3_folder, s3_file='')
logger.debug(f"uploaded data to s3 bucket... {destination}")
def _download_from_s3(table_id):
"""
All s3 files containing the prefix of `table_id` is downloaded. A line by
line iterator is returned.
:param table_id: download s3 files containing this prefix in their names.
:type table_id: str
:return: an iterator yielding per line.
:rtype: iter[str]
"""
output_files = _get_s3_outputs(table_id)
for output_file in output_files:
for line in _get_s3_file(output_file):
yield line
def _get_s3_outputs(table_id):
return _list_s3_folder(f'{table_id}_output/')
def _list_s3_folder(prefix):
"""
List all files under an s3 folder.
:param prefix: filter files containing the following prefix.
:type prefix: str
:return: iterable of s3 file names.
:rtype: iter[str]
"""
s3_client = _get_s3_client()
next_token = None
while True:
args = {
'Bucket': env_keys.S3_BUCKET_NAME,
'Prefix': prefix,
}
if next_token:
args['ContinuationToken'] = next_token
response = s3_client.list_objects_v2(**args)
for entry in response['Contents']:
yield entry['Key']
next_token = response.get('NextContinuationToken')
if not response['IsTruncated']:
break
def _get_s3_file(filepath):
"""
Get s3 file contents.
:param filepath: full file path to s3 object to obtain file content.
:type filepath: str
:return: an iterator over the lines in requested file.
:rtype: iter[str]
"""
s3_client = _get_s3_client()
response = s3_client.get_object(Bucket=env_keys.S3_BUCKET_NAME,
Key=filepath)
for line in response['Body'].iter_lines():
yield line.decode()
def _delete_bucket_folder(s3_folder):
"""
Remove all files within the folder `s3_folder`, which is technically all s3
objects holding a prefix of `s3_folder`.
:param s3_folder: files containing this prefix.
:type s3_folder: str
"""
s3_client = _get_s3_client()
bucket_name = env_keys.S3_BUCKET_NAME
files = _list_s3_folder(s3_folder)
num_files = 0
for batch in more_itertools.chunked(files, _MAX_FILES):
s3_client.delete_objects(
Bucket=bucket_name,
Delete={'Objects': [{
'Key': file
} for file in batch]})
num_files += len(batch)
logger.debug(f'deleted {num_files} files from bucket: {bucket_name}')
[docs]def get_s3_absolute_path(s3_folder, s3_file=None):
components = ['s3:/', env_keys.S3_BUCKET_NAME, s3_folder]
if s3_file:
components.append(s3_file)
return '/'.join(components)
# ==============================================================================
# SETTINGS
# ==============================================================================
[docs]class RSTable(parameters.CompoundParam):
table_name: str = None
[docs] def getFullTableId(self):
if not self.table_name:
raise ValueError("Table name is not specified.")
# Note: step names are typically used for table names so as a courtesy
# we make a simple replacement
name = self.table_name.replace('.', '_')
if env_keys.CLOUD_ARTIFACT_PREFIX:
name = '/'.join([env_keys.CLOUD_ARTIFACT_PREFIX, name])
if not re.match(r'^[a-zA-Z][\w/]+$', name):
raise ValueError(f'{INVALID_TABLE_NAME} - {name=}')
return name
[docs]class S3Folder(parameters.CompoundParam):
folder_name: str = None
[docs] def getAbsolutePath(self):
"""
:return: the absolute path to folder in s3, e.g.
s3://my_bucket/my_folder
:rtype: str
"""
return get_s3_absolute_path(self.folder_name)
[docs]class S3File(parameters.CompoundParam):
filename: str = None
[docs]class RSTableExportSettings(parameters.CompoundParam):
max_file_size: float = 5 # MB || range: min=5MB to max=6.2GB
[docs]class RSFilterSettings(parameters.CompoundParam):
table: RSTable
s3_folder: S3Folder
# ==============================================================================
# STEPS
# ==============================================================================
class _UploadToS3Step(basesteps.UploadStep):
"""
First part of the upload step works by uploading the input mols serialized
to strings to s3.
"""
Settings = S3Folder
Output = S3Folder
def reduceFunction(self, inps):
s3_folder = self.settings.folder_name
_upload_to_s3(inps, self._getInputSerializer(), s3_folder)
yield self.settings
class _ExportFromS3ToRS(stepper.ReduceStep):
"""
Second part of the upload step works by copying over the S3 folder to the
redshift database under the requested table ID.
"""
Settings = S3Folder
Input = S3Folder
Output = RSTable
def reduceFunction(self, s3_folders):
s3_paths = set([inp.getAbsolutePath() for inp in s3_folders])
for path in s3_paths:
rs_table = self._s3FolderToRSTable(path)
_import_table_from_s3(rs_table.getFullTableId(), path)
yield rs_table
def _s3FolderToRSTable(self, s3_path):
return RSTable(table_name=s3_path.split('/')[-1])
class _EnumerateS3Folder(stepper.MapStep):
"""
This step only maps a given table name to it's appropriate folder in S3 and
enumerates over the list of files there that were batch exported from the
previous filtering step.
"""
Input = RSTable
Output = S3File
def mapFunction(self, table):
table_id = table.getFullTableId()
for batch_file in _get_s3_outputs(table_id):
yield S3File(filename=batch_file)
class _DownloadFromS3Step(basesteps.DownloadStep):
"""
Download results by looking for the newly created folder in s3 by either the
deduplication or random sampling step. The output lines are serialized into
mols.
"""
Input = S3File
def mapFunction(self, inp):
op_serializer = self.getOutputSerializer()
for line in _get_s3_file(inp.filename):
yield op_serializer.fromString(line)
class _DeduplicateStep(basesteps.TableReduceStep):
"""
Deduplicates given Redshift table and batch exports the results to S3.
"""
Settings = RSTableExportSettings
Input = RSTable
Output = RSTable
def _actOnTable(self, table_id):
_deduplicate_table(table_id, self.settings.max_file_size)
class _RandomSampleStep(basesteps.TableReduceStep):
"""
Randomly samples given Redshift table and batch exports the results to S3.
"""
class Settings(RSTableExportSettings):
n: int = 5000
Input = RSTable
Output = RSTable
def _actOnTable(self, table_id):
_random_sample_table(table_id, self.settings.n,
self.settings.max_file_size)
class _DeduplicateAndRandomSampleStep(_RandomSampleStep):
"""
Deduplicates and randomly samples given Redshift table and batch exports the
results to S3.
"""
def _actOnTable(self, table_id):
_deduplicate_and_random_sample_table(table_id, self.settings.n,
self.settings.max_file_size)
class _DropTableStep(basesteps.TableReduceStep):
"""
Drops the requested Redshift table, and outputs the table ID.
"""
Input = RSTable
Output = RSTable
def _actOnTable(self, table_id):
_drop_table(table_id)
# ==============================================================================
# CHAINS
# ==============================================================================
[docs]class RSFilter(MolMolMixin, basesteps.CloudFilterChain):
"""
Generic Redshift filter with table setup and validation defined. Classes
inheriting from `RSFilter` need to define `addFilterSteps` for filter steps.
"""
Settings = RSFilterSettings
[docs] def setUp(self):
super(RSFilter, self).setUp()
self[0].setSettings(**self.settings.s3_folder.toDict())
# This is a bit of a hack to fix a bug when running this chain
# with two levels of jobcontrol. See AD-359
self._setConfig(self._getCanonicalizedConfig())
def _setUpTable(self):
table = self.settings.table
if table.table_name is None:
name = utils.generate_stepid_and_random_suffix(self)
table.table_name = name.replace('-', '_')
table_id = table.getFullTableId()
_create_table(table_id)
self.settings.s3_folder.folder_name = table_id
def _validateTable(self):
errs = []
cluster_id = env_keys.REDSHIFT_CLUSTER_ID
database = env_keys.REDSHIFT_DATABASE
db_user = env_keys.REDSHIFT_DB_USER
if not (cluster_id and database and db_user):
errs.append(
stepper.SettingsError(
self,
INSUFFICIENT_DATABASE_SETTINGS.format(
cluster_id, database, db_user)))
if not env_keys.S3_BUCKET_NAME:
errs.append(stepper.SettingsError(self, MISSING_BUCKET))
return errs
[docs] def buildChain(self):
self.addStep(_UploadToS3Step(**self.settings.s3_folder.toDict()))
self.addStep(_ExportFromS3ToRS())
self.addFilterSteps()
self.addDropTableStepInProduction()
self.addStep(_EnumerateS3Folder())
self.addStep(_DownloadFromS3Step())
[docs] def addFilterSteps(self):
raise NotImplementedError
[docs] def addDropTableStepInProduction(self):
if int(os.environ.get('SCHRODINGER_STEPPER_DEBUG', 0)):
return
self.addStep(_DropTableStep())
[docs]class RSUniqueSmilesFilter(RSFilter):
"""
A Chain that takes in Mol's, uploads them to Redshift, and deduplicates
them. To use, set the table name you'd like to use in the step settings.
"""
[docs] def addFilterSteps(self):
self.addStep(_DeduplicateStep())
[docs]class RSRandomSampleFilter(RSFilter):
"""
A Chain that takes in Mol's, uploads them to Redshift, and outputs a random
sample of them. To use, set the table name you'd like to use in the step
settings.
The settings also has a `n` setting for determining roughly
how many rows should be sampled. Note that this is an approximate number
and a few more or less may be output.
"""
[docs] class Settings(RSFilterSettings):
n: int = 5000
[docs] def addFilterSteps(self):
self.addStep(_RandomSampleStep(n=self.settings.n))
[docs]class RSDeduplicateAndRandomSampleFilter(RSRandomSampleFilter):
"""
Same as RSRandomSampleFilter except the data is deduplicated before
randomly sampled.
"""
[docs] def addFilterSteps(self):
self.addStep(_DeduplicateAndRandomSampleStep(n=self.settings.n))