"""
Steps for deduplicating and random sampling data through AWS (currently limited
to use with Mols).
"""
import gzip
import time
import itertools
import os
import re
import uuid
import more_itertools
from schrodinger.application.steps import basesteps
from schrodinger.application.steps import env_keys
from schrodinger.application.steps import utils
from schrodinger.models import parameters
from schrodinger.stepper.cloud import aws_client
from schrodinger import stepper
from schrodinger.stepper import logger
from .basesteps import MolMolMixin
try:
    import boto3
except ImportError:
    boto3 = None
# ==============================================================================
# Error messages
# ==============================================================================
INSUFFICIENT_DATABASE_SETTINGS = (
    'Cannot create new Redshift table with insufficient database settings:'
    '\ncluster_id={0}\ndatabase={1}\ndatabase_user={2}')
MISSING_BUCKET = ('AWS Redshift requires use of S3 buckets and bucket name '
                  'must be set through environment variables. See '
                  '`schrodinger.applciation.steps.env_keys` for more details.')
QUERY_ERROR = "Errors while running query {0}\n{1}"
INVALID_TABLE_NAME = ("Invalid table name specified. Table names must only "
                      "contain alphanumeric values and underscores, and must "
                      "also begin with a non-digit value.")
# ==============================================================================
# CONSTANTS
# ==============================================================================
REDSHIFT_CLIENT = None
S3_CLIENT = None
_MAX_FILES = 1000
# Default maximum size of one file when exporting from Redshift in MB
_MAX_FILE_SIZE = 5.0  # MB
# ==============================================================================
# AWS Clients
# ==============================================================================
def _get_redshift_client():
    global REDSHIFT_CLIENT
    if REDSHIFT_CLIENT is None:
        REDSHIFT_CLIENT = aws_client.get_client('redshift-data')
    return REDSHIFT_CLIENT
def _get_s3_client():
    global S3_CLIENT
    if S3_CLIENT is None:
        S3_CLIENT = aws_client.get_client('s3')
    return S3_CLIENT
# ==============================================================================
# REDSHIFT CONNECTION METHODS
# ==============================================================================
def _get_service_credentials():
    """
    Retrieves aws credentials for authentication across services
    (ex: Redshift <-> S3) by checking the following settings in order:
    1. a valid IAM Role set through environment variable.
    2. Access keys set through environment variables.
    3. Credentials set under ~/.aws with given profile set through environment
    variable.
    :return: credentials formatted according to redshift service query
        statement structure.
    :rtype: str
    """
    if env_keys.REDSHIFT_S3_IAM_ROLE:
        return f"IAM_ROLE '{env_keys.REDSHIFT_S3_IAM_ROLE}'"
    elif env_keys.SCHRODINGER_AWS_KEY:
        return f"""
            ACCESS_KEY_ID '{env_keys.SCHRODINGER_AWS_KEY}'
            SECRET_ACCESS_KEY '{env_keys.SCHRODINGER_AWS_SECRET_KEY}'"""
    else:
        # note that this is typically used when access keys are set under ~/.aws
        aws_credentials = aws_client.get_credentials()
        return f"""
            ACCESS_KEY_ID '{aws_credentials.access_key}'
            SECRET_ACCESS_KEY '{aws_credentials.secret_key}'
            SESSION_TOKEN '{aws_credentials.token}'"""
def _execute_redshift_statement(sql):
    """
    Execute a redshift sql query with appropriate database connection settings.
    :param sql: query to execute
    :type sql: str
    :return: query ID
    :rtype: str
    """
    response = _get_redshift_client().execute_statement(
        ClusterIdentifier=env_keys.REDSHIFT_CLUSTER_ID,
        Database=env_keys.REDSHIFT_DATABASE,
        DbUser=env_keys.REDSHIFT_DB_USER,
        Sql=sql)
    return response['Id']
def _wait_for_redshift_query(query_id, raise_error=True):
    """
    Wait for redshift query to complete. Raises RuntimeError if query fails to
    execute successfully if `raise_error` is enabled.
    :param query_id: unique ID of query to monitor
    :type query_id: str
    :param raise_error: whether to raise an exception if the query returns an
        error; default behavior is True (to raise an exception).
    :type raise_error: bool
    :return: the query job response
    :rtype: dict
    """
    client = _get_redshift_client()
    # sleep till query finishes
    while _get_query_status(query_id) not in ['FINISHED', 'FAILED']:
        time.sleep(5)
    # check for errors
    response = client.describe_statement(Id=query_id)
    if response.get('Error') and raise_error:
        msg = QUERY_ERROR.format(response['QueryString'], response['Error'])
        raise RuntimeError(msg)
    return response
def _get_query_status(query_id):
    """
    Helper method to retrieve query's status by ID.
    :param query_id: unique ID of query to retrieve status
    :type query_id: str
    :return: 'SUBMITTED'|'PICKED'|'STARTED'|'FINISHED'|'ABORTED'|'FAILED'|'ALL'
    :rtype: str
    """
    client = _get_redshift_client()
    return client.describe_statement(Id=query_id)['Status']
[docs]def run_query(sql, raise_error=True):
    """
    Helper method to execute an SQL query and wait for its completion.
    :param sql: query to execute
    :type sql: str
    :param raise_error: whether to raise an exception for failed query - see
        `_wait_for_redshift_query` docstring for more details.
    :type raise_error: bool
    :return: the query job response
    :rtype: dict
    """
    query_id = _execute_redshift_statement(sql)
    return _wait_for_redshift_query(query_id, raise_error) 
# ==============================================================================
# REDSHIFT DATABASE METHODS
# ==============================================================================
def _create_table(table_id):
    """
    Creates a new table under the default database set during redshift
    execution settings.
    :param table_id: unique table ID
    :type table_id: str
    """
    logger.debug(f"creating table... {table_id}")
    sql = f'CREATE TABLE "{table_id}" (data varchar(max));'
    _execute_redshift_statement(sql)
def _drop_table(table_id):
    """
    Requested table is deleted.
    :param table_id: table to delete by ID.
    :type table_id: str
    """
    logger.debug(f"dropping table... {table_id}")
    sql = f'DROP TABLE "{table_id}";'
    _execute_redshift_statement(sql)
def _table_exists(table_id):
    """
    Check if table exists in the database.
    :param table_id: table to check status for.
    :type table_id: str
    :return: whether the requested table exists in database.
    :rtype: bool
    """
    sql = f'SELECT EXISTS (SELECT 1 FROM "{table_id}");'
    response = run_query(sql, raise_error=False)
    if response.get('Error'):
        err_msg = response.get('Error')
        if f'relation "{table_id.lower()}" does not exist' in err_msg:
            return False
        # everything else should raise an error
        raise RuntimeError(QUERY_ERROR.format(sql, err_msg))
    return True
def _deduplicate_table(table_id, max_file_size=_MAX_FILE_SIZE):
    """
    Deduplicates the requested table and exports to s3.
    :param table_id: table to deduplicate.
    :type table_id: str
    :param max_file_size: maximum file size for exporting before batching.
    :type max_file_size: float
    """
    query_str = f"""
            SELECT data FROM \"{table_id}\" GROUP BY data"""
    destination = _execute_unload_query(table_id, query_str, max_file_size)
    logger.debug(f"deduplicated table {table_id} and exported to: "
                 f"{destination}")
def _random_sample_table(table_id, n, max_file_size=_MAX_FILE_SIZE):
    """
    Samples the requested table randomly and exports the results to s3.
    :param table_id: table to randomly sample.
    :type table_id: str
    :param n: number of samples requested
    :type n: int
    :param max_file_size: maximum file size for exporting before batching.
    :type max_file_size: float
    """
    query_str = f"""
            SELECT data FROM \"{table_id}\"
            WHERE RANDOM() < {n}/(SELECT COUNT(*) FROM \"{table_id}\")::float"""
    destination = _execute_unload_query(table_id, query_str, max_file_size)
    logger.debug(f"randomly sampled {n} entries from table {table_id} and"
                 f"exported to: {destination} ")
def _deduplicate_and_random_sample_table(table_id,
                                         n,
                                         max_file_size=_MAX_FILE_SIZE):
    """
    The requested table is first deduplicated, then randomly sampled for `n`
    entries, with the results exported to s3.
    :param table_id: table to randomly sample.
    :type table_id: str
    :param n: number of samples requested
    :type n: int
    :param max_file_size: maximum file size for exporting before batching.
    :type max_file_size: float
    """
    query_str = f"""
            SELECT data FROM \"{table_id}\" GROUP BY data
            HAVING
            RANDOM() < {n}/(SELECT COUNT(DISTINCT data)
                            FROM \"{table_id}\")::float"""
    destination = _execute_unload_query(table_id, query_str, max_file_size)
    logger.debug(f"deduplicated and random sampled table {table_id} and "
                 f"exported to: {destination}")
def _export_table(table_id, max_file_size=_MAX_FILE_SIZE):
    """
    Generic export call to unload the table into s3.
    :param table_id: table to export.
    :type table_id: str
    :param max_file_size: maximum file size for exporting before batching.
    :type max_file_size: float
    """
    query_str = f"""
            SELECT * FROM \"{table_id}\""""
    destination = _execute_unload_query(table_id, query_str, max_file_size)
    logger.debug(f"exported table.. {table_id} to: {destination}")
def _execute_unload_query(table_id, query, max_file_size):
    """
    Helper method to add s3 destination and authentication credentials to query.
    :param table_id: table to unload into s3.
    :type table_id: str
    :param query: SQL query.
    :type query: str
    :param max_file_size: maximum file size for exporting before batching.
    :type max_file_size: float
    :return: newly created s3 folder where the results are exported.
    :rtype: str
    """
    bucket_name = env_keys.S3_BUCKET_NAME
    destination = f's3://{bucket_name}/{table_id}_output/'
    query_str = f"""
            UNLOAD ('{query}')
            TO '{destination}'
            MAXFILESIZE {max_file_size} MB
            {_get_service_credentials()};"""
    run_query(query_str)
    return destination
def _import_table_from_s3(table_id, s3_path):
    """
    Import data into a redshift table by copying the requested data from s3 to
    redshift.
    :param table_id: the newly generated table inside the redshift database.
    :type table_id: str
    :param s3_path: the absolute path to an s3 folder to copy data from.
    :type s3_path: str
    """
    origin = s3_path
    query_str = f"""
            COPY \"{table_id}\"
            FROM '{origin}'
            {_get_service_credentials()}
            CSV GZIP;"""
    run_query(query_str)
    logger.debug(f"imported data to... {table_id}")
# ==============================================================================
# S3 BUCKET METHODS
# ==============================================================================
def _upload_to_s3(gen, serializer, s3_folder, chunk_size=100_000):
    """
    Upload the given data to s3 under a new folder named after `s3_folder`.
    :param gen: iterable of entries to upload to s3 - must be serializable to
        string.
    :type gen: iter
    :param serializer: should be able to serializer data entries into strings.
    :type serializer: Serializer
    :param s3_folder: used as the folder name in s3.
    :type s3_folder: str
    :param chunk_size: number of lines to upload per file.
    :type chunk_size: int
    """
    s3_client = _get_s3_client()
    bucket_name = env_keys.S3_BUCKET_NAME
    output_generator = more_itertools.peekable(gen)
    def get_chunk_of_rows(gen):
        return [
            serializer.toString(output)
            for output in itertools.islice(gen, chunk_size)
        ]
    chunk_idx = 0
    suffix = str(uuid.uuid4())[:8]
    while rows_to_upload := get_chunk_of_rows(output_generator):
        fname = f'{chunk_idx}_{suffix}.csv.gz'
        content = gzip.compress('\n'.join(rows_to_upload).encode())
        s3_client.put_object(Body=content,
                             Bucket=bucket_name,
                             Key='/'.join([s3_folder, fname]))
        chunk_idx += 1
    destination = get_s3_absolute_path(s3_folder, s3_file='')
    logger.debug(f"uploaded data to s3 bucket... {destination}")
def _download_from_s3(table_id):
    """
    All s3 files containing the prefix of `table_id` is downloaded. A line by
    line iterator is returned.
    :param table_id: download s3 files containing this prefix in their names.
    :type table_id: str
    :return: an iterator yielding per line.
    :rtype: iter[str]
    """
    output_files = _get_s3_outputs(table_id)
    for output_file in output_files:
        for line in _get_s3_file(output_file):
            yield line
def _get_s3_outputs(table_id):
    return _list_s3_folder(f'{table_id}_output/')
def _list_s3_folder(prefix):
    """
    List all files under an s3 folder.
    :param prefix: filter files containing the following prefix.
    :type prefix: str
    :return: iterable of s3 file names.
    :rtype: iter[str]
    """
    s3_client = _get_s3_client()
    next_token = None
    while True:
        args = {
            'Bucket': env_keys.S3_BUCKET_NAME,
            'Prefix': prefix,
        }
        if next_token:
            args['ContinuationToken'] = next_token
        response = s3_client.list_objects_v2(**args)
        for entry in response['Contents']:
            yield entry['Key']
        next_token = response.get('NextContinuationToken')
        if not response['IsTruncated']:
            break
def _get_s3_file(filepath):
    """
    Get s3 file contents.
    :param filepath: full file path to s3 object to obtain file content.
    :type filepath: str
    :return: an iterator over the lines in requested file.
    :rtype: iter[str]
    """
    s3_client = _get_s3_client()
    response = s3_client.get_object(Bucket=env_keys.S3_BUCKET_NAME,
                                    Key=filepath)
    for line in response['Body'].iter_lines():
        yield line.decode()
def _delete_bucket_folder(s3_folder):
    """
    Remove all files within the folder `s3_folder`, which is technically all s3
    objects holding a prefix of `s3_folder`.
    :param s3_folder: files containing this prefix.
    :type s3_folder: str
    """
    s3_client = _get_s3_client()
    bucket_name = env_keys.S3_BUCKET_NAME
    files = _list_s3_folder(s3_folder)
    num_files = 0
    for batch in more_itertools.chunked(files, _MAX_FILES):
        s3_client.delete_objects(
            Bucket=bucket_name,
            Delete={'Objects': [{
                'Key': file
            } for file in batch]})
        num_files += len(batch)
    logger.debug(f'deleted {num_files} files from bucket: {bucket_name}')
[docs]def get_s3_absolute_path(s3_folder, s3_file=None):
    components = ['s3:/', env_keys.S3_BUCKET_NAME, s3_folder]
    if s3_file:
        components.append(s3_file)
    return '/'.join(components) 
# ==============================================================================
# SETTINGS
# ==============================================================================
[docs]class RSTable(parameters.CompoundParam):
    table_name: str = None
[docs]    def getFullTableId(self):
        if not self.table_name:
            raise ValueError("Table name is not specified.")
        # Note: step names are typically used for table names so as a courtesy
        # we make a simple replacement
        name = self.table_name.replace('.', '_')
        if env_keys.CLOUD_ARTIFACT_PREFIX:
            name = '/'.join([env_keys.CLOUD_ARTIFACT_PREFIX, name])
        if not re.match(r'^[a-zA-Z][\w/]+$', name):
            raise ValueError(f'{INVALID_TABLE_NAME} - {name=}')
        return name  
[docs]class S3Folder(parameters.CompoundParam):
    folder_name: str = None
[docs]    def getAbsolutePath(self):
        """
        :return: the absolute path to folder in s3, e.g.
            s3://my_bucket/my_folder
        :rtype: str
        """
        return get_s3_absolute_path(self.folder_name)  
[docs]class S3File(parameters.CompoundParam):
    filename: str = None 
[docs]class RSTableExportSettings(parameters.CompoundParam):
    max_file_size: float = 5  # MB  ||  range: min=5MB to max=6.2GB 
[docs]class RSFilterSettings(parameters.CompoundParam):
    table: RSTable
    s3_folder: S3Folder 
# ==============================================================================
# STEPS
# ==============================================================================
class _UploadToS3Step(basesteps.UploadStep):
    """
    First part of the upload step works by uploading the input mols serialized
    to strings to s3.
    """
    Settings = S3Folder
    Output = S3Folder
    def reduceFunction(self, inps):
        s3_folder = self.settings.folder_name
        _upload_to_s3(inps, self._getInputSerializer(), s3_folder)
        yield self.settings
class _ExportFromS3ToRS(stepper.ReduceStep):
    """
    Second part of the upload step works by copying over the S3 folder to the
    redshift database under the requested table ID.
    """
    Settings = S3Folder
    Input = S3Folder
    Output = RSTable
    def reduceFunction(self, s3_folders):
        s3_paths = set([inp.getAbsolutePath() for inp in s3_folders])
        for path in s3_paths:
            rs_table = self._s3FolderToRSTable(path)
            _import_table_from_s3(rs_table.getFullTableId(), path)
            yield rs_table
    def _s3FolderToRSTable(self, s3_path):
        return RSTable(table_name=s3_path.split('/')[-1])
class _EnumerateS3Folder(stepper.MapStep):
    """
    This step only maps a given table name to it's appropriate folder in S3 and
    enumerates over the list of files there that were batch exported from the
    previous filtering step.
    """
    Input = RSTable
    Output = S3File
    def mapFunction(self, table):
        table_id = table.getFullTableId()
        for batch_file in _get_s3_outputs(table_id):
            yield S3File(filename=batch_file)
class _DownloadFromS3Step(basesteps.DownloadStep):
    """
    Download results by looking for the newly created folder in s3 by either the
    deduplication or random sampling step. The output lines are serialized into
    mols.
    """
    Input = S3File
    def mapFunction(self, inp):
        op_serializer = self.getOutputSerializer()
        for line in _get_s3_file(inp.filename):
            yield op_serializer.fromString(line)
class _DeduplicateStep(basesteps.TableReduceStep):
    """
    Deduplicates given Redshift table and batch exports the results to S3.
    """
    Settings = RSTableExportSettings
    Input = RSTable
    Output = RSTable
    def _actOnTable(self, table_id):
        _deduplicate_table(table_id, self.settings.max_file_size)
class _RandomSampleStep(basesteps.TableReduceStep):
    """
    Randomly samples given Redshift table and batch exports the results to S3.
    """
    class Settings(RSTableExportSettings):
        n: int = 5000
    Input = RSTable
    Output = RSTable
    def _actOnTable(self, table_id):
        _random_sample_table(table_id, self.settings.n,
                             self.settings.max_file_size)
class _DeduplicateAndRandomSampleStep(_RandomSampleStep):
    """
    Deduplicates and randomly samples given Redshift table and batch exports the
    results to S3.
    """
    def _actOnTable(self, table_id):
        _deduplicate_and_random_sample_table(table_id, self.settings.n,
                                             self.settings.max_file_size)
class _DropTableStep(basesteps.TableReduceStep):
    """
    Drops the requested Redshift table, and outputs the table ID.
    """
    Input = RSTable
    Output = RSTable
    def _actOnTable(self, table_id):
        _drop_table(table_id)
# ==============================================================================
# CHAINS
# ==============================================================================
[docs]class RSFilter(MolMolMixin, basesteps.CloudFilterChain):
    """
    Generic Redshift filter with table setup and validation defined. Classes
    inheriting from `RSFilter` need to define `addFilterSteps` for filter steps.
    """
    Settings = RSFilterSettings
[docs]    def setUp(self):
        super(RSFilter, self).setUp()
        self[0].setSettings(**self.settings.s3_folder.toDict())
        # This is a bit of a hack to fix a bug when running this chain
        # with two levels of jobcontrol. See AD-359
        self._setConfig(self._getCanonicalizedConfig()) 
    def _setUpTable(self):
        table = self.settings.table
        if table.table_name is None:
            name = utils.generate_stepid_and_random_suffix(self)
            table.table_name = name.replace('-', '_')
        table_id = table.getFullTableId()
        _create_table(table_id)
        self.settings.s3_folder.folder_name = table_id
    def _validateTable(self):
        errs = []
        cluster_id = env_keys.REDSHIFT_CLUSTER_ID
        database = env_keys.REDSHIFT_DATABASE
        db_user = env_keys.REDSHIFT_DB_USER
        if not (cluster_id and database and db_user):
            errs.append(
                stepper.SettingsError(
                    self,
                    INSUFFICIENT_DATABASE_SETTINGS.format(
                        cluster_id, database, db_user)))
        if not env_keys.S3_BUCKET_NAME:
            errs.append(stepper.SettingsError(self, MISSING_BUCKET))
        return errs
[docs]    def buildChain(self):
        self.addStep(_UploadToS3Step(**self.settings.s3_folder.toDict()))
        self.addStep(_ExportFromS3ToRS())
        self.addFilterSteps()
        self.addDropTableStepInProduction()
        self.addStep(_EnumerateS3Folder())
        self.addStep(_DownloadFromS3Step()) 
[docs]    def addFilterSteps(self):
        raise NotImplementedError 
[docs]    def addDropTableStepInProduction(self):
        if int(os.environ.get('SCHRODINGER_STEPPER_DEBUG', 0)):
            return
        self.addStep(_DropTableStep())  
[docs]class RSUniqueSmilesFilter(RSFilter):
    """
    A Chain that takes in Mol's, uploads them to Redshift, and deduplicates
    them. To use, set the table name you'd like to use in the step settings.
    """
[docs]    def addFilterSteps(self):
        self.addStep(_DeduplicateStep())  
[docs]class RSRandomSampleFilter(RSFilter):
    """
    A Chain that takes in Mol's, uploads them to Redshift, and outputs a random
    sample of them. To use, set the table name you'd like to use in the step
    settings.
    The settings also has a `n` setting for determining roughly
    how many rows should be sampled. Note that this is an approximate number
    and a few more or less may be output.
    """
[docs]    class Settings(RSFilterSettings):
        n: int = 5000 
[docs]    def addFilterSteps(self):
        self.addStep(_RandomSampleStep(n=self.settings.n))  
[docs]class RSDeduplicateAndRandomSampleFilter(RSRandomSampleFilter):
    """
    Same as RSRandomSampleFilter except the data is deduplicated before
    randomly sampled.
    """
[docs]    def addFilterSteps(self):
        self.addStep(_DeduplicateAndRandomSampleStep(n=self.settings.n))