"""
Steps for deduplicating and random sampling data (currently limited to
use with Mols).
See the chains at the bottom of the module for the steps you'll most likely want
to use.
"""
import copy
import time
import itertools
import os
import uuid
import more_itertools
from schrodinger.application.steps import basesteps
from schrodinger.models import parameters
from schrodinger import stepper
from schrodinger.stepper import logger
from .basesteps import MolMolMixin
try:
    from google.auth import compute_engine
    from google.cloud import bigquery
    from google.cloud import exceptions
    from google.cloud import storage
    from google.oauth2 import service_account
except ImportError:
    compute_engine = None
    bigquery = None
    storage = None
    exceptions = None
# Scopes
BIGQUERY_SCOPE = "https://www.googleapis.com/auth/bigquery"
CLOUD_SCOPE = "https://www.googleapis.com/auth/cloud-platform"
# Error messages
NO_DATASET = 'No bigquery dataset set'
_NO_PROJECT_ERR_MSG = (
    "No bigquery project is defined for this run. Set "
    "an environment variable for SCHRODINGER_GCP_PROJECT and try again.")
#===============================================================================
# BigQuery Functions
#===============================================================================
PROJECT = os.environ.get('SCHRODINGER_GCP_PROJECT')
KEY_PATH = os.environ.get('SCHRODINGER_GCP_KEY')  # Service account key
BQ_CLIENT = None
def _create_table(table_id):
    table_id = _get_fully_qualified_table_id(table_id)
    client = _get_bq_client()
    schema = [bigquery.SchemaField("data", "STRING", mode="REQUIRED")]
    table = bigquery.Table(table_id, schema=schema)
    client.create_table(table)
    logger.debug(f"creating table... {table_id}")
def _is_bigquery_enabled() -> bool:
    return bool(os.environ.get('SCHRODINGER_GCP_ENABLED', False))
def _wait_for_query_job(job):
    result = job.result()
    if job.errors:
        err_str = (f"Errors while running query {job.query}\n" +
                   f"{job.errors=}")
        logger.error(err_str)
    return result
def _dedupe_table(input_table_id, output_table_id):
    bq_client = _get_bq_client()
    input_table_id = _get_fully_qualified_table_id(input_table_id)
    output_table_id = _get_fully_qualified_table_id(output_table_id)
    job_config = bigquery.QueryJobConfig(destination=output_table_id,
                                         write_disposition='WRITE_TRUNCATE')
    query_str = f"""
             SELECT
                 data
             FROM
                 `{input_table_id}`
             GROUP BY
                 data"""
    query_job = bq_client.query(query_str, job_config=job_config)
    _wait_for_query_job(query_job)
    logger.debug(f"deduped table {input_table_id}")
def _generate_credentials():
    if KEY_PATH is None:
        return None
    credentials = service_account.Credentials.from_service_account_file(
        KEY_PATH,
        scopes=[CLOUD_SCOPE, BIGQUERY_SCOPE],
    )
    return credentials
def _generate_clients():
    credentials = _generate_credentials()
    bq_client = bigquery.Client(project=PROJECT, credentials=credentials)
    return bq_client
def _get_bq_client():
    global BQ_CLIENT
    if BQ_CLIENT is None:
        BQ_CLIENT = _generate_clients()
    return BQ_CLIENT
def _get_fully_qualified_table_id(table_id):
    if PROJECT is None:
        raise ValueError(_NO_PROJECT_ERR_MSG)
    if PROJECT not in table_id:
        return f'{PROJECT}.{table_id}'
    else:
        return table_id
def _stream_in_batches(gen,
                       serializer,
                       table_id,
                       chunk_size=10000,
                       skip_sleep=False):
    """
    Load batches of outputs into a table specified by `table_id`. outputs are
    batched so csv files are around `csv_size_limit` bytes. The csv files
    are written in chunks of `chunk_size` before being checked for size.
    After streaming in the data, this function will sleep for 5 seconds.
    This is to give BigQuery enough time to process the new results, otherwise
    any queries that happen immediately after this function will sometimes
    not process the new data. You can set `skip_sleep` to True if you don't
    expect to make any queries soon after.
    :param gen: A generator of outputs to load into the table
    :type  gen: Iterator
    :param serializer: A serializer to serialize the outputs, see `Serializer`
    :type  serializer: Serializer
    :param table_id: The table to load the outputs into. Should include both
        dataset and table name, i.e. "<DATASET>.<TABLE>"
    :type  table_id: str
    """
    client = _get_bq_client()
    table_id = _get_fully_qualified_table_id(table_id)
    output_generator = more_itertools.peekable(gen)
    def get_chunk_of_rows(gen):
        rows = []
        for output in itertools.islice(gen, chunk_size):
            rows.append((serializer.toString(output),))
        return rows
    fields = [bigquery.SchemaField("data", "STRING", mode="REQUIRED")]
    while True:
        rows_to_insert = get_chunk_of_rows(output_generator)
        if not rows_to_insert:
            break
        errors = client.insert_rows(
            table_id, rows_to_insert,
            selected_fields=fields)  # Make an API request.
        if errors:
            logger.error(f"while streaming in rows: {errors}")
    if not skip_sleep:
        time.sleep(5)
    logger.debug(f"streamed in data to... {table_id}")
def _random_sample_table(src_table_id, dest_table_id, n):
    bq_client = _get_bq_client()
    input_table_id = _get_fully_qualified_table_id(src_table_id)
    output_table_id = _get_fully_qualified_table_id(dest_table_id)
    job_config = bigquery.QueryJobConfig(destination=output_table_id,
                                         write_disposition='WRITE_TRUNCATE')
    query_str = f"""
            SELECT
                data
            FROM
                `{input_table_id}`
            WHERE
                RAND() < {n}/(SELECT COUNT(*) FROM `{src_table_id}`);"""
    query_job = bq_client.query(query_str, job_config=job_config)
    _wait_for_query_job(query_job)
def _get_table_data(table_id, starting_idx=None, num_rows=None):
    """
    Get contents of a table specified by `table_id`. If `starting_idx` is
    specified, then the data will start at that row index. Up to `num_rows`
    will be returned.
    """
    if table_id is None:
        raise TypeError("table_id must be string, not None")
    bq_client = _get_bq_client()
    table_id = _get_fully_qualified_table_id(table_id)
    def _unwrap_row_iterator():
        for row in bq_client.list_rows(table_id,
                                       start_index=starting_idx,
                                       max_results=num_rows):
            yield row['data']
    return _unwrap_row_iterator()
def _table_row_count(table_id):
    # We use a query to get the number of rows. This has higher costs than
    # just checking a property on the table but table properties don't update
    # for streaming inserts.
    table_id = _get_fully_qualified_table_id(table_id)
    bq_client = _get_bq_client()
    query = ("SELECT data " f"FROM `{table_id}`")
    query_job = bq_client.query(query,)
    results = _wait_for_query_job(query_job)
    return results.total_rows
#===============================================================================
# Data and Setting Classes
#===============================================================================
[docs]class BQTable(parameters.CompoundParam):
    dataset: str = 'bq_testing_dataset'
    table_name: str = None
[docs]    def getFullTableId(self):
        if not self.table_name:
            raise ValueError("Table name is not specified.")
        return self.dataset.replace('.', '-') + '.' + self.table_name.replace(
            '.', '-')  
[docs]class TableChunk(parameters.CompoundParam):
    start_idx: int
    chunk_size: int
    table: BQTable 
class _DownloadSettings(BQTable):
    chunk_size: int = 10000
#===============================================================================
# Steps
#===============================================================================
class _UploadToBQStep(basesteps.UploadStep):
    Settings = BQTable
    Output = BQTable
    def reduceFunction(self, inps):
        table_id = self.settings.getFullTableId()
        _stream_in_batches(inps, self._getInputSerializer(), table_id)
        yield self.settings
    def validateSettings(self):
        if PROJECT is None:
            return [stepper.SettingsError(self, _NO_PROJECT_ERR_MSG)]
        return []
class _ChunkBigQueryTable(stepper.MapStep):
    class Settings(parameters.CompoundParam):
        chunk_size: int = 10000
    Input = BQTable
    Output = TableChunk
    def mapFunction(self, table):
        num_rows = _table_row_count(table.getFullTableId())
        for idx in range(0, num_rows, self.settings.chunk_size):
            table_copy = copy.deepcopy(table)
            yield TableChunk(start_idx=idx,
                             chunk_size=self.settings.chunk_size,
                             table=table_copy)
class _DownloadFromBQStep(basesteps.DownloadStep):
    Input = TableChunk
    def mapFunction(self, inp):
        op_serializer = self.getOutputSerializer()
        for datum in _get_table_data(inp.table.getFullTableId(),
                                     starting_idx=inp.start_idx,
                                     num_rows=inp.chunk_size):
            yield op_serializer.fromString(datum)
class _DedupeStep(basesteps.TableReduceStep):
    Input = BQTable
    Output = BQTable
    def _actOnTable(self, table_id):
        _dedupe_table(table_id, table_id)
class _RandomSampleStep(basesteps.TableReduceStep):
    """
    Deduplication step with random sampling enabled.
    Sampling occurs after deduplication.
    The config's `n` specifies the average number of rows to keep.
    """
    class Settings(parameters.CompoundParam):
        n: int = 5000
    Output = BQTable
    Input = BQTable
    def _actOnTable(self, table_id):
        _random_sample_table(table_id, table_id, self.settings.n)
#===============================================================================
# Chains
# To use these chains with pubsub, specify `use_pubsub` in the batch settings
# for the steps `_UploadToBQStep` and `_DownloadFromBQStep`
#===============================================================================
def _generate_stepid_and_random_suffix(step):
    return f"{step.getStepId()}_{str(uuid.uuid4())[:8]}"
def _validate_table(step, table):
    errs = []
    if not table.dataset:
        errs.append(stepper.SettingsError(step, NO_DATASET))
    return errs
def _setup_table(step, table):
    if table.table_name is None:
        table.table_name = _generate_stepid_and_random_suffix(step)
    table_id = table.getFullTableId()
    _create_table(table_id)
[docs]class BQUniqueSmilesFilter(MolMolMixin, stepper.Chain):
    """
    A Chain that takes in Mol's, uploads them to BigQuery, and deduplicates
    them. To use, set the dataset and table name you'd like to use in
    the step settings. A table will be created in that dataset with that
    name.
    The BQ project is specified with the SCHRODINGER_GCP_PROJECT environment
    variable.
    """
    Settings = BQTable
[docs]    def setUp(self):
        _setup_table(self, self.settings)
        self[0].setSettings(**self.settings.toDict())
        # This is a bit of a hack to fix a bug when running this chain
        # with two levels of jobcontrol. See AD-359
        self._setConfig(self._getCanonicalizedConfig()) 
[docs]    def buildChain(self):
        self.addStep(_UploadToBQStep(**self.settings.toDict()))
        self.addStep(_DedupeStep())
        self.addStep(_ChunkBigQueryTable())
        self.addStep(_DownloadFromBQStep()) 
[docs]    def validateSettings(self):
        ret = super().validateSettings()
        return ret + _validate_table(self, self.settings)  
[docs]class BQRandomSampleFilter(MolMolMixin, stepper.Chain):
    """
    A Chain that takes in Mol's, uploads them to BigQuery, and outputs a random
    sample of them. To use, set the dataset and table name you'd like to use in
    the step settings. A table will be created in that dataset with that name.
    The settings also has a `n` setting for determining roughly
    how many rows should be sampled. Note that this is an approximate number
    and a few more or less may be output.
    The BQ project is specified with the SCHRODINGER_GCP_PROJECT environment
    variable.
    """
[docs]    class Settings(parameters.CompoundParam):
        table: BQTable
        n: int = 5000 
[docs]    def setUp(self):
        _setup_table(self, self.settings.table)
        self[0].setSettings(**self.settings.table.toDict())
        # This is a bit of a hack to fix a bug when running this chain
        # with two levels of jobcontrol. See AD-359
        self._setConfig(self._getCanonicalizedConfig()) 
[docs]    def buildChain(self):
        self.addStep(_UploadToBQStep(**self.settings.table.toDict()))
        self.addStep(_RandomSampleStep(n=self.settings.n))
        self.addStep(_ChunkBigQueryTable())
        self.addStep(_DownloadFromBQStep()) 
[docs]    def validateSettings(self):
        ret = super().validateSettings()
        return ret + _validate_table(self, self.settings.table)  
[docs]class BQDedupeAndRandomSampleFilter(MolMolMixin, stepper.Chain):
    """
    Same as BQRandomSampleFilter except the data is deduplicated before
    randomly sampled.
    """
[docs]    class Settings(parameters.CompoundParam):
        table: BQTable
        n: int = 5000 
[docs]    def setUp(self):
        _setup_table(self, self.settings.table)
        self[0].setSettings(**self.settings.table.toDict())
        # This is a bit of a hack to fix a bug when running this chain
        # with two levels of jobcontrol. See AD-359
        self._setConfig(self._getCanonicalizedConfig()) 
[docs]    def buildChain(self):
        self.addStep(_UploadToBQStep(**self.settings.table.toDict()))
        self.addStep(_DedupeStep())
        self.addStep(_RandomSampleStep(n=self.settings.n))
        self.addStep(_ChunkBigQueryTable())
        self.addStep(_DownloadFromBQStep()) 
[docs]    def validateSettings(self):
        ret = super().validateSettings()
        return ret + _validate_table(self, self.settings.table)