Source code for schrodinger.utils.funcchains
"""
This module provides a system for collecting functions into ordered groups
called chains. The functions in a chain can be executed sequentially, with an
optional callback after each function that can be used to terminate the
processing of the chain. An object may have multiple chains of functions, each
identified by its respective decorator.
To use, create an object that inherits FuncChainMixin, and use a
FuncChainDecorator to mark methods in that object as belonging to a particular
chain. For example::
my_startup_functions = FuncChainDecorator()
class Foo(FuncChainMixin):
def __init__(self):
self.processFuncChain('startup')
@my_startup_functions(order=1)
def initVariables(self):
...
@my_startup_functions(order=2)
def setupWorkspace(self):
...
Supply a callback to self.processFuncChain(result_callback=bar) that accepts
the return value of each function and returns True to continue processing the
chain or False to terminate processing of the chain.
"""
import collections
import functools
from schrodinger.utils import funcgroups
from schrodinger.utils.funcgroups import get_marked_func_order # noqa: F401
[docs]class FuncChainMarker(funcgroups.FuncGroupMarker):
[docs] def customizeFuncResult(self, func, result):
"""
Override this method to customize the return value of a function in the
chain. This can be used to cast or interpret the return value or use
the return value to build a custom object.
:param func: The function that produced the result
:param result: The return value of that function
"""
return result
def __call__(self, func=None, order=None):
if not callable(func) and func is not None:
# If an argument is passed in that is not a function, assume that
# it was meant to be used for the order.
order = func
func = None
if func is None:
# func is None if decorator is called to specify order, i.e.
# @funcchain(order=2)
# def foo(self):
# pass
return lambda func: self(func, order=order)
# func is not None if the decorator is used directly, i.e.
# @funcchain
# def foo(self):
# pass
marked_func = super().__call__(func, order=order)
@functools.wraps(marked_func)
def wrapped_func(*args, **kwargs):
retval = marked_func(*args, **kwargs)
return self.customizeFuncResult(wrapped_func, retval)
return wrapped_func
funcchain = FuncChainMarker('funcchain')
[docs]class FuncChainMixin(funcgroups.FuncGroupMixin):
_default_funcgroup = funcchain
[docs] def __init__(self, *args, **kwargs):
self._funcchain_extra_funcs = collections.defaultdict(list)
super().__init__(*args, **kwargs)
def _defaultResultCallback(self, result):
"""
The default callback used when processFuncChain is called without a
result_callback specified. See processFuncChain for more information.
By default, only an explicit return value of False will cause chain
processing to terminate.
:param result: the return value of a function in the chain
:return: Whether to proceed to the next function in the chain
:rtype: bool
"""
return result is not False
[docs] def processFuncChain(self, chain=None, result_callback=None):
"""
Execute each function in the specified chain sequentially in order.
The result_callback is called after each function with the return value
of that function. This can be used to respond to the return value (e.g.
present information to the user, get user feedback, log the result,
etc.)
The return value of the result_callback determines whether processing
will proceeed to the next function.
:param chain: which chain to process
:type chain: FuncChainDecorator
:param result_callback: the callback that will get called with the
result of each function in the chain
:return: a list of the results from the functions
"""
if chain is None:
chain = self._default_funcgroup
if result_callback is None:
result_callback = self._defaultResultCallback
results = []
for func in self.getFuncGroup(chain):
if hasattr(func, '_marked_method_group'):
# This func is already wrapped by the decorator
result = func()
else:
# This func is from addFuncToGroup or monkeypatched for a test
result = chain.customizeFuncResult(func, func())
results.append(result)
cont = result_callback(result)
if not cont:
break
return results