"""
Steps for deduplicating and random sampling data (currently limited to
use with Mols).
See the chains at the bottom of the module for the steps you'll most likely want
to use.
"""
import copy
import time
import itertools
import os
import uuid
import more_itertools
from schrodinger.application.steps import basesteps
from schrodinger.models import parameters
from schrodinger import stepper
from schrodinger.stepper import logger
from .basesteps import MolMolMixin
try:
from google.auth import compute_engine
from google.cloud import bigquery
from google.cloud import exceptions
from google.cloud import storage
from google.oauth2 import service_account
except ImportError:
compute_engine = None
bigquery = None
storage = None
exceptions = None
# Scopes
BIGQUERY_SCOPE = "https://www.googleapis.com/auth/bigquery"
CLOUD_SCOPE = "https://www.googleapis.com/auth/cloud-platform"
# Error messages
NO_DATASET = 'No bigquery dataset set'
_NO_PROJECT_ERR_MSG = (
"No bigquery project is defined for this run. Set "
"an environment variable for SCHRODINGER_GCP_PROJECT and try again.")
#===============================================================================
# BigQuery Functions
#===============================================================================
PROJECT = os.environ.get('SCHRODINGER_GCP_PROJECT')
KEY_PATH = os.environ.get('SCHRODINGER_GCP_KEY') # Service account key
BQ_CLIENT = None
def _create_table(table_id):
table_id = _get_fully_qualified_table_id(table_id)
client = _get_bq_client()
schema = [bigquery.SchemaField("data", "STRING", mode="REQUIRED")]
table = bigquery.Table(table_id, schema=schema)
client.create_table(table)
logger.debug(f"creating table... {table_id}")
def _is_bigquery_enabled() -> bool:
return bool(os.environ.get('SCHRODINGER_GCP_ENABLED', False))
def _wait_for_query_job(job):
result = job.result()
if job.errors:
err_str = (f"Errors while running query {job.query}\n" +
f"{job.errors=}")
logger.error(err_str)
return result
def _dedupe_table(input_table_id, output_table_id):
bq_client = _get_bq_client()
input_table_id = _get_fully_qualified_table_id(input_table_id)
output_table_id = _get_fully_qualified_table_id(output_table_id)
job_config = bigquery.QueryJobConfig(destination=output_table_id,
write_disposition='WRITE_TRUNCATE')
query_str = f"""
SELECT
data
FROM
`{input_table_id}`
GROUP BY
data"""
query_job = bq_client.query(query_str, job_config=job_config)
_wait_for_query_job(query_job)
logger.debug(f"deduped table {input_table_id}")
def _generate_credentials():
if KEY_PATH is None:
return None
credentials = service_account.Credentials.from_service_account_file(
KEY_PATH,
scopes=[CLOUD_SCOPE, BIGQUERY_SCOPE],
)
return credentials
def _generate_clients():
credentials = _generate_credentials()
bq_client = bigquery.Client(project=PROJECT, credentials=credentials)
return bq_client
def _get_bq_client():
global BQ_CLIENT
if BQ_CLIENT is None:
BQ_CLIENT = _generate_clients()
return BQ_CLIENT
def _get_fully_qualified_table_id(table_id):
if PROJECT is None:
raise ValueError(_NO_PROJECT_ERR_MSG)
if PROJECT not in table_id:
return f'{PROJECT}.{table_id}'
else:
return table_id
def _stream_in_batches(gen,
serializer,
table_id,
chunk_size=10000,
skip_sleep=False):
"""
Load batches of outputs into a table specified by `table_id`. outputs are
batched so csv files are around `csv_size_limit` bytes. The csv files
are written in chunks of `chunk_size` before being checked for size.
After streaming in the data, this function will sleep for 5 seconds.
This is to give BigQuery enough time to process the new results, otherwise
any queries that happen immediately after this function will sometimes
not process the new data. You can set `skip_sleep` to True if you don't
expect to make any queries soon after.
:param gen: A generator of outputs to load into the table
:type gen: Iterator
:param serializer: A serializer to serialize the outputs, see `Serializer`
:type serializer: Serializer
:param table_id: The table to load the outputs into. Should include both
dataset and table name, i.e. "<DATASET>.<TABLE>"
:type table_id: str
"""
client = _get_bq_client()
table_id = _get_fully_qualified_table_id(table_id)
output_generator = more_itertools.peekable(gen)
def get_chunk_of_rows(gen):
rows = []
for output in itertools.islice(gen, chunk_size):
rows.append((serializer.toString(output),))
return rows
fields = [bigquery.SchemaField("data", "STRING", mode="REQUIRED")]
while True:
rows_to_insert = get_chunk_of_rows(output_generator)
if not rows_to_insert:
break
errors = client.insert_rows(
table_id, rows_to_insert,
selected_fields=fields) # Make an API request.
if errors:
logger.error(f"while streaming in rows: {errors}")
if not skip_sleep:
time.sleep(5)
logger.debug(f"streamed in data to... {table_id}")
def _random_sample_table(src_table_id, dest_table_id, n):
bq_client = _get_bq_client()
input_table_id = _get_fully_qualified_table_id(src_table_id)
output_table_id = _get_fully_qualified_table_id(dest_table_id)
job_config = bigquery.QueryJobConfig(destination=output_table_id,
write_disposition='WRITE_TRUNCATE')
query_str = f"""
SELECT
data
FROM
`{input_table_id}`
WHERE
RAND() < {n}/(SELECT COUNT(*) FROM `{src_table_id}`);"""
query_job = bq_client.query(query_str, job_config=job_config)
_wait_for_query_job(query_job)
def _get_table_data(table_id, starting_idx=None, num_rows=None):
"""
Get contents of a table specified by `table_id`. If `starting_idx` is
specified, then the data will start at that row index. Up to `num_rows`
will be returned.
"""
if table_id is None:
raise TypeError("table_id must be string, not None")
bq_client = _get_bq_client()
table_id = _get_fully_qualified_table_id(table_id)
def _unwrap_row_iterator():
for row in bq_client.list_rows(table_id,
start_index=starting_idx,
max_results=num_rows):
yield row['data']
return _unwrap_row_iterator()
def _table_row_count(table_id):
# We use a query to get the number of rows. This has higher costs than
# just checking a property on the table but table properties don't update
# for streaming inserts.
table_id = _get_fully_qualified_table_id(table_id)
bq_client = _get_bq_client()
query = ("SELECT data " f"FROM `{table_id}`")
query_job = bq_client.query(query,)
results = _wait_for_query_job(query_job)
return results.total_rows
#===============================================================================
# Data and Setting Classes
#===============================================================================
[docs]class BQTable(parameters.CompoundParam):
dataset: str = 'bq_testing_dataset'
table_name: str = None
[docs] def getFullTableId(self):
if not self.table_name:
raise ValueError("Table name is not specified.")
return self.dataset.replace('.', '-') + '.' + self.table_name.replace(
'.', '-')
[docs]class TableChunk(parameters.CompoundParam):
start_idx: int
chunk_size: int
table: BQTable
class _DownloadSettings(BQTable):
chunk_size: int = 10000
#===============================================================================
# Steps
#===============================================================================
class _UploadToBQStep(basesteps.UploadStep):
Settings = BQTable
Output = BQTable
def reduceFunction(self, inps):
table_id = self.settings.getFullTableId()
_stream_in_batches(inps, self._getInputSerializer(), table_id)
yield self.settings
def validateSettings(self):
if PROJECT is None:
return [stepper.SettingsError(self, _NO_PROJECT_ERR_MSG)]
return []
class _ChunkBigQueryTable(stepper.MapStep):
class Settings(parameters.CompoundParam):
chunk_size: int = 10000
Input = BQTable
Output = TableChunk
def mapFunction(self, table):
num_rows = _table_row_count(table.getFullTableId())
for idx in range(0, num_rows, self.settings.chunk_size):
table_copy = copy.deepcopy(table)
yield TableChunk(start_idx=idx,
chunk_size=self.settings.chunk_size,
table=table_copy)
class _DownloadFromBQStep(basesteps.DownloadStep):
Input = TableChunk
def mapFunction(self, inp):
op_serializer = self.getOutputSerializer()
for datum in _get_table_data(inp.table.getFullTableId(),
starting_idx=inp.start_idx,
num_rows=inp.chunk_size):
yield op_serializer.fromString(datum)
class _DedupeStep(basesteps.TableReduceStep):
Input = BQTable
Output = BQTable
def _actOnTable(self, table_id):
_dedupe_table(table_id, table_id)
class _RandomSampleStep(basesteps.TableReduceStep):
"""
Deduplication step with random sampling enabled.
Sampling occurs after deduplication.
The config's `n` specifies the average number of rows to keep.
"""
class Settings(parameters.CompoundParam):
n: int = 5000
Output = BQTable
Input = BQTable
def _actOnTable(self, table_id):
_random_sample_table(table_id, table_id, self.settings.n)
#===============================================================================
# Chains
# To use these chains with pubsub, specify `use_pubsub` in the batch settings
# for the steps `_UploadToBQStep` and `_DownloadFromBQStep`
#===============================================================================
def _generate_stepid_and_random_suffix(step):
return f"{step.getStepId()}_{str(uuid.uuid4())[:8]}"
def _validate_table(step, table):
errs = []
if not table.dataset:
errs.append(stepper.SettingsError(step, NO_DATASET))
return errs
def _setup_table(step, table):
if table.table_name is None:
table.table_name = _generate_stepid_and_random_suffix(step)
table_id = table.getFullTableId()
_create_table(table_id)
[docs]class BQUniqueSmilesFilter(MolMolMixin, stepper.Chain):
"""
A Chain that takes in Mol's, uploads them to BigQuery, and deduplicates
them. To use, set the dataset and table name you'd like to use in
the step settings. A table will be created in that dataset with that
name.
The BQ project is specified with the SCHRODINGER_GCP_PROJECT environment
variable.
"""
Settings = BQTable
[docs] def setUp(self):
_setup_table(self, self.settings)
self[0].setSettings(**self.settings.toDict())
# This is a bit of a hack to fix a bug when running this chain
# with two levels of jobcontrol. See AD-359
self._setConfig(self._getCanonicalizedConfig())
[docs] def buildChain(self):
self.addStep(_UploadToBQStep(**self.settings.toDict()))
self.addStep(_DedupeStep())
self.addStep(_ChunkBigQueryTable())
self.addStep(_DownloadFromBQStep())
[docs] def validateSettings(self):
ret = super().validateSettings()
return ret + _validate_table(self, self.settings)
[docs]class BQRandomSampleFilter(MolMolMixin, stepper.Chain):
"""
A Chain that takes in Mol's, uploads them to BigQuery, and outputs a random
sample of them. To use, set the dataset and table name you'd like to use in
the step settings. A table will be created in that dataset with that name.
The settings also has a `n` setting for determining roughly
how many rows should be sampled. Note that this is an approximate number
and a few more or less may be output.
The BQ project is specified with the SCHRODINGER_GCP_PROJECT environment
variable.
"""
[docs] class Settings(parameters.CompoundParam):
table: BQTable
n: int = 5000
[docs] def setUp(self):
_setup_table(self, self.settings.table)
self[0].setSettings(**self.settings.table.toDict())
# This is a bit of a hack to fix a bug when running this chain
# with two levels of jobcontrol. See AD-359
self._setConfig(self._getCanonicalizedConfig())
[docs] def buildChain(self):
self.addStep(_UploadToBQStep(**self.settings.table.toDict()))
self.addStep(_RandomSampleStep(n=self.settings.n))
self.addStep(_ChunkBigQueryTable())
self.addStep(_DownloadFromBQStep())
[docs] def validateSettings(self):
ret = super().validateSettings()
return ret + _validate_table(self, self.settings.table)
[docs]class BQDedupeAndRandomSampleFilter(MolMolMixin, stepper.Chain):
"""
Same as BQRandomSampleFilter except the data is deduplicated before
randomly sampled.
"""
[docs] class Settings(parameters.CompoundParam):
table: BQTable
n: int = 5000
[docs] def setUp(self):
_setup_table(self, self.settings.table)
self[0].setSettings(**self.settings.table.toDict())
# This is a bit of a hack to fix a bug when running this chain
# with two levels of jobcontrol. See AD-359
self._setConfig(self._getCanonicalizedConfig())
[docs] def buildChain(self):
self.addStep(_UploadToBQStep(**self.settings.table.toDict()))
self.addStep(_DedupeStep())
self.addStep(_RandomSampleStep(n=self.settings.n))
self.addStep(_ChunkBigQueryTable())
self.addStep(_DownloadFromBQStep())
[docs] def validateSettings(self):
ret = super().validateSettings()
return ret + _validate_table(self, self.settings.table)