"""
Framework for writing computational workflows and running them in a highly
distributed manner. Each step of the workflow is either a "mapping" operation
(see `MapStep`) or "reducing" operation (see `ReduceStep). These steps can then
be chained together using the `Chain` class.
For a more complete introduction, see WordCount tutorial:
    https://confluence.schrodinger.com/display/~jtran/Stepper+WordCount+Tutorial
For documentation on specific stepper features, see the following feature list.
You can ctrl+f the feature tag to jump to the relevant docstrings.
+----------------+-------------------+
| Feature        | Tag               |
+================+===================+
| MapStep        | _map_step_        |
+----------------+-------------------+
| ReduceStep     | _reduce_step_     |
+----------------+-------------------+
| Chain          | _chain_           |
+----------------+-------------------+
| Settings       | _settings_        |
+----------------+-------------------+
| Serialization  | _serialization_   |
+----------------+-------------------+
| File Handling  | _file_handling_   |
+----------------+-------------------+
| Custom Workflow| _custom_workflows_|
+----------------+-------------------+
| Double Batching| _dbl_batching_    |
+----------------+-------------------+
#===============================================================================
# Running stepper with custom, undistributed workflows     <_custom_workflows_>
#===============================================================================
To run steps that aren't defined in the core suite:
The script should be executed inside the working directory and import steps from
a local package in the working directory.
Working dir contents::
    script.py
    my_lib/
        __init__.py
        steps.py
Minimal code in script.py if it needs to run under job control::
    from schrodinger.job import launchapi
    from schrodinger.ui.qt.appframework2 import application
    from my_lib.steps import MyStep
    def get_job_spec_from_args(argv):
        jsb = launchapi.JobSpecificationArgsBuilder(argv)
        jsb.setInputFile(__file__)
        jsb.setInputDirectory('my_lib')
        return jsb.getJobSpec()
    def main():
        step = MyStep()
        set.getOutputs()
    if __name__ == '__main__':
        application.run_application(main)
#===============================================================================
# Double Batching                                              <_dbl_batching_>
#===============================================================================
Job launch speeds at the time of writing is about one job per 3 or 4 seconds.
This rate becomes insufficient once we need more than a few hundred workers.
To get around this, stepper employs a pattern we coin "double batching", where
we create subjobs whose sole purpose is to themselves create the subjobs
that actually run the steps.
NOTE:: We use double-batching for the PubSub implementation of stepper as
well as the file-based implementation. The literal meaning of "double-batching"
doesn't apply as well to the PubSub implementation but the general pattern
of launching subjobs to launch more subjobs still applies.
#===============================================================================
# Environment variables and global settings
#===============================================================================
Settings:
    - SCHRODINGER_STEPPER_DEBUG
        Set to 1 to have most files brought back from a workflow run.
        Set to 2 to have _all_ files brought back.
    - SCHRODINGER_GCP_PROJECT
        Expected when running stepper with pubsub and bigquery. Should just
        be a string with the GCP project name, e.g. ad-pydev-dev
    - SCHRODINGER_GCP_KEY
        Expected when running stepper with pubsub and bigquery. Should be
        a path to the gcp service key. See
            https://cloud.google.com/iam/docs/creating-managing-service-account-keys
        for more information on generating gcp service keys.
    - SCHRODINGER_PUBSUB_TOPIC_PREFIX
        Optional setting. If set, all topics and subscriptions created during
        workflow runs will have the specified prefix added to the name. This is
        useful when a cloud provider searches topics by prefix and holds no
        project level separation, such as AWS.
    - SCHRODINGER_PUBSUB_TOPIC_SUFFIX
        Optional setting. If set, all topics and subscriptions created during
        workflow runs will have the specified suffix added to the name.
        This is useful for searching for all topics and subscriptions created
        for a particular run.
    - SCHRODINGER_GCP_DUPLICATE_SUBSCRIPTIONS
        Optional debug setting. If set, whenever a subscription is created,
        a second one will also be created. The second sub will have the
        same name plus an additional '_debug' appended to it. This is useful
        for debugging runs and looking at what data was generated by all pubsub
        steps.
    - SCHRODINGER_CLOUD_WORKER_TIMEOUT
        Optional debug setting. If set, pubsub workers will timeout after
        SCHRODINGER_CLOUD_WORKER_TIMEOUT minutes.
    - SCHRODINGER_GCP_NUM_PUBSUB_WORKERS
        Sets the default number of pubsub workers that will be used. If not
        set, one will be used. Note that this value can still be overridden
        by a workflow's configuration.
"""
import collections
import copy
import enum
import glob
import inspect
import itertools
import logging
import math
import os
import pprint
import re
import shutil
import subprocess
import time
import uuid
import zipfile
from typing import Any
from typing import Iterable
from typing import List
from typing import Set
from typing import Optional
import more_itertools
from ruamel import yaml
from schrodinger.application.steps import env_keys
from schrodinger.job import jobcontrol
from schrodinger.models import json
from schrodinger.models import parameters
from schrodinger.models import paramtools
from schrodinger.Qt import QtCore
from schrodinger.tasks import hosts
from schrodinger.tasks import jobtasks
from schrodinger.tasks import queue
from schrodinger.tasks import tasks
from schrodinger.ui.qt.appframework2 import application
from schrodinger.utils import env
from schrodinger.utils import imputils
MODULE_ROOT_BLACKLIST = ('schrodinger',)
DOUBLE_BATCH_THRESHOLD = float('inf')
TOPIC_PREFIX_LIMIT = 10
TOPIC_SUFFIX_LIMIT = 6
TOPIC_STEP_ID_LIMIT_FOR_AWS = 50
TOPIC_STEP_ID_LIMIT_FOR_GCP = 230
#===============================================================================
# Logging
# Stepper uses a special logger that includes a timestamp relative to the start
# time of a workflow. Note that by nature the logger and formatter are global
# objects.
#===============================================================================
[docs]def get_debug_level():
    return int(os.environ.get('SCHRODINGER_STEPPER_DEBUG', 0)) 
logger = logging.getLogger('schrodinger.stepper.stepper')
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
FORMATTER = ElapsedFormatter()
handler.setFormatter(FORMATTER)
logger.addHandler(handler)
#===============================================================================
# Batching
#===============================================================================
[docs]def ichunked(iterable, n):
    """
    Reimplementation of more_itertools.ichunked that does not cache
    n items of iterable at a time.
    Breaks `iterable` into sub-iterables with `n` elements each.
    Note that unlike more_itertools.ichunked, an error will be raised if
    you try to iterate over a chunk before its previous chunk has been
    consumed.
    """
    source = iter(iterable)
    _marker = object()
    while True:
        # Check to see whether we're at the end of the source iterable
        item = next(source, _marker)
        if item is _marker:
            return
        ichunk = itertools.islice(itertools.chain([item], source), n)
        yield ichunk
        try:
            next(ichunk)
        except StopIteration:
            pass
        else:
            raise RuntimeError("Previous chunks must be exhausted before "
                               "iterating over following chunks") 
def _assert_step_hasnt_started(func):
    """
    Decorator that prevents a step method from running if the output generator
    has already been created.
    """
    def wrapped_func(self, *args, **kwargs):
        if self._outputs_gen is not None:
            raise RuntimeError(
                f'Cannot call {func.__name__} because this step has already '
                'started (i.e. outputs() or getOutput() has already been called).'
            )
        return func(self, *args, **kwargs)
    return wrapped_func
def _prettify_time(time_in_float):
    utc_time = time.gmtime(time_in_float)
    return time.strftime('%Y-%m-%d %H:%M:%S %Z', utc_time)
def _prettify_duration(time_in_sec):
    def div_w_remainder(numer, denom):
        return int(numer // denom), numer % denom
    days, remaining_sec = div_w_remainder(time_in_sec, 24 * 60 * 60)
    hours, remaining_sec = div_w_remainder(remaining_sec, 60 * 60)
    minutes, remaining_sec = div_w_remainder(remaining_sec, 60)
    seconds = int(remaining_sec)
    pretty_string = f'{hours:02d}:{minutes:02d}:{seconds:02d}'
    if days:
        pretty_string = f'{days:02d}:{pretty_string}'
    return pretty_string
[docs]class ResourceType(enum.Enum):
    LOCAL = enum.auto()
    STATIC = enum.auto() 
class _StepperResource(json.JsonableClassMixin, str):
    """
    See `_BaseStep` for documentation.
    """
    LOCAL = ResourceType.LOCAL
    STATIC = ResourceType.STATIC
    def __new__(cls, value='', *args, **kwargs):
        # explicitly only pass value to the str constructor
        return super().__new__(cls, value)
    def __init__(self, path='', resource_type=LOCAL):
        self.resource_type = resource_type
    @classmethod
    def fromJsonImplementation(cls, json_obj):
        if isinstance(json_obj, str):
            return cls(json_obj)
        return cls(json_obj['path'],
                   ResourceType[json_obj['resource_type'].upper()])
    def toJsonImplementation(self):
        return {'path': str(self), 'resource_type': self.resource_type.name}
[docs]class StepperFile(_StepperResource):
    pass 
[docs]class StepperFolder(_StepperResource):
    pass 
class _DehydratedStep(parameters.CompoundParam):
    """
    See `_BaseStep._dehydrateStep` for documentation.
    """
    step_module_path: str
    step_class_name: str
    step_id: str
    step_config: dict
    starting_step_id: str = None
    input_file: StepperFile = None
[docs]class StepTaskOutput(parameters.CompoundParam):
    output_file: jobtasks.TaskFile = None
    run_info: dict
    misc_output_filenames: List[jobtasks.TaskFile] 
[docs]class StepTaskMixin(parameters.CompoundParamMixin):
    """
    This class must be mixed in with a subclass of AbstractComboTask. The
    resulting task class may be used to run any step as a task, provided the
    input, output, and settings classes are all JSONable.
    """
    input: StepTaskInput
    output: StepTaskOutput
    DEFAULT_TASKDIR_SETTING = tasks.AUTO_TASKDIR
[docs]    def __init__(self, *args, step=None, **kwargs):
        super().__init__(*args, **kwargs)
        self._step_class = None
        self._step = None
        if step is not None:
            self.setStep(step) 
[docs]    def addLicenseReservation(self, license, num_tokens=1):
        try:
            super().addLicenseReservation(license, num_tokens)
        except AttributeError:
            pass 
    def _setUpInputFile(self, filepath):
        """
        Given a filepath, do any necessary setup to register the file (e.g.
        add it to a list of input files) and return the path that should be
        used by the backend task. (e.g. the absolute path for a subprocess task
        or a relative path for a job task)
        """
        raise NotImplementedError
[docs]    def setStep(self, step):
        self._step = step
        for s in self._getAllStepsAndChains(step):
            def skip_none_map_func(map_func):
                def wrapped_map_func(param_val):
                    if param_val is None:
                        return None
                    return map_func(param_val)
                return wrapped_map_func
            paramtools.map_subparams(skip_none_map_func(self._setUpInputFile),
                                     s.settings, StepperFile)
            paramtools.map_subparams(skip_none_map_func(self._setUpInputFolder),
                                     s.settings, StepperFolder)
        dehyd_step = step._dehydrateStep()
        self.input.dehydrated_step = dehyd_step
        self._step_class = type(step)
        self._setUpStepTask(dehyd_step) 
    def _getAllStepsAndChains(self, step: '_BaseStep') -> Set['_BaseStep']:
        """
        Given a step, return a set of all steps it contains and itself.
        For example, given a chain A with the following topology::
                A
            |-------|
            B       C
                |-------|
                D       E
        this method will return::
            A -> set([A, B, C, D, E])
            B -> set[B])
            C -> set([C, D, E])
        """
        if not isinstance(step, Chain):
            return set([step])
        leaf_steps = set(more_itertools.collapse(step))
        max_depth = max((s.getStepId().count('.') for s in leaf_steps),
                        default=0)
        all_steps_and_chains = set([step])
        # Iteratively collapse to include every level up to the max depth
        for level in range(max_depth):
            all_steps_and_chains.update(
                more_itertools.collapse(step, levels=level))
        return all_steps_and_chains
[docs]    def getStepClass(self):
        return self._step_class 
    def _setUpStepTask(self, dehyd_step: _DehydratedStep):
        self._preprocessModuleRoot(dehyd_step)
        self._preprocessInputFiles(dehyd_step.input_file)
        if dehyd_step.input_file is not None:
            dehyd_step.input_file = self._setUpInputFile(dehyd_step.input_file)
    def _preprocessModuleRoot(self, dehyd_step: _DehydratedStep):
        """
        If the dehydrated step is defined in a package in a non-blacklisted
        folder in the working directory, add the package as an input folder for
        the task so it will be available for import in the backend task folder.
        If the step is defined in the main script we will not be able to import
        it in the backend, so a `ValueError` exception is raised.
        :raise ValueError: if the module root is __main__.
        """
        root = dehyd_step.step_module_path.split('.')[0]
        if root == '__main__':
            raise ValueError(
                f'Step class {dehyd_step.step_class_name} should be defined'
                f' outside of __main__.')
        if os.path.isdir(root) and not root.lower() in MODULE_ROOT_BLACKLIST:
            print(f'Using nonstandard package {root}')
            self._setUpInputFolder(root)
    def _preprocessInputFiles(self, input_file: str):
        """
        Before the starting the task, convert any input StepperFiles to paths
        relative to the backend machine.
        """
        if self._step.Input is StepperFile:
            # If the inputs for the steps are StepperFiles, then we need
            # to read the input file and register the inputs and convert
            # them to the right path for the backend.
            serializer = self._step._getInputSerializer()
            # Read the inputs and register them
            inp_files = []
            for inp in serializer.deserialize(input_file):
                inp_files.append(self._setUpInputFile(inp))
            # Write out the inputs again with the correct paths for
            # the backend
            serializer.serialize(inp_files, input_file)
    def _makeBackendStep(self):
        step = _rehydrate_step(self.input.dehydrated_step)
        self._step_class = type(step)
        if not self.input._double_batch:
            step.setBatchSettings(None)
        return step
[docs]    def mainFunction(self):
        try:
            self._step = self._makeBackendStep()
            step = self._step
            batch_outp_name = self.name + '.out'
            step.writeOutputsToFile(batch_outp_name)
            self.output.output_file = batch_outp_name
            self.output.run_info = step._run_info
            if self.input.debug_mode:
                self._runDebug()
        except Exception:
            self._registerOutputFilesFromDir(self.getTaskDir())
            raise 
    def _runDebug(self):
        pass
    def _postprocessOutputFiles(self):
        """
        After the task returns, convert any output StepperFiles to paths
        relative to the frontend machine.
        """
        if self._step.Output is StepperFile:
            output_file = self.output.output_file
            self._processOutputStepperFiles(output_file)
    def _processOutputStepperFiles(self, output_file: str):
        """
        Reads in the output file containing stepper files, processes them
        (e.g. register them as output files, convert them to the correct paths),
        and then writes them back out again.
        :param output_file: File storing list of unprocessed output stepper
            files
        """
        processed_outputs = []
        serializer = self._step.getOutputSerializer()
        for outp in serializer.deserialize(output_file):
            outp = self._setUpOutputFile(outp)
            processed_outputs.append(outp)
        serializer.serialize(processed_outputs, output_file)
    def _setUpOutputFile(self, outp_file):
        return StepperFile(
            os.path.join(self.getTaskDir(), self.getTaskFilename(outp_file))) 
[docs]class StepSubprocessTask(StepTaskMixin, tasks.ComboSubprocessTask):
    def _setUpInputFile(self, filepath):
        return StepperFile(os.path.abspath(filepath))
    def _setUpInputFolder(self, folderpath):
        return StepperFolder(os.path.abspath(folderpath))
    @tasks.postprocessor
    def _postprocessOutputFiles(self):
        return super()._postprocessOutputFiles()
    def _registerOutputFilesFromDir(self, *args):
        pass 
[docs]class StepJobTask(StepTaskMixin, jobtasks.ComboJobTask):
    _use_async_jobhandler: bool = True
    input: StepTaskInput
    output: StepTaskOutput
    def _makeBackendStep(self, *args, **kwargs):
        step = super()._makeBackendStep(*args, **kwargs)
        return step
    def _isLocalResource(self, fpath):
        return fpath.resource_type is ResourceType.LOCAL
    def _setUpInputFile(self, filepath):
        """
        Register a file as an input file and convert the path into a path
        that will be valid for the compute host. For example...
            self._setUpInputFile(StepperFile('/path/to/inps.txt')) -> 'inps.txt'
            self._setUpInputFile(StepperFile('inps.txt')) -> 'inps.txt'
            self._setUpInputFile(StepperFile('data/inps.txt')) -> 'data/inps.txt'
        Static filepaths will have their paths returned unchanged since the
        compute host will have access to them using the same path.
        """
        if self._isLocalResource(filepath):
            self.input.misc_input_filenames.append(filepath)
        if self._isLocalResource(filepath) and (os.path.isabs(filepath) or
                                                filepath.startswith('..')):
            return os.path.basename(filepath)
        else:
            return filepath
    def _setUpInputFolder(self, folderpath):
        """
        Register a file as an input file and convert the path into a path
        that will be valid for the compute host.
        """
        if self._isLocalResource(folderpath):
            self.addInputDirectory(folderpath)
        if self._isLocalResource(folderpath) and (os.path.isabs(folderpath) or
                                                  folderpath.startswith('..')):
            return os.path.basename(folderpath)
        else:
            return folderpath
    def _runDebug(self):
        # Register all input and output files so they're brought back to the
        # launch machine.
        self.output.misc_output_filenames.extend(
            list(_get_stepper_debug_files()))
[docs]    def mainFunction(self):
        self._job = jobcontrol.get_backend().getJob()
        super().mainFunction()
        if get_debug_level() >= 2:
            self._registerAllFiles()
        if self._step.Output is StepperFile:
            self._processOutputStepperFiles(self.output.output_file) 
    def _registerOutputFilesFromDir(self, dir):
        self.output.misc_output_filenames.append(dir)
    def _registerAllFiles(self):
        self._registerOutputFilesFromDir('.')
    @tasks.postprocessor
    def _postprocessOutputFiles(self):
        return super()._postprocessOutputFiles()
    def _setUpOutputFile(self, outp_file):
        result = StepperFile(
            os.path.join(self.getTaskDir(), self.getTaskFilename(outp_file)))
        self.output.misc_output_filenames.append(result)
        return result 
class _PubSubTaskInput(StepTaskInput):
    input_topic: str = None
    output_topic: str = None
    batch_size: int = None
class _PubSubTaskOutput(StepTaskOutput):
    num_outputs: int = None
    num_inputs: int = None
class _PubSubWorkerTask(StepJobTask):
    """
    A task implementing one PubSub worker. This task will take a step,
    input topic, and an output topic and run the step on batches of inputs
    pulled from input topic and upload the results to supplied output topic.
    """
    Input = _PubSubTaskInput
    Output = _PubSubTaskOutput
    def _makePubsubCmd(self):
        inp = self.input
        step = self._step
        args = [
            'stepper',
            inp.input_topic,
            inp.output_topic,
            self._getSettingsFilename(),
            step._getStepPath(),
            step.getStepId(),
        ]
        if inp.batch_size is not None:
            args.append(str(inp.batch_size))
        return [SCHRODINGER_RUN, *step.cloud_service_script] + args
    def _getSettingsFilename(self):
        step = self._step
        return f'{step.getStepId()}_settings.yaml'
    def _generateSettingsFile(self):
        settings_fname = self._getSettingsFilename()
        with open(settings_fname, 'wt') as settings_file:
            # We deepcopy the config to get rid of any special types that
            # yaml won't know how to process. See AD-378 for more info.
            yaml.dump(copy.deepcopy(self._step._getCanonicalizedConfig()),
                      settings_file)
    def mainFunction(self):
        self._step = self._makeBackendStep()
        self._generateSettingsFile()
        cmd = self._makePubsubCmd()
        stdout_list = []
        with subprocess.Popen(cmd,
                              stdout=subprocess.PIPE,
                              bufsize=1,
                              universal_newlines=True) as p:
            for line in p.stdout:
                stdout_list.append(line)
                print(line, end='')  # process line here
        if p.returncode != 0:
            self._registerAllFiles()
            raise subprocess.CalledProcessError(p.returncode, p.args)
        stdout = ''.join(stdout_list)
        last_line = stdout.split('\n')[-2]
        output = self.output
        output.num_inputs, output.num_outputs = map(
            int,
            last_line.strip('()').split(' '))
class _PubSubWorkerLauncherTask(StepJobTask):
    """
    A task implementing a PubSub worker launcher. This task's purpose is simply
    to launch many `_PubSubWorkerTask`s. This is necessary in order to launch
    enough `_PubSubWorkerTask` simultaneously given jobserver's job launch
    speed. See `_dbl_batching_` in the module string for more info.
    """
    Input = _PubSubTaskInput
    Output = _PubSubTaskOutput
    def mainFunction(self):
        """
        Launch DOUBLE_BATCH_THRESHOLD _PubSubWorkerTasks
        This is done by simply running the backend step set to this task.
        The backend step will already have batch settings on it with
        `num_pubsub_workers` set to to `DOUBLE_BATCH_THRESHOLD`, so we just
        need to set the topics on the step and run it.
        """
        step = self._step = self._makeBackendStep()
        self._step.setInputTopic(self.input.input_topic)
        self._step.setOutputTopic(self.input.output_topic)
        step.outputs()
        self.output.num_outputs = step._output_count
        self.output.num_inputs = step._input_count
    def _makeBackendStep(self):
        step = _rehydrate_step(self.input.dehydrated_step)
        self._step_class = type(step)
        return step
#===============================================================================
# Running steps in batches
#===============================================================================
def _get_default_num_pubsub_workers():
    return int(os.environ.get('SCHRODINGER_GCP_NUM_PUBSUB_WORKERS', '1'))
[docs]class BatchSettings(parameters.CompoundParam):
    size: int = None
    task_class: type = StepJobTask
    hostname: str = 'localhost'
    use_pubsub: bool
    num_pubsub_workers: int = 1  # default value set in initializeValue
[docs]    def initializeValue(self):
        super().initializeValue()
        self.num_pubsub_workers = _get_default_num_pubsub_workers()  
[docs]class Serializer:
    """                                                     <_serialization_>
    A class for defining special serialization for some datatype. Serialization
    by default uses the `json` protocol, but if a specialized protocol is wanted
    instead, users can subclass this class to do so.
    Subclasses should:
        - Define `DataType`. This is the class that this serializer can
            encode/decode.
        - Define `toString(self, output)`, which defines how to serialize
            an output.
        - Define `fromString(self, input_str)`, which defines how to
            deserialize an input.
    This can then be used as the `InputSerializer` or `OutputSerializer` for
    any step.
    Here's an example for defining an int that's serialized in base-two
    as opposed to base-ten::
        class IntBaseTwoSerializer(Serializer):
            DataType = int
            def toString(self, output):
                return bin(output) # 7 -> '0b111'
            def fromString(self, input_str):
                return int(input_str[2:], 2) # '0b111' -> 7
    This can then be used anywhere you'd use an int as the output or input in a
    step. For example::
        class SquaringStep(MapStep):
            Input = int
            InputSerializer = IntBaseTwoSerializer
            Output = int
            OutputSerializer = IntBaseTwoSerializer
            def mapFunction(self, inp):
                yield inp**2
    Now, any time that a `SquaringStep` would read its inputs from a file
    or write its outputs to a file, it'll do so using using a base-two
    representation.
    """
    DataType = NotImplemented
[docs]    def serialize(self, items, fname):
        """
        Write `items` to a file named `fname`.
        :type items: iterable[self.DataType]
        :type fname: str
        """
        with open(fname, 'w') as outfile:
            for outp in items:
                outfile.write(self.toString(outp) + '\n') 
[docs]    def deserialize(self, fname):
        """
        Read in items from `fname`.
        :type fname: str
        :rtype: iterable[self.DataType]
        """
        if fname is None:
            raise TypeError("deserialize called with None")
        with open(fname, 'r') as infile:
            for line in infile:
                inp = self.fromString(line.strip('\n'))
                yield inp 
[docs]    def fromString(self, input_str):
        raise NotImplementedError 
[docs]    def toString(self, output):
        raise NotImplementedError 
    @classmethod
    def __init_subclass__(cls):
        if cls.DataType is NotImplemented:
            raise NotImplementedError(
                "DataType must be specified for Serializers")
        super().__init_subclass__() 
class _DynamicSerializer(Serializer):
    """
    The default serializer that simply uses `json.loads` and `json.dumps`
    """
    DataType = object
    def __init__(self, dataclass):
        self._dataclass = dataclass
    def fromString(self, inp_str):
        try:
            return json.loads(inp_str, DataClass=self._dataclass)
        except:
            print(f"Error while trying to decode: {inp_str}")
            raise
    def toString(self, outp):
        return json.dumps(outp)
[docs]class ValidationIssue(RuntimeError):
[docs]    def __init__(self, source_step, msg):
        self.source_step = source_step
        self.msg = msg
        super().__init__(msg) 
    def __repr__(self):
        return f'{type(self).__name__}("{self.source_step.getStepId()}", "{self.msg}")'
    def __str__(self):
        return f'{type(self).__name__}("{self.source_step.getStepId()}", "{self.msg}")' 
[docs]class SettingsError(ValidationIssue):
    """
    Used in conjunction with `_BaseStep.validateSettings` to report an error
    with settings. Constructed with the step with the invalid settings and an
    error message, e.g.
    `SettingsError(bad_step, "Step does not have required settings."`)
    """ 
[docs]class SettingsWarning(ValidationIssue):
    """
    Used in conjunction with `_BaseStep.validateSettings` to report a warning
    with settings. Constructed with the step with the invalid settings and an
    error message, e.g.
    `SettingsWarning(bad_step, "Step setting FOO should ideally be positive"`)
    """ 
[docs]class ResourceError(ValidationIssue):
    """
    Used in conjunction with `_BaseStep.validateSettings` to report an error
    with a resource setting. Constructed with the step with the invalid setting
    and an error message, e.g.,
    `ResourceError(bad_step, "Step setting 'file' has not been set."`)
    """ 
[docs]class LocalResourceError(ResourceError):
    """
    A ResourceError specifically for local StepperFile and StepperFolder
    validations, i.e., resources that are on a job submission host and may have
    to be transferred to compute resources
    """ 
[docs]class StaticResourceError(ResourceError):
    """
    A ResourceError specifically for static StepperFile and StepperFolder
    validations, i.e., resources that are not necessarily available on a job
    submission host
    """ 
class _BaseStep(QtCore.QObject):
    """
    The features and behavior described in this docstring apply to all steps
    and chains.
    To use a step, instantiate it, set the inputs, and request outputs.
    Accessing outputs causes the step to get input from the input source and
    run the step operation. There is no concept of "running" or "starting" the
    step.
        class SquareStep(MapStep):
            def mapFunction(self, inp):
                yield inp * inp
        step = SquareStep()
        step.setInputs([1, 2, 3])
        print(step.getOutputs()) # [1, 4, 9]
    The outputs are produced with a generator. Thus, calling
    `step.getOutputs()` twice will always result in an empty list for the
    second call.
    Settings
    ========                                                <_settings_>
    Every step can parameterize how it operates using a set of settings. The
    settings of a step are defined as a subclass of `CompoundParam` at the
    class level, and can be set per-instance using keyword arguments at
    instantiation time. Example::
        class MultiplyByStep(MapStep):
            class Settings(parameters.CompoundParam):
                multiplier: int = 1
        by_4_step = MultiplyByStep(multiplier=4)
        by_4_step.setInputs([1, 2, 3])
        by_4_step.getOutputs() == [4, 8, 12]
    =============
    Configuration
    =============
    A configuration is a dictionary that specifies settings values for steps
    within a chain.
    A step can take a configuration dictionary that maps step
    selectors to default setting values. For example::
        Chain(config={'A':{'max_rounds':10}})
    This configuration will go through `Chain` and set all settings of A step's
    to have `max_rounds` value of 10.
    There are three currently supported selectors:
        General selectors e.g. "A":
            This will select all steps of type "A" (Note that this does not
            select subclasses of "A")
        Child selectors e.g. "A>B"
            This will select all steps of type "B" that
            are in chains of type "A". Multiple ">" operators can be linked
            together. For example, "A>B>C" will select all "C" steps in "B"
            chains which are in the "A" chain.
        ID selector e.g. "A.B_0"
            This will select the first "B" step in chain "A". The top level
            chain never has an index. Steps in a chain are indexed relative to
            other steps of the same type in that chain. For example,
            if chain "A" is composed of steps BCBCC, then the ids would be
            "A.B_0", "A.C_0", "A.B_1", "A.C_1", "A.C_2"
    As a convenience, you can set the special key __DEFAULT_BATCH_SETTINGS__
    to a dictionary to use as the new default batch settings.
    ============================================================================
    File Handling                                           <_file_handling_>
    ============================================================================
    To specify a file, use the `StepperFile` class as the input type, output
    type, or as a subparam on the `Settings` class. Local files specified in
    these locations will automatically be copied to and from compute machines.
    You can similarly specify `StepperFolder` to have folders copied over
    to compute machines. Currently, `StepperFolder` can only be used with
    step settings, not as step inputs or outputs.
    Strings specified in `config` for `StepperFile` and `StepperFolder` will
    be automatically cast.
    If a step depends on a shared resource (e.g. on a shared filesystem,
    built into a node's image), the file or folder can be marked as a STATIC
    resource, signifying to the framework that it does not need to be
    copied over. To do this, set a stepper file or folder with a dictionary
    specifying path and the resource type. For example::
        my_step = MyStep(config={
            'compute_library':{
                'path':'/path/to/shared/resource',
                'resource_type':'STATIC'
            })
    ========
    Licenses
    ========
    Some steps may require a license for each node that it's run on. All
    batchable steps support this feature.
    To specify the number of license reservations a step requires, override
    `getLicenseRequirements` and return a dictionary mapping licenses
    to the number of tokens required for that license. For example::
        from schrodinger.utils import license
        class LicenseRequiringStep(MapStep):
            Input = str
            Output = str
            def getLicenseRequirements(self):
                return {license.GLIDE_MAIN: 2}
    Once you've specified what licenses are required, any batched steps will
    automatically have the right number of licenses reserved.
    .. NOTE:: Batched `Chain` by default account for any reservations that
        might be necessary to run any component steps.
    """
    Input = None
    InputSerializer = _DynamicSerializer
    Output = None
    OutputSerializer = _DynamicSerializer
    Settings = parameters.CompoundParam
    def __init__(self,
                 settings=None,
                 config=None,
                 step_id=None,
                 _run_info=None,
                 **kwargs):
        super().__init__()
        if not step_id:
            self._step_id = type(self).__name__
        else:
            self._step_id = step_id
        if _run_info is None:
            _run_info = collections.defaultdict(dict)
        self._setRunInfo(_run_info)
        self._outputs_gen = None
        self.setSettings(settings, **kwargs)
        self._setCompositionPath(type(self).__name__)
        self._setConfig(config)
        self._input_file = None
        self._inputs = None
        self._input_count = 0
    @classmethod
    def __init_subclass__(cls):
        """
        Validate the validity of the class.
        """
        cls._validateInputSerializer()
        cls._validateOutputSerializer()
        if (not isinstance(cls.Settings, type) or
                not issubclass(cls.Settings, parameters.CompoundParam)):
            raise TypeError("Custom settings must subclass CompoundParam")
        super().__init_subclass__()
    @classmethod
    def _validateInputSerializer(cls):
        if cls.InputSerializer is not _DynamicSerializer:
            if cls.Input is None or not issubclass(
                    cls.Input, cls.InputSerializer.DataType):
                msg = (
                    'Incompatible InputSerializer specified. \n'
                    f'Step "{cls.__name__}" has Input "{cls.Input}" '
                    f'but InputSerializer has DataType "{cls.InputSerializer.DataType}"'
                )
                raise TypeError(msg)
    @classmethod
    def _validateOutputSerializer(cls):
        if cls.OutputSerializer is not _DynamicSerializer:
            if cls.Output is None or (
                    cls.Output != cls.OutputSerializer.DataType and
                    not issubclass(cls.Output, cls.OutputSerializer.DataType)):
                msg = (
                    'Incompatible OutputSerializer specified. \n'
                    f'Step "{cls.__name__}" has Output "{cls.Output}" '
                    f'but OutputSerializer has DataType "{cls.OutputSerializer.DataType}"'
                )
                raise TypeError(msg)
    def _getCanonicalizedConfig(self):
        return {self.getStepId(): self.settings.toDict()}
    def report(self, prefix=''):
        """
        Report the settings and batch settings for this step.
        """
        logger.info(f'{prefix} - {self.getStepId()}')
        all_options = [self.settings]
        if hasattr(self, '_batch_settings'):
            all_options.append(self._batch_settings)
        for opts in all_options:
            if opts and opts.toDict():
                logger.info(
                    f'{prefix}     {opts.__class__.__name__}: {opts.toDict()}')
    def prettyPrintRunInfo(self):
        """
        Format and print info about the step's run.
        """
        run_info = copy.deepcopy(self.getRunInfo())
        self._prettifyRunInfo(run_info)
        # Listify the dict into tuples since prettyprint doesnt respect
        # dictionary order
        run_info = list(run_info.items())
        pprint.pprint(run_info)
    def _prettifyRunInfo(self, run_info_dict):
        """
        Recurse through `run_info_dict` and listify dicts into item tuples.
        This improves the readability of pretty-print and preserves the
        dictionary insertion order.
        """
        for k, v in run_info_dict.items():
            if isinstance(v, dict):
                self._prettifyRunInfo(v)
    def __copy__(self):
        copied_step = type(self)(settings=copy.copy(self.settings),
                                 config=self._getCanonicalizedConfig(),
                                 step_id=self.getStepId())
        return copied_step
    def _getInputSerializer(self):
        if issubclass(self.InputSerializer, _DynamicSerializer):
            return _DynamicSerializer(dataclass=self.Input)
        else:
            return self.InputSerializer()
    def getOutputSerializer(self):
        if issubclass(self.OutputSerializer, _DynamicSerializer):
            return _DynamicSerializer(dataclass=self.Output)
        else:
            return self.OutputSerializer()
    def _validateStepperFileSettings(self):
        """
        Look through settings for StepperFiles and StepperFolders and
        confirms that that they're set to valid files and folder paths.
        ResourceErrors will be returned for StepperFile or StepperFolder
        instances that are static resources.
        :return: A list of `SettingsError`, one for each invalid stepper file
        :rtype: list[SettingsError or ResourceError]
        """
        results = []
        if self.settings is None:
            return results
        settings = self.settings
        for subparam_name, abstract_subparam in self.Settings.getSubParams(
        ).items():
            what = f"<{self._step_id}> setting '{subparam_name}'"
            resource = abstract_subparam.getParamValue(settings)
            if abstract_subparam.DataClass in (StepperFile, StepperFolder):
                if resource is None:
                    results.append(
                        SettingsError(self, f"{what} has not been set."))
                    continue
                error = (StaticResourceError if
                         (resource and
                          resource.resource_type is resource.STATIC) else
                         LocalResourceError)
            if abstract_subparam.DataClass is StepperFile:
                if not os.path.isfile(resource):
                    results.append(
                        error(
                            self,
                            f"{what} set to invalid file path: '{str(resource)}'"
                        ))
                elif (resource.resource_type is resource.STATIC and
                      not os.path.isabs(resource)):
                    results.append(
                        error(
                            self,
                            f"{what} is set as a static file with a relative"
                            f" path: {str(resource)}"))
            if abstract_subparam.DataClass is StepperFolder:
                if not os.path.isdir(resource):
                    results.append(
                        error(
                            self,
                            f"{what} set to invalid dir path: '{str(resource)}'"
                        ))
                elif (resource.resource_type is resource.STATIC and
                      not os.path.isabs(resource)):
                    results.append(
                        error(
                            self,
                            f"{what} is set as a static folder with a relative"
                            f" path: {str(resource)}"))
        return results
    def validateSettings(self):
        """
        Check whether the step settings are valid and return a list of
        `SettingsError` and `SettingsWarning` to report any invalid settings.
        Default implementation checks that all stepper files are set to valid
        file paths.
        :rtype: list[TaskError or TaskWarning]
        """
        return self._validateStepperFileSettings()
    def getResources(self, param_type, resource_type):
        """
        Get the stepper resources in the settings that are instances of
        `param_type` and have a resource_type attribute that is `resource_type`.
        Note does not work for list/set/tuple subparams in the settings.
        :param param_type: the resource parameter type
        :type param_type: _StepperResource
        :param resource_type: the type of resource to get
        :type resource_type: ResourceType
        :return: the set of stepper resources of `resource_type`
        :rtype: set of _StepperResource
        """
        if self.settings is None:
            return {}
        def _add(value):
            if value and value.resource_type == resource_type:
                resources.add(value)
            return value
        resources = set()
        paramtools.map_subparams(_add, self.settings, param_type)
        return resources
    def _setCompositionPath(self, path):
        """
        Update the composition path. The composition path is the string
        that defines a steps ancestry. For example, a composition path "A>B>C"
        means that this step, C, is in a chain B, which is itself in a chain
        A.
        """
        self._comp_path = path
    def _setStepId(self, new_id):
        self._step_id = new_id
    def getStepId(self):
        return self._step_id
    def _setRunInfo(self, run_info):
        self._run_info = run_info
    def getRunInfo(self):
        return self._run_info
    def _setConfig(self, config):
        def split(path):
            return re.split('[.>]', path)
        if config:
            # Sort by number of split items in the selectors so that we apply
            # child selectors by order of selectivity.
            if '__sorted' not in config:
                config = dict(
                    sorted(config.items(),
                           key=lambda item: len(split(item[0]))))
                config['__sorted'] = True
            for k in config:
                split_k = split(k)
                last_comp_path = split(self._comp_path)[-len(split_k):]
                # only apply the settings if the last items in the composition
                # path matches all items in the selector key
                # to avoid C>BA or C.BA getting settings from A
                if last_comp_path == split_k:
                    self._applyConfigSettings(config[k])
            # Apply ID selector settings last so they take final priority
            if self._step_id in config:
                self._applyConfigSettings(config[self._step_id])
        self._config = config
    def _applyConfigSettings(self, new_settings):
        if new_settings:
            for k, v in new_settings.items():
                if v is None:
                    continue
                if not hasattr(self.Settings, k):
                    raise SettingsError(
                        self, f"Step \"{type(self).__name__}\""
                        f" has no setting \"{k}\"")
            self.settings.setValue(**new_settings)
            def deserialize_res(resource, resource_class):
                if not isinstance(resource,
                                  resource_class) and resource is not None:
                    return resource_class.fromJson(resource)
                return resource
            paramtools.map_subparams(
                lambda res: deserialize_res(res, StepperFile), self.settings,
                StepperFile)
            paramtools.map_subparams(
                lambda res: deserialize_res(res, StepperFolder), self.settings,
                StepperFolder)
    def setInputFile(self, fname):
        self._input_file = fname
        self.setInputs(self._inputsFromFile(fname))
    def _inputsFromFile(self, fname):
        serializer = self._getInputSerializer()
        yield from serializer.deserialize(fname)
    def writeOutputsToFile(self, fname):
        """
        Write outputs to `fname`. By default, the output file will consist of
        one line for each output with whatever is produced when passing the out-
        put to `str`. Override this method if more complex behavior is needed.
        """
        serializer = self.getOutputSerializer()
        serializer.serialize(self.outputs(), fname)
    def setUp(self):
        """
        Hook for adding any type of work that needs to happen before any
        outputs are created.
        """
        pass
    def cleanUp(self):
        """
        Hook for adding any type of work that needs to happen after all
        outputs are exhausted or if some outputs are created and the step
        is destroyed.
        """
        pass
    @_assert_step_hasnt_started
    def setSettings(self, settings=None, **kwargs):
        """
        Supply the settings for this step to use when running. The supplied
        settings must match the Settings class or, if None is passed in, a
        default settings object will be used.
        """
        if settings is not None and kwargs:
            raise ValueError('Cannot specify both settings and kwargs')
        elif self.Settings is None:
            if settings is not None or kwargs:
                raise ValueError("Specified settings for a step that doesn't "
                                 "expect settings")
        elif settings is None:
            settings = self.Settings(**kwargs)
        elif not isinstance(settings, self.Settings):
            raise ValueError(f"settings should be of type {self.Settings}, not "
                             f"{type(settings)}.")
        self.settings = settings
    @_assert_step_hasnt_started
    def setInputs(self, inputs):
        """
        Set the input source for this step. This should be an iterable. Items
        from the input source won't actually be accessed until the outputs for
        this step are accessed.
        """
        if inputs is None:
            inputs = []
        self._inputs = inputs
    def inputs(self):
        yield from self._inputs
    @_assert_step_hasnt_started
    def outputs(self):
        """
        Creates the output generator for this step and returns it.
        """
        self.setUp()
        self._run_info[self.getStepId()] = {}
        outputs_gen = self._makeOutputGenerator()
        outputs_gen = self._outputsWithCounting(outputs_gen)
        self._outputs_gen = self._cleanUp_after_generator(outputs_gen)
        return self._outputs_gen
    def _outputsWithCounting(self, output_gen):
        self._output_count = 0
        self._end_time = None
        def wrapped_output_gen():
            for output in output_gen:
                self._output_count += 1
                yield output
            self._end_time = time.time()
            self._updateRunInfo()
        return wrapped_output_gen()
    def _cleanUp_after_generator(self, gen):
        """
        Call the step's cleanUp method once the generator has been
        exhausted.
        """
        try:
            for output in gen:
                yield output
        finally:
            self.cleanUp()
    def _updateRunInfo(self):
        step_run_info = self._run_info[self.getStepId()]
        start_time = getattr(self, '_start_time', None)
        end_time = getattr(self, '_end_time', None)
        step_run_info['num_inputs'] = self._input_count
        step_run_info['num_outputs'] = getattr(self, '_output_count', 0)
        if start_time and end_time:
            duration = self._end_time - self._start_time
        elif start_time:
            duration = time.time() - self._start_time
        else:
            duration = None
        if start_time:
            step_run_info['start_time'] = _prettify_time(start_time)
        if end_time:
            step_run_info['end_time'] = _prettify_time(end_time)
        if duration:
            step_run_info['duration'] = _prettify_duration(duration)
    def _getElapsedTime(self):
        if self._start_time is None:
            raise RuntimeError("Can't get elapsed time when step hasn't been "
                               "started.")
        return _prettify_duration(time.time() - self._start_time)
    def _makeOutputGenerator(self):
        raise NotImplementedError()
    def getOutputs(self):
        """
        Gets all the outputs in a list by fully iterating the output generator.
        """
        return list(self.outputs())
    def getLicenseRequirements(self):
        return {}
def _rehydrate_step(dehydrated_step: _DehydratedStep):
    """
    Recreate the step that `dehydrated_step` was created from.
    """
    with env.prepend_sys_path(os.getcwd()):
        step_module = imputils.get_module_from_path(
            dehydrated_step.step_module_path)
    step_class = getattr(step_module, dehydrated_step.step_class_name)
    return step_class._rehydrateStep(dehydrated_step)
SCHRODINGER_RUN = os.path.join(os.environ['SCHRODINGER'], 'run')
def _clean_up_task(task):
    assert task.status in (task.DONE, task.FAILED)
    assert task.taskDirSetting() is not None
    shutil.rmtree(task.getTaskDir())
class _BatchableStepMixin:
    """
    A step that can distribute its input into multiple batches and processes
    them in parallel as tasks. Example::
        # Running a batcher as a single step
        b = ProcessSmilesChain(batch_size=10)
        b.setInputFile(smiles_filename)
        for output in b.outputs():
            print(output)
    """
    def __init__(self, *args, batch_size=None, batch_settings=None, **kwargs):
        if batch_size and batch_settings:
            raise ValueError("Can't pass both batch_size and batch_settings")
        elif batch_size is not None:
            batch_settings = BatchSettings(size=batch_size)
        self._batch_settings = batch_settings
        super().__init__(*args, **kwargs)
    @_assert_step_hasnt_started
    def setBatchSettings(self, batch_settings):
        """
        Set the batch settings for this step. Will raise an exception if this
        is done after the step has already started processing inputs.
        :type batch_settings: BatchSettings
        """
        self._batch_settings = batch_settings
    def _prettifyRunInfo(self, run_info_dict):
        super()._prettifyRunInfo(run_info_dict)
        if 'batches' in run_info_dict:
            batch_infos = []
            if not isinstance(run_info_dict['batches'], dict):
                return
            for batch_job_id, batch_info in run_info_dict['batches'].items():
                self._prettifyRunInfo(batch_info)
                batch_infos.append((batch_job_id, list(batch_info.items())))
            run_info_dict['batches'] = batch_infos
    def _setConfig(self, config):
        if config:
            if defaults := config.get('__DEFAULT_BATCH_SETTINGS__'):
                if not '__DEFAULTS_APPLIED__' in config:
                    for k, v in config.items():
                        if isinstance(v, dict) and 'batch_settings' in v:
                            v['batch_settings'] = {
                                **defaults,
                                **v['batch_settings']
                            }
                    config['__DEFAULTS_APPLIED__'] = True
                if self._batch_settings is not None:
                    self._batch_settings.setValue(**defaults)
        return super()._setConfig(config)
    def _applyConfigSettings(self, new_settings):
        new_settings = copy.deepcopy(new_settings)
        if 'batch_settings' in new_settings:
            batch_settings = new_settings['batch_settings']
            if batch_settings is None:
                self.setBatchSettings(new_settings.pop('batch_settings'))
            else:
                for k in new_settings['batch_settings']:
                    if not hasattr(BatchSettings, k):
                        raise SettingsError(
                            self,
                            f"Specified batch setting does not exist: \"{k}\"")
                self.setBatchSettings(
                    BatchSettings(**new_settings.pop('batch_settings')))
        super()._applyConfigSettings(new_settings)
    def _getCanonicalizedConfig(self):
        """
        Return a config that can be used to set the settings for a different
        instance of this step to the same settings as this step.
        """
        if isinstance(self.settings, parameters.CompoundParam):
            canon_config = super()._getCanonicalizedConfig()
            if self._batch_settings:
                batch_settings_dict = self._batch_settings.toDict()
                # Setting task class through config is currently unsupported
                batch_settings_dict.pop('task_class')
                canon_config[
                    self.getStepId()]['batch_settings'] = batch_settings_dict
            return canon_config
        return {}
    def _dehydrateStep(self):
        """
        Create a `_DehydratedStep` from this instance of a step. A dehydrated
        step has all the information necessary to recreate a step sans inputs
        and can be serialized in a json file.
        """
        dehyd = _DehydratedStep()
        step_module = inspect.getmodule(self)
        dehyd.step_module_path = imputils.get_path_from_module(step_module)
        dehyd.step_class_name = type(self).__name__
        dehyd.step_id = self._step_id
        dehyd.step_config = self._getCanonicalizedConfig()
        if self._input_file is not None:
            dehyd.input_file = StepperFile(self._input_file)
        return dehyd
    def _getStepPath(self):
        step_module = inspect.getmodule(self)
        step_module_path = imputils.get_path_from_module(step_module)
        step_class_name = type(self).__name__
        return f"{step_module_path}.{step_class_name}"
    @classmethod
    def _rehydrateStep(cls, dehydrated_step):
        """
        Recreate the step that `dehydrated_step` was created from.
        """
        step = cls(step_id=dehydrated_step.step_id,
                   config=dehydrated_step.step_config)
        if dehydrated_step.input_file:
            step.setInputFile(dehydrated_step.input_file)
        return step
    def _makeStep(self, input_file):
        step = copy.copy(self)
        step.setInputFile(input_file)
        return step
    def getLicenseRequirements(self):
        return {}
    def _makeBatchTask(self, batch_file, double_batch: bool):
        step = self._makeStep(batch_file)
        task = self._batch_settings.task_class(step=step)
        task.input._double_batch = double_batch
        if issubclass(self._batch_settings.task_class, StepJobTask):
            task.job_config.host_settings.host = hosts.Host(
                self._batch_settings.hostname)
        if not double_batch:
            for req_license, num_tokens in self.getLicenseRequirements().items(
            ):
                task.addLicenseReservation(req_license, num_tokens)
        return task
    def _queueBatchSteps(self, task_queue):
        for batch_num, batch_file, double_batch in self._splitInputsIntoBatchFiles(
        ):
            application.process_events()
            task = self._makeBatchTask(batch_file, double_batch)
            task.name, _ = os.path.splitext(os.path.basename(batch_file))
            task_queue.addTask(task)
    def _splitInputsIntoBatchFiles(self):
        serializer = self._getInputSerializer()
        inps = self._inputsWithCounting()
        continue_with_double_batching = False
        MAX_BATCHES = int(
            os.environ.get('SCHRODINGER_MAX_NUM_BATCHES', 999999999))
        for batch_num, batch_of_lines in enumerate(
                ichunked(inps, self._batch_settings.size)):
            batch_fname = self.getStepId() + '_batch_' + str(batch_num) + '.in'
            serializer.serialize(batch_of_lines, batch_fname)
            yield batch_num, batch_fname, False
            if batch_num + 1 >= MAX_BATCHES:
                break
            if batch_num + 1 >= DOUBLE_BATCH_THRESHOLD:
                continue_with_double_batching = True
                break
        if continue_with_double_batching:
            double_batch_size = self._batch_settings.size * DOUBLE_BATCH_THRESHOLD
            double_batches = ichunked(inps, double_batch_size)
            for batch_num, batch_of_lines in enumerate(double_batches,
                                                       start=batch_num + 1):
                batch_fname = self.getStepId() + '_batch_' + str(
                    batch_num) + '.in'
                serializer.serialize(batch_of_lines, batch_fname)
                yield batch_num, batch_fname, True
                if batch_num + 1 >= MAX_BATCHES:
                    break
    @_assert_step_hasnt_started
    def outputs(self):
        """
        Like the super class method, returns a generator for the outputs.
        Calling the generator begins the batching process by requesting outputs
        from the input source (previous step), accumulating them into batches
        of the specified size, and queuing them all up.
        """
        if self._batch_settings is None:
            return super().outputs()
        else:
            self._start_time = time.time()
            FORMATTER.start()
            self._run_info[self.getStepId()] = {
                'batches': collections.defaultdict(dict)
            }
            task_dj = queue.TaskDJ(max_failures=queue.NOLIMIT)
            self._queueBatchSteps(task_dj)
            if not task_dj.waiting_jobs:
                # We didn't have any batches to process, so just return early
                return []
            outputs_gen = self._makeBatchedOutputsGenerator(task_dj)
            outputs_gen = self._outputsWithCounting(outputs_gen)
            self._outputs_gen = outputs_gen
            return outputs_gen
    def _updateBatchRunInfo(self, batch_name, new_batch_info):
        stepid = self.getStepId()
        batch_info = self._run_info[stepid]['batches'][batch_name]
        batch_info.update(new_batch_info)
        batch_info.update(batch_info.pop(stepid))
    def _makeBatchedOutputsGenerator(self, task_dj):
        for task in task_dj.updatedTasks():
            if task.status is task.DONE:
                self._updateBatchRunInfo(task.name, task.output.run_info)
                branch_count = task.name.count('.')
                logger.info(f'{">"*branch_count}START {task.name} log')
                logger.info(task.getLogAsString().strip())
                logger.info(f'{">"*branch_count}END {task.name} log')
                task.wait()
                outp_file = task.output.output_file
                assert os.path.isfile(outp_file), outp_file
                serializer = self.getOutputSerializer()
                for outp in serializer.deserialize(outp_file):
                    yield outp
            elif task.status is task.FAILED:
                logger.error("task failed")
                branch_count = task.name.count('.')
                logger.error(f"FAILURE WHEN RUNNING {task.name}")
                try:
                    _write_repro_file(task)
                except Exception:
                    logger.error(
                        "Error when writing the reproduction zip. Try "
                        f"reproducing manually with {task.name}'s inputs.")
                else:
                    logger.error(
                        f"Files for reproducing step saved to: {task.name}_repro.rzip"
                    )
                logger.error(f'{">"*branch_count}START {task.name} log')
                logger.error(task.getLogAsString())
                logger.error(f'{">"*branch_count}END {task.name} log')
def _topic_cloud_console_link(pubsub_topic: str) -> str:
    """
    Given a pubsub topic name (and assuming there's a subscription to it with
    the same name), return a url to view the subscription on the browser.
    """
    if env_keys.is_aws_service_available():
        # FIXME: JIRA-ID
        return 'no_link_for_now'
    # else GCP
    return (
        "https://console.cloud.google.com/cloudpubsub/subscription/detail/" +
        pubsub_topic + "?project=" + os.environ['SCHRODINGER_GCP_PROJECT'])
[docs]class PubsubEnabledStepMixin:
    """
    A mixin that allows a step to be run using PubSub.
    Steps with this mixin will have batch settings that have a `use_pubsub`
    flag and a `num_pubsub_workers` integer. Flipping `use_pubsub` to on will
    have the step load up all its inputs into a pubsub topic before spinning
    up `num_pubsub_workers` subjobs that will all take from the input topic,
    run the step's computation on it, and upload it to an output topic.
    Calling `my_pubsub_step.getOutputs()` will return all the outputs from the
    output topic, so to a user this will all be implementation detail.
    """
[docs]    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setInputTopic(None)
        self.setOutputTopic(None) 
    @property
    def topic_prefix(self):
        if not hasattr(self, '_topic_prefix'):
            self._topic_prefix = os.environ.get(
                'SCHRODINGER_PUBSUB_TOPIC_PREFIX', '')
        return self._topic_prefix
    @property
    def topic_suffix(self):
        if not hasattr(self, '_topic_suffix'):
            self._topic_suffix = os.environ.get(
                'SCHRODINGER_PUBSUB_TOPIC_SUFFIX',
                str(uuid.uuid4())[:6])
        return self._topic_suffix
    @property
    def cloud_service_script(self):
        if env_keys.is_aws_service_available():
            return ['sqs.py']
        return ['python3', '-m', 'schrodinger.stepper._cloud.gcp']
[docs]    @_assert_step_hasnt_started
    def outputs(self):
        if self.usingPubsub():
            self.initializeTopics()
            self._runWithPubsub()
            return self._deserializeFromOutputTopic()
        else:
            return super().outputs() 
[docs]    def usingPubsub(self):
        return False 
    def _generateInputTopicName(self):
        return self._generateTopicName('inputs')
    def _generateOutputTopicName(self):
        return self._generateTopicName('outputs')
    def _generateTopicName(self, topic_type):
        if env_keys.is_aws_service_available():
            step_id_limit = TOPIC_STEP_ID_LIMIT_FOR_AWS
        else:
            step_id_limit = TOPIC_STEP_ID_LIMIT_FOR_GCP
        sections = [
            self.topic_prefix[:TOPIC_PREFIX_LIMIT],
            self.getStepId()[:step_id_limit], topic_type,
            self.topic_suffix[:TOPIC_SUFFIX_LIMIT]
        ]
        return '_'.join(sections).strip('_')
[docs]    def getOutputTopic(self) -> Optional[str]:
        if isinstance(self, Chain) and len(self):
            return self._output_topic or self[-1].getOutputTopic()
        else:
            return self._output_topic 
[docs]    def setOutputTopic(self, outp_topic: Optional[str]):
        self._output_topic = outp_topic 
[docs]    def initializeTopics(self):
        batch_settings = self._batch_settings
        if batch_settings is None:
            raise RuntimeError(
                "Can't initialize topics for a step that's not "
                "using pubsub. To use pubsub for a step, set batch settings "
                "on it that have `use_pubsub` set to True.")
        if batch_settings.use_pubsub is False:
            raise RuntimeError(
                "Can't initialize topics for a step that's not "
                "using pubsub. To use pubsub for a step, set batch settings "
                "on it that have `use_pubsub` set to True.")
        script = self.cloud_service_script
        if self.getInputTopic() is None:
            inp_topic = self._generateInputTopicName()
            self._input_topic = inp_topic
            subprocess.run(
                [SCHRODINGER_RUN, *script, 'create', inp_topic, inp_topic])
            self._uploadToTopic(self._inputsWithCounting(),
                                self._getInputSerializer(), inp_topic)
        if self.getOutputTopic() is None:
            outp_topic = self._generateOutputTopicName()
            self.setOutputTopic(outp_topic)
            subprocess.run(
                [SCHRODINGER_RUN, *script, 'create', outp_topic, outp_topic]) 
    def _uploadToTopic(self, generator, serializer, topic):
        inp_fname = f"{topic}_msgs.txt"
        serializer.serialize(generator, inp_fname)
        subprocess.run([
            SCHRODINGER_RUN, *self.cloud_service_script, 'upload', topic,
            inp_fname
        ])
    def _downloadFromTopic(self, topic, fname):
        subprocess.run([
            SCHRODINGER_RUN, *self.cloud_service_script, 'download', topic,
            fname
        ])
    def _runWithPubsub(self):
        self._start_time = time.time()
        batch_settings = self._batch_settings
        tasks = []
        double_batch = batch_settings.num_pubsub_workers > DOUBLE_BATCH_THRESHOLD
        if double_batch:
            num_workers = math.ceil(batch_settings.num_pubsub_workers /
                                    DOUBLE_BATCH_THRESHOLD)
            base_worker_name = f'{self.getStepId()}_worker_launcher'
            batch_settings.num_pubsub_workers = DOUBLE_BATCH_THRESHOLD
            worker_class = _PubSubWorkerLauncherTask
        else:
            num_workers = batch_settings.num_pubsub_workers
            base_worker_name = f'{self.getStepId()}_worker'
            worker_class = _PubSubWorkerTask
        for idx in range(num_workers):
            task = worker_class(step=copy.copy(self))
            task.job_config.host_settings.host = hosts.Host(
                batch_settings.hostname)
            task.input.input_topic = self.getInputTopic()
            task.input.output_topic = self.getOutputTopic()
            if batch_settings.size is not None:
                task.input.batch_size = batch_settings.size
            for req_license, num_tokens in self.getLicenseRequirements().items(
            ):
                task.addLicenseReservation(req_license, num_tokens)
            tasks.append(task)
            task.taskDone.connect(self._onTaskDone)
            task.taskStarted.connect(self._onTaskStarted)
            task.taskFailed.connect(self._onTaskFailed)
        successful_tasks = queue.run_tasks_in_parallel(
            tasks, basename=base_worker_name)
        msgs_pulled = sum(t.output.num_inputs for t in successful_tasks)
        msgs_pushed = sum(t.output.num_outputs for t in successful_tasks)
        if (self._input_count != 0 and self._input_count != msgs_pulled):
            logger.warning(
                f"{self.getStepId()}: The number of messages uploaded to the input topic "
                f"({self._input_count}) was not equal to the number of "
                f"messages pulled and processed ({msgs_pulled})")
        self._input_count = msgs_pulled
        self._output_count = msgs_pushed
        self._end_time = time.time()
        self._updateRunInfo()
        run_info = self._run_info[self.getStepId()]
        inp_topic = self.getInputTopic()
        run_info['input_topic'] = inp_topic
        run_info['input_topic_link'] = _topic_cloud_console_link(inp_topic)
        outp_topic = self.getOutputTopic()
        run_info['output_topic'] = outp_topic
        run_info['output_topic_link'] = _topic_cloud_console_link(outp_topic)
    def _deserializeFromTopic(self, serializer, topic):
        topic_download_file = f"{topic}_msgs_{str(uuid.uuid4())[:6]}.txt"
        self._downloadFromTopic(topic, topic_download_file)
        for output in serializer.deserialize(topic_download_file):
            yield output
    def _deserializeFromInputTopic(self):
        yield from self._deserializeFromTopic(self._getInputSerializer(),
                                              self.getInputTopic())
    def _deserializeFromOutputTopic(self):
        yield from self._deserializeFromTopic(self.getOutputSerializer(),
                                              self.getOutputTopic())
    def _onTaskDone(self):
        pass
        #task = self.sender()
        #_clean_up_task(task)
    def _onTaskStarted(self):
        task = self.sender()
        print(f'Batch {task.name} started')
    def _onTaskFailed(self):
        task = self.sender()
        print(f'Batch {task.name} failed!')
        try:
            print(task.getLogAsString().strip())
        except Exception as e:
            print(f"{e} raised while trying to print failed task's log") 
class _BatchableStepMixin(PubsubEnabledStepMixin, _BatchableStepMixin):
    def usingPubsub(self):
        return bool(self._batch_settings and self._batch_settings.use_pubsub)
[docs]class UnbatchedReduceStep(_BaseStep):
    """"
    An unbatchable ReduceStep. See ReduceStep for more information.
    """
    def _makeOutputGenerator(self):
        self._start_time = time.time()
        FORMATTER.start()
        return self.reduceFunction(self._inputsWithCounting())
    def _inputsWithCounting(self):
        self._updateRunInfo()
        if self._inputs is None:
            raise RuntimeError(
                f"Inputs have not been set for {self.getStepId()}")
        for input in self._inputs:
            self._input_count += 1
            yield input
[docs]    def reduceFunction(self, inputs):
        raise NotImplementedError  
[docs]class ReduceStep(_BatchableStepMixin, UnbatchedReduceStep):
    """                                                     <_reduce_step_>
    A computational step that performs a function on a collection of inputs
    to produce output items.
    To construct a ReduceStep:
        * Implement reduceFunction
        * Define Input (the type expected by the mapFunction)
        * Define Output (the type of item produced by the mapFunction)
        * Define Settings (data class for any settings needed by the
          mapFunction)
    """
[docs]    def reduceFunction(self, inputs):
        """
        The main computation for this step. This function should take in a
        iterable of inputs and return an iterable of outputs.
        Example::
            def reduceFunction(self, words):
                # Find all unique words
                seen_words = set()
                for word in words:
                    if word not in seen_words:
                        seen_words.add(word)
                        yield word
        """
        return super().reduceFunction(inputs)  
[docs]class UnbatchedMapStep(UnbatchedReduceStep):
    """                                                     <_unbatchability_>
    An unbatchable MapStep. See MapStep for more information.
    """
[docs]    def reduceFunction(self, inputs):
        for input in inputs:
            for output in self.mapFunction(input):
                yield output 
[docs]    def mapFunction(self, input):
        raise NotImplementedError()  
[docs]class MapStep(_BatchableStepMixin, UnbatchedMapStep):
    """                                                         <_map_step_>
    A computational step that performs a function on input items from an input
    source to produce output items.
    To construct a MapStep:
    * Implement mapFunction
    * Define Input (the type expected by the mapFunction)
    * Optionally define a InputSerializer (see `Serializer` for more info.)
    * Define Output (the type of item produced by the mapFunction)
    * Optionally define a OutputSerializer (see `Serializer` for more info.)
    * Define Settings (data class for any settings needed by the mapFunction)
    """
[docs]    def mapFunction(self, input):
        """
        The main computation for this step. This function should take in a
        single input item and return an iterable of outputs. This allows a
        single output to produce multiple ouputs (e.g. enumeration).
        The output may be yielded as a generator, in order to reduce memory
        usage.
        If only a single output is produced for each input, return it as a
        single-element list.
        :param input: this will be a single input item from the input source.
            Implementer is encouraged to use a more descriptive, context-
            specific variable name. Example:
                def mapFunction(self, starting_smiles):
                    ...
        """
        return super().mapFunction(input)  
[docs]class UnbatchedChain(UnbatchedReduceStep):
    @property
    def Input(self):
        if not self._steps:
            return super().Input
        return self[0].Input
    @property
    def Output(self):
        if not self._steps:
            return super().Output
        return self[-1].Output
    @property
    def InputSerializer(self):
        if not self._steps:
            return super().InputSerializer
        return self[0].InputSerializer
    @property
    def OutputSerializer(self):
        if not self._steps:
            return super().OutputSerializer
        return self[-1].OutputSerializer
    # Since the serializers are just inferred from the steps and the steps
    # have their serializers validated, we don't do it at chain declaration
    # level.
    @classmethod
    def _validateInputSerializer(cls):
        pass
    @classmethod
    def _validateOutputSerializer(cls):
        pass
    def __copy__(self):
        copied_step = super().__copy__()
        copied_step.setStartingStep(self._starting_step_id)
        return copied_step
[docs]    def setStartingStep(self, starting_step: str):
        if starting_step is not None:
            self._validateStartingStepId(starting_step)
        self._starting_step_id = starting_step 
[docs]    def validateSettings(self):
        """
        Check whether the chain settings are valid and return a list of
        `SettingsError` and `SettingsWarning` to report any invalid settings.
        Default implementation simply returns problems from all child steps.
        :rtype: list[TaskError or TaskWarning]
        """
        problems = []
        for step in self:
            problems += step.validateSettings()
        return problems 
[docs]    def getResources(self, param_type, resource_type):
        """
        Get the stepper resources in the settings for the chain as well as for
        every step in the chain that are instances of `param_type` and have a
        resource_type attribute that is `resource_type`.
        Note does not work for list/set/tuple subparams in the settings.
        :param param_type: the resource parameter type
        :type param_type: _StepperResource
        :param resource_type: the type of resource to get
        :type resource_type: ResourceType
        :return: the set of stepper resources of `resource_type`
        :rtype: set of _StepperResource
        """
        resources = super().getResources(param_type, resource_type)
        for step in self:
            resources |= step.getResources(param_type, resource_type)
        return resources 
[docs]    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setStartingStep(None)
        self._updateChain() 
    def __getitem__(self, idx):
        return self._steps[idx]
    def _setStepId(self, new_id):
        super()._setStepId(new_id)
        self._updateChain()
[docs]    def __len__(self):
        return len(self._steps) 
    def _setConfig(self, config):
        super()._setConfig(config)
        self._updateChain()
    def _getCanonicalizedConfig(self):
        """
        Return a config that can be used to set the settings for a different
        instance of this chain and its substeps to the same settings as this
        chain and its substeps.
        """
        config = super()._getCanonicalizedConfig()
        for child_step in self:
            config.update(child_step._getCanonicalizedConfig())
        return config
    def _updateChain(self):
        self._steps = []
        self.buildChain()
        self._updateComponentStepIDs()
        self._updateComponentStepConfigs()
        self.validateChain()
    def _updateComponentStepIDs(self):
        step_type_counter = collections.Counter()
        for step in self:
            step_count = step_type_counter[type(step)]
            step._setStepId(
                f'{self._step_id}.{type(step).__name__}_{step_count}')
            step_type_counter[type(step)] += 1
[docs]    def addStep(self, step):
        self._steps.append(step)
        step._setCompositionPath(self._comp_path + '>' + step._comp_path)
        step._setRunInfo(self._run_info) 
    def _updateComponentStepConfigs(self):
        for step in self:
            step._setConfig(self._config)
[docs]    def report(self, prefix=''):
        """
        Report the workflow  steps and their settings (recursively).
        :param prefix: the text to start each line with
        :type prefix: str
        """
        super().report(prefix)
        for step in self:
            step.report(prefix + '  ') 
[docs]    def validateChain(self):
        """
        Checks that the declaration of the chain is internally consistent - i.e.
        that each step is valid and each step's Input class matches the
        preceding step's Output class.
        """
        if len(self) == 0:
            return
        for prev_step, next_step in more_itertools.pairwise(self):
            err_msg = (f"Mismatched Input and Output.\n"
                       f"Previous step: {prev_step}\n"
                       f"Output: {prev_step.Output}\n"
                       f"Next step: {next_step}\n"
                       f"Input: {next_step.Input}\n")
            if None in (next_step.Input, prev_step.Output):
                assert prev_step.Output is next_step.Input, err_msg
            else:
                assert prev_step.Output == next_step.Input or issubclass(
                    prev_step.Output, next_step.Input), err_msg
        first_step = self[0]
        msg = (f'Mismatched input of first step. The Input for the chain'
               f'("{type(self).__name__}") is specified as {self.Input}'
               ' but the Input for the first step '
               f'("{type(first_step).__name__}") is {first_step.Input}')
        assert first_step.Input is self.Input, msg
        last_step = self[-1]
        msg = (f'Mismatched output of last step. The Output for the chain'
               f'("{type(self).__name__}") is specified as {self.Output}'
               ' but the Output for the last step '
               f'("{type(last_step).__name__}") is {last_step.Output}')
        assert last_step.Output is self.Output, msg 
    def _validateStartingStepId(self, step_id: str):
        """
        Checks to see if the `step_id` actually matches a step in this chain.
        If not, raise a ValueError.
        """
        if step_id == self.getStepId():
            return
        for idx, step in enumerate(self):
            if step_id.startswith(step.getStepId()):
                if isinstance(step, Chain):
                    step._validateStartingStepId(step_id)
                    break
                else:
                    if step.getStepId() == step_id:
                        break
        else:
            raise ValueError("Invalid starting step ID: " + step_id)
[docs]    def reduceFunction(self, inputs):
        self._updateChain()
        if len(self) == 0:
            return inputs
        # Determine starting step and propagate starting step id
        starting_step_id = self._starting_step_id
        starting_step_idx = 0
        if starting_step_id is not None:
            for idx, step in enumerate(self):
                if self._starting_step_id.startswith(step.getStepId()):
                    starting_step_idx = idx
                    break
        starting_step = self[starting_step_idx]
        if starting_step_id and isinstance(starting_step, Chain):
            starting_step.setStartingStep(starting_step_id)
        # Set inputs, whether by input topic or by generator
        if (isinstance(starting_step, PubsubEnabledStepMixin) and
                self._input_topic):
            starting_step.setInputTopic(self._input_topic)
        else:
            starting_step.setInputs(inputs)
        for prev_step, next_step in more_itertools.pairwise(
                self[starting_step_idx:]):
            self._connectSteps(prev_step, next_step)
        # Set outputs
        last_step = self[-1]
        return last_step.outputs() 
    def _connectSteps(self, prev_step, next_step):
        output_gen = prev_step.outputs()
        prev_has_output_topic = (isinstance(prev_step, PubsubEnabledStepMixin)
                                 and prev_step.getOutputTopic() is not None)
        compatible_serializers = prev_step.OutputSerializer is next_step.InputSerializer
        next_using_pubsub = isinstance(next_step, PubsubEnabledStepMixin)
        if (prev_has_output_topic and compatible_serializers and
                next_using_pubsub):
            next_step.setInputTopic(prev_step.getOutputTopic())
        else:
            next_step.setInputs(output_gen)
[docs]    def buildChain(self):
        """
        This method must be implemented by subclasses to build the chain. The
        chain is built by modifying self.steps. The chain's composition may be
        dependent on self.settings.
        """
        raise NotImplementedError()  
[docs]class Chain(_BatchableStepMixin, UnbatchedChain):
    """                                                             <_chain_>
    Run a series of steps. The steps must be created by overriding buildChain.
    """
[docs]    def getLicenseRequirements(self):
        req_licenses = collections.Counter()
        for step in self:
            if not (isinstance(step, _BatchableStepMixin) and
                    step._batch_settings is not None):
                req_licenses = req_licenses | collections.Counter(
                    step.getLicenseRequirements())
        return dict(req_licenses) 
    def _dehydrateStep(self):
        dehyd = super()._dehydrateStep()
        dehyd.starting_step_id = self._starting_step_id
        return dehyd
    @classmethod
    def _rehydrateStep(cls, dehydrated_step: _DehydratedStep) -> 'Chain':
        """
        Recreate the step that `dehydrated_step` was created from.
        """
        step = super()._rehydrateStep(dehydrated_step)
        step.setStartingStep(dehydrated_step.starting_step_id)
        return step 
def _line_count(filename):
    count = 0
    with open(filename, 'r') as file:
        for line in file:
            count += 1
    return count
### Debugging helper methods, not for use in production.
def _get_all_stepper_input_files():
    input_file_pattern = os.path.join('**', '*.in')
    return glob.glob(input_file_pattern, recursive=True)
def _get_all_stepper_output_files():
    output_file_pattern = os.path.join('**', '*.out')
    return glob.glob(output_file_pattern, recursive=True)
def _get_all_stepper_zip_files():
    output_file_pattern = os.path.join('**', '*.rzip')
    return glob.glob(output_file_pattern, recursive=True)
def _write_repro_file(steptask):
    """
    Write a rzip with...
        - the input file for the step
        - the yaml config file for the step
        - a command for rerunning the step with the above input files
        - any necessary settings files/folders
    """
    repro_fname = f'{steptask.name}_repro.rzip'
    with zipfile.ZipFile(repro_fname, 'w') as repro_zipfile:
        dehyd_step = steptask._step._dehydrateStep()
        for step_id, step_settings in dehyd_step.step_config.items():
            if step_id.startswith(steptask._step.getStepId()):
                for name, value in step_settings.items():
                    if isinstance(value, StepperFile):
                        repro_zipfile.write(value, value)
                    elif isinstance(value, StepperFolder):
                        for root, _, files in os.walk(value):
                            for filename in files:
                                src_path = os.path.join(root, filename)
                                repro_zipfile.write(src_path)
        yaml_fname = f'{steptask.name}.yaml'
        with open(yaml_fname, 'w') as yaml_file:
            yaml.dump(dict(dehyd_step.step_config), yaml_file)
        cmd_fname = f'{steptask.name}.sh'
        with open(cmd_fname, 'w') as cmd_file:
            cmd_file.write(
                f'$SCHRODINGER/run stepper.py '
                f'{dehyd_step.step_module_path}.{dehyd_step.step_class_name} '
                f'{dehyd_step.input_file} bad_step.out -config {yaml_fname} '
                f'-workflow-id {dehyd_step.step_id}')
        repro_zipfile.write(dehyd_step.input_file,
                            os.path.basename(dehyd_step.input_file))
        repro_zipfile.write(yaml_fname)
        repro_zipfile.write(cmd_fname)
def _get_stepper_debug_files():
    # Return all stepper repro zip files
    return _get_all_stepper_zip_files()