Source code for schrodinger.stepper.sideinputs
"""
Utility steps for creating side inputs in stepper workflows.
For example, in a workflow with steps A, B, C, and D, a ForkStep and JoinStep
can be set up so that all outputs from A are passed along to D. This allows
outputs from A to get to D even if B or C would normally filter those inputs.
Example::
    class MyWorkflow(stepper.Chain):
        def buildChain(self):
            a = A()
            self.addStep(a)
            fork = ForkStep(step=a)
            self.addStep(fork)
            self.addStep(B())
            self.addStep(C())
            self.addStep(JoinStep(fork=fork))
            self.addStep(D())
"""
import os
import uuid
from schrodinger.models import parameters
from schrodinger import stepper
SCHRODINGER_RUN = os.path.join(os.environ['SCHRODINGER'], 'run')
[docs]class ForkStep(stepper.UnbatchedReduceStep):
    """
    A step to save some inputs to be reprocessed again. See the module
    docstring for more info and an example.
    """
[docs]    def __init__(self, step):
        self.Input = self.Output = step.Output
        self.InputSerializer = self.OutputSerializer = step.OutputSerializer
        self._pipe_fname = f'.{str(uuid.uuid4())}.forkfile'
        super().__init__() 
[docs]    def reduceFunction(self, inps):
        serializer = self.getOutputSerializer()
        with open(self._pipe_fname, 'w') as outfile:
            for inp in inps:
                outfile.write(f"{serializer.toString(inp)}\n")
                yield inp 
[docs]    def getPipeFilename(self):
        return self._pipe_fname 
[docs]    def report(self, prefix=''):
        stepper.logger.info(f'{prefix} - {self.getStepId()}')  
class JoinStep(stepper.UnbatchedReduceStep):
    """
    A step to read some inputs saved by a preceding ForkStep. See the
    module docstring for more info and an example.
    """
    def __init__(self, fork):
        self._fork = fork
        self.InputSerializer = self.OutputSerializer = fork.OutputSerializer
        self._in_fname = fork.getPipeFilename()
        super().__init__()
    @property
    def Input(self):
        return self._fork.Input
    @property
    def Output(self):
        return self._fork.Output
[docs]    def reduceFunction(self, inps):
        yield from inps
        yield from self.getOutputSerializer().deserialize(self._in_fname) 
[docs]    def report(self, prefix=''):
        stepper.logger.info(
            f'{prefix} - {self.getStepId()} <- {self._fork.getStepId()}') 
class JoinFromFileStep(stepper.UnbatchedReduceStep):
    """
    A step for injecting inputs read from a file into a chain. To use, add into
    your chain and set the step's `join_file` setting to the path of your
    datafile.
    """
[docs]    class Settings(parameters.CompoundParam):
        join_file: stepper.StepperFile = None 
    def __init__(self, Input=None, InputSerializer=None, **kwargs):
        if Input is None and InputSerializer is None:
            raise TypeError("Must set either Input or InputSerializer at "
                            "step initialization time.")
        elif Input is not None and InputSerializer is not None:
            raise TypeError("Can't set both Input _and_ InputSerializer")
        if InputSerializer:
            Input = InputSerializer.DataType
            self.InputSerializer = self.OutputSerializer = InputSerializer
        self.Input = self.Output = Input
        super().__init__(**kwargs)
[docs]    def reduceFunction(self, inps):
        serializer = self._getInputSerializer()
        yield from serializer.deserialize(self.settings.join_file)
        yield from inps 
#==============================================================================
# PUBSUB FUNCTIONALITY
#
# The below code adds functionality to the Join* so extra inputs will just
# be added to the input topics. This optimizes the Join* steps so they don't
# unnecessarily read topics just to append a few extra inputs.
#
# NOTE: This section is designed to be 'transparent', meaning if a chain
# is not using pubsub, all JoinSteps will behave normally. Additionally,
# if this section of code is removed, the base functionality of the Join*
# will still work.
#==============================================================================
[docs]class JoinStep(stepper.PubsubEnabledStepMixin, JoinStep):
[docs]    def outputs(self):
        if self.usingPubsub():
            inp_topic = self.getInputTopic()
            self.setOutputTopic(inp_topic)
            extra_outputs = self.getOutputSerializer().deserialize(
                self._in_fname)
            self._uploadToTopic(extra_outputs, self.getOutputSerializer(),
                                inp_topic)
            return self._deserializeFromOutputTopic()
        else:
            return super().outputs() 
[docs]    def usingPubsub(self):
        return bool(self.getInputTopic())  
[docs]class JoinFromFileStep(stepper.PubsubEnabledStepMixin, JoinFromFileStep):
[docs]    def outputs(self):
        if self.usingPubsub():
            inp_topic = self.getInputTopic()
            self.setOutputTopic(inp_topic)
            extra_outputs = self.getOutputSerializer().deserialize(
                self.settings.join_file)
            self._uploadToTopic(extra_outputs, self.getOutputSerializer(),
                                inp_topic)
            return self._deserializeFromOutputTopic()
        else:
            return super().outputs() 
[docs]    def usingPubsub(self):
        return bool(self.getInputTopic())