Source code for schrodinger.utils.funcchains
"""
This module provides a system for collecting functions into ordered groups
called chains. The functions in a chain can be executed sequentially, with an
optional callback after each function that can be used to terminate the
processing of the chain. An object may have multiple chains of functions, each
identified by its respective decorator.
To use, create an object that inherits FuncChainMixin, and use a
FuncChainDecorator to mark methods in that object as belonging to a particular
chain. For example::
    my_startup_functions = FuncChainDecorator()
    class Foo(FuncChainMixin):
        def __init__(self):
            self.processFuncChain('startup')
        @my_startup_functions(order=1)
        def initVariables(self):
            ...
        @my_startup_functions(order=2)
        def setupWorkspace(self):
            ...
Supply a callback to self.processFuncChain(result_callback=bar) that accepts
the return value of each function and returns True to continue processing the
chain or False to terminate processing of the chain.
"""
import collections
import functools
from schrodinger.utils import funcgroups
from schrodinger.utils.funcgroups import get_marked_func_order  # noqa: F401
[docs]class FuncChainMarker(funcgroups.FuncGroupMarker):
[docs]    def customizeFuncResult(self, func, result):
        """
        Override this method to customize the return value of a function in the
        chain. This can be used to cast or interpret the return value or use
        the return value to build a custom object.
        :param func: The function that produced the result
        :param result: The return value of that function
        """
        return result 
    def __call__(self, func=None, order=None):
        if not callable(func) and func is not None:
            # If an argument is passed in that is not a function, assume that
            # it was meant to be used for the order.
            order = func
            func = None
        if func is None:
            # func is None if decorator is called to specify order, i.e.
            # @funcchain(order=2)
            # def foo(self):
            #     pass
            return lambda func: self(func, order=order)
        # func is not None if the decorator is used directly, i.e.
        # @funcchain
        # def foo(self):
        #     pass
        marked_func = super().__call__(func, order=order)
        @functools.wraps(marked_func)
        def wrapped_func(*args, **kwargs):
            retval = marked_func(*args, **kwargs)
            return self.customizeFuncResult(wrapped_func, retval)
        return wrapped_func 
funcchain = FuncChainMarker('funcchain')
[docs]class FuncChainMixin(funcgroups.FuncGroupMixin):
    _default_funcgroup = funcchain
[docs]    def __init__(self, *args, **kwargs):
        self._funcchain_extra_funcs = collections.defaultdict(list)
        super().__init__(*args, **kwargs) 
    def _defaultResultCallback(self, result):
        """
        The default callback used when processFuncChain is called without a
        result_callback specified. See processFuncChain for more information.
        By default, only an explicit return value of False will cause chain
        processing to terminate.
        :param result: the return value of a function in the chain
        :return: Whether to proceed to the next function in the chain
        :rtype: bool
        """
        return result is not False
[docs]    def processFuncChain(self, chain=None, result_callback=None):
        """
        Execute each function in the specified chain sequentially in order.
        The result_callback is called after each function with the return value
        of that function. This can be used to respond to the return value (e.g.
        present information to the user, get user feedback, log the result,
        etc.)
        The return value of the result_callback determines whether processing
        will proceeed to the next function.
        :param chain: which chain to process
        :type chain: FuncChainDecorator
        :param result_callback: the callback that will get called with the
            result of each function in the chain
        :return: a list of the results from the functions
        """
        if chain is None:
            chain = self._default_funcgroup
        if result_callback is None:
            result_callback = self._defaultResultCallback
        results = []
        for func in self.getFuncGroup(chain):
            if hasattr(func, '_marked_method_group'):
                # This func is already wrapped by the decorator
                result = func()
            else:
                # This func is from addFuncToGroup or monkeypatched for a test
                result = chain.customizeFuncResult(func, func())
            results.append(result)
            cont = result_callback(result)
            if not cont:
                break
        return results