"""
A task represents a block of work that has a defined input and output and runs
without user intervention. Different task classes share a common external API
but have different implementations for defining and executing the work, such as
blocking calls, threads, subprocesses, or job control (see jobtasks).
To define a task, follow these basic instructions:
1. Choose a task class to subclass. The choice of task class is primarily
dictated by how the task needs to run - thread, subprocess, job, etc. See the
Task Class Selection Guide for help.
2. Override the input and output params. The task.input and task.output params
may be of any Param type, including CompoundParam (typical). For CompoundParams,
either use an existing class to override task.input, OR define a nested class
named Input within the task. Doing so will automatically override task.input.
The same goes for task.output. Example::
class FooTask(tasks.ThreadFunctionTask):
input = AtomPair() # AtomPair is an existing CompoundParam subclass
# This will magically override FooTask.output = Output()
class Output(parameters.CompoundParam):
charge: float
processed_atom_pair: AtomPair
3. Define the work of the task. This is done differently for different task
classes, but generally involves overriding a method to either provide python
logic directly as the work to be done or to construct a command line with the
appropriate arguments that will be invoked via the appropriate mechanism for the
task type.
Once a task is defined, it can be instantiated, set up, and started::
task = FooThreadTask()
task.input.x = 3
task.input.y = 4
task.start()
assert task.status is tasks.Status.RUNNING
task.wait()
assert task.status is tasks.Status.DONE
print(task.output)
.. warning::
`wait()` executes a local event loop, so it should not be called directly
from a GUI - see PANEL-18317 for discussion. `wait()` is safe to call
inside a subprocess or job (e.g. if a jobtask spawns child tasks).
Run `git grep "task[.]wait("` to see safe examples annotated with "# OK".
==================
Pre/postprocessors
==================
Tasks support pre/post processing functions. These can either be methods in the
class that are decorated with the preprocessor or postprocessor decorators, or
external functions that are added to a task instance. Example::
class MyTask(tasks.BlockingFunctionTask):
@tasks.preprocessor
def checkInput(self):
if self.input.x <0:
return False, 'x must be a nonnegative number.'
For more information, see the module-level preprocessor and postprocessor
decorators as well as the start(), preprocessors(), and addPreprocessor()
methods of AbstractTask.
========================
Task directory (taskdir)
========================
Tasks have a concept of a taskdir. While the task framework will never actually
chdir into a different directory, the task provides functions for specifying
and accessing a directory that is considered that task's directory by
convention. Subprocesses started by the task will use the taskdir as their
working directory.
To specify a taskdir, override AbstractTask.DEFAULT_TASKDIR_SETTING or use
task.specifyTaskDir(). Example::
class MyTask(tasks.BlockingFunctionTask):
DEFAULT_TASKDIR_SETTING = tasks.AUTO_TASKDIR
task = MyTask()
task.specifyTaskDir('foo_dir')
The taskdir is created during preprocessing. Once the taskdir is created, use
task.getTaskDir() and task.getTaskFilename() when reading and writing files for
the task. Example::
class MyTask(tasks.SubprocessCmdTask):
@tasks.preprocessor(order=tasks.AFTER_TASKDIR)
def writeInputFiles(self):
with open(self.getTaskFilename('foo_data.txt'), 'w') as f:
f.write(self.input.foo_data)
For more details on taskdir, see task.specifyTaskDir() task.getTaskDir().
==========================
Input/Output File Handling
==========================
To specify a task input file or folder, use the `TaskFile` or `TaskFolder`
classes as a subparam on the task.input param. If the task runs its unit
of work on a different machine or process, the input files/folders will
automatically be copied to the right location on the compute host. The
path to the `TaskFile`/`TaskFolder` will also be updated so it points
to the right location, regardless of when or where it's accessed.
`TaskFile`/`TaskFolder`s may be nested under the input param in supported
container types. Supported container types are:
There are few restrictions on how nested you can define your
`TaskFile/TaskFolder` on the input param. For example, if you have
a variable number of input files, you can define the input with a list::
- List
- Dict
- Set
- Tuple
- CompoundParam
For example::
class Input(parameters.CompoundParam):
receptor_filename: TaskFile
ligand_filenames: List[TaskFile]
Task output files/folders behave in the exact same way as task input
files/folders except they're defined as `TaskFile` or `TaskFolder` on the
output param.
"""
import contextlib
import copy
import enum
import inspect
import os
import pathlib
import pickle
import random
import shutil
import string
import sys
import tempfile
import traceback
import typing
from collections import namedtuple
from datetime import datetime
from typing import List
from schrodinger.models import json
from schrodinger.models import jsonable
from schrodinger.models import parameters
from schrodinger.models import paramtools
from schrodinger.Qt import QtCore
from schrodinger.Qt.QtCore import QProcess
from schrodinger.tasks import cmdline
from schrodinger.ui.qt.appframework2 import application
from schrodinger.utils import fileutils
from schrodinger.utils import funcchains
from schrodinger.utils import imputils
from schrodinger.utils import qt_utils
from schrodinger.utils import scollections
from schrodinger.utils import subprocess as subprocess_utils
[docs]class TaskDirNotFoundError(RuntimeError):
pass
[docs]class TaskFile(str):
"""
See the "Input/Output File Handling" section of the module docstring
for information.
"""
[docs]class TaskFolder(str):
"""
See the "Input/Output File Handling" section of the module docstring
for information.
"""
#===============================================================================
# Task pre/post processing
#===============================================================================
# Ordering constants
BEFORE_TASKDIR = -2000 # Runs preprocessor before taskdir creation
AFTER_TASKDIR = 0 # Runs preprocesser after taskdir creation (default)
_TASKDIR_ORDER = -1000 # Order for taskdir creation
_WRITE_JSON_ORDER = 10000
[docs]class TaskDirSetting(enum.Enum):
AUTO_TASKDIR = enum.auto()
TEMP_TASKDIR = enum.auto()
# Taskdir settings
AUTO_TASKDIR = TaskDirSetting.AUTO_TASKDIR
TEMP_TASKDIR = TaskDirSetting.TEMP_TASKDIR
class _ProcessorMarker(funcchains.FuncChainMarker):
def customizeFuncResult(self, func, result):
return _cast_processing_result(result, func)
"""
The preprocessor and post processor decorators can be used to mark functions to
be run before/after a task. These decorators may be used on task methods both
with or without args::
class MyTask(tasks.BlockingFunctionTask):
@tasks.preprocessor # Use without args
def checkInput(self):
pass
@tasks.preprocessor(order=tasks.AFTER_TASKDIR) # Use with args
def writeInput(self):
pass
The optional order argument is a float that is used as a sorting key to
determine the order of execution of pre/postprocessors. It's recommended that
one of the module level ordering constants is used, with +/- increments to fine-
tune the order. For example::
class MyTask(tasks.BlockingFunctionTask):
@tasks.preprocessor(order=tasks.AFTER_TASKDIR)
def checkInput(self):
pass
def writeInput(self, order=tasks.AFTER_TASKDIR+1):
pass
External functions may also be decorated. In this case, the function must also
be added to a task instance. Example::
@tasks.preprocessor(order=tasks.AFTER_TASKDIR)
def foo()
pass
task = MyTask()
task.addPreprocessor(foo)
Pre/postprocessors may optionally return a ProcessingResult. As a convenience,
a (passed, message) tuple return value will automatically be cast into a
ProcessingResult by the decorator. Examples::
@tasks.preprocessor
def checkInput(self):
if self.input.x < 0: # Preprocessing failure
return False, 'x must be nonnegative.'
if self.input.x > 100: # Preprocessing warning
return True, 'Large values of x may take a long time.'
return True # Pass (equivalent to returning None)
Returning False without a message will be a silent failure.
"""
preprocessor = _ProcessorMarker('preprocessor')
postprocessor = _ProcessorMarker('postprocessor')
[docs]class ProcessingResult:
"""
A general-purpose return value for task pre/post processors
"""
[docs] def __init__(self, passed, message=None):
"""
:param passed: Whether the result is considered to be passing
:type passed: bool
:param message: A message for this result
:type message: str
"""
self.func = None
self.passed = passed
self.message = message
[docs] def processorName(self):
if self.func is None:
return
return self.func.__name__
def __bool__(self):
return self.passed
def __repr__(self):
return str(self)
def __str__(self):
msg = ''
if self.func is not None:
msg += f'{self.func.__name__}: '
if self.passed and not self.message:
msg += 'Passed'
elif self.passed and self.message:
msg += f'WARNING - {self.message}'
elif not self.passed and not self.message:
msg += 'FAILED'
elif not self.passed and self.message:
msg += f'FAILED - {self.message}'
return msg
[docs]class CallingContext(enum.IntEnum):
CMDLINE = enum.auto()
GUI = enum.auto()
def _cast_processing_result(result, func=None):
"""
Convert the return value of a pre/post-processor to a ProcessingResult,
if necessary. If a func is supplied, it will be recorded in the
ProcessingResult.
:param func: the function that produced this result
:param result: the return value of the pre/post-processor. This can be
represented in one of three ways: (1) True/False for passsed (2) tuple
of (passed, message) (3) a ProcessingResult instance
:type result: bool, tuple, or ProcessingResult
:return: the wrapped return value
:rtype: ProcessingResult
"""
if result is None:
result = True
if isinstance(result, bool):
result = ProcessingResult(result)
if isinstance(result, tuple):
result = ProcessingResult(*result)
if isinstance(result, ProcessingResult):
result.func = func
return result
raise TypeError(f'Return value should be bool or tuple. Got {result}')
#===============================================================================
# Task exceptions
#===============================================================================
[docs]class TaskFailure(Exception):
"""
Exception raised when a task fails for reasons other than an unexpected
error occuring during execution.
"""
# This class intentionally left blank.
[docs]class TaskKilled(TaskFailure):
pass
class _TaskTestTimeout(TaskFailure):
"""
Exception raised if a task times out under pytest
"""
pass
#===============================================================================
# Status
#===============================================================================
FailureInfo = namedtuple('FailureInfo', 'exception traceback message')
[docs]class FailureInfo(FailureInfo):
def __str__(self):
if self.exception is None:
return 'No failure recorded.'
else:
return f'Task failure:\n{self.traceback}\n{self.exception}'
[docs]class Status(jsonable.JsonableIntEnum):
WAITING, RUNNING, FAILED, DONE = range(4)
FINISHED_STATUSES = {Status.FAILED, Status.DONE}
STARTABLE_STATUSES = {Status.WAITING, Status.FAILED, Status.DONE}
NON_RUNNING_STATUSES = {Status.WAITING, Status.FAILED, Status.DONE}
def _wait(task, timeout=None):
"""
Block until the task is finished executing or `timeout` seconds have
passed.
:param timeout: Amount of time in seconds to wait before timing out. If
None or a negative number, this method will wait until the task
is finished.
:type timeout: NoneType or int
:return: whether the task finished during the wait. Returns False if wait
timed out
"""
return _wait_for(task, NON_RUNNING_STATUSES, timeout=timeout)
@application.require_application(use_qtcore_app=True)
def _wait_for(task, end_statuses, timeout=None):
"""
Block until a task reaches one of the specified statuses. Blocks using a
local event loop.
:param task: the task to wait on
:param end_statuses: the task statuses to wait for
:param timeout: an optional timeout in seconds
:return: whether the wait succeeded. Returns False if wait timed out
"""
if task.status in end_statuses:
return True
event_loop = QtCore.QEventLoop()
def check_status(status):
if status in end_statuses:
event_loop.exit()
def time_out_event_loop():
event_loop.exit()
if timeout is not None:
QtCore.QTimer.singleShot(timeout * 1000, time_out_event_loop)
task.statusChanged.connect(check_status)
event_loop.exec()
return task.status in end_statuses
#===============================================================================
# Abstract Task
#===============================================================================
[docs]@qt_utils.add_enums_as_attributes(Status)
@qt_utils.add_enums_as_attributes(CallingContext)
class AbstractTask(funcchains.FuncChainMixin, parameters.CompoundParam):
input: parameters.CompoundParam
output: parameters.CompoundParam
status: Status
name: str
progress: int
max_progress: int
progress_string: str
calling_context = parameters.NonParamAttribute()
failure_info = parameters.NonParamAttribute()
# Convenience Signals
taskDone = QtCore.pyqtSignal()
taskStarted = QtCore.pyqtSignal()
taskFailed = QtCore.pyqtSignal()
DEFAULT_TASKDIR_SETTING = None
AUTO_TASKDIR = AUTO_TASKDIR # Add these to the class namespace for
TEMP_TASKDIR = TEMP_TASKDIR # convenience.
_all_task_tempdirs = []
_is_debug_enabled = False
#===========================================================================
# Construction
#===========================================================================
[docs] @classmethod
def runFromCmdLine(cls):
return cmdline.run_task_from_cmdline(cls)
[docs] @classmethod
def fromJsonFilename(cls, filename):
with open(filename) as f:
json_dict = json.load(f)
task = cls.fromJson(json_dict)
return task
[docs] def initConcrete(self):
super().initConcrete()
self.statusChanged.connect(self.__onStatusChanged)
self.failure_info = None
self._taskdir = None
self._taskdir_setting = self.DEFAULT_TASKDIR_SETTING
self.calling_context = None
self._in_preprocessing = False
self._interruption_requested = False
self._tempdir = None
[docs] def initializeValue(self):
"""
@overrides: parameters.CompoundParam
"""
if not self.name:
self.name = self.__class__.__name__
#===========================================================================
# Abstract Methods
#===========================================================================
INTERRUPT_ENABLED = False
[docs] def run(self):
# Implementations of run are responsible for directly calling
# `_finish` or connecting a signal to `_finish`.
raise NotImplementedError()
[docs] def kill(self):
"""
Implementations are responsible for immediately stopping the task. No
threads or processes should be running after this method is complete.
This method should be called sparingly since in many contexts the task
will be forced to terminate without a chance to clean up or free
resources.
"""
raise NotImplementedError()
#===========================================================================
# Public API
#===========================================================================
[docs] def start(self, skip_preprocessing=False):
"""
This is the main method for starting a task. Start will check if a task
is not already running, run preprocessing, and then run the task.
Failures in preprocessing will interrupt the task start, and the task
will never enter the RUNNING state.
:param skip_preprocessing: whether to skip preprocessing. This can be
useful if preprocessing was already performed prior to calling
start.
:type skip_preprocessing: bool
"""
self.printDebug('start')
if not self.isStartable():
raise RuntimeError(
f"Can't start a task with status {self.status.name}")
if not self.name:
raise RuntimeError("Can't start a task with name: ''")
self.status = Status.WAITING
self._interruption_requested = False
self.failure_info = None
if not skip_preprocessing:
with self.guard():
self.runPreprocessing(callback=self._processingCallback)
if self.failure_info is not None:
self.status = self.FAILED
return
self.status = self.RUNNING
with self.guard():
self.run()
if self.failure_info is not None:
self.status = self.FAILED
return
[docs] def wait(self, timeout=None):
r"""
Block until the task is finished executing or `timeout` seconds have
passed.
.. warning::
This should not be called directly from GUI code - see PANEL-18317.
It is safe to call inside a subprocess or job. Run
`git grep "task\.wait("` to see safe examples annotated with "# OK".
:param timeout: Amount of time in seconds to wait before timing out. If
None or a negative number, this method will wait until the task
is finished.
:type timeout: NoneType or int
"""
# Call the module-level wait function
self.printDebug(f'wait({timeout})')
try:
with self.guard():
return _wait(self, timeout)
finally:
self.printDebug('wait done')
[docs] def isRunning(self):
return self.status is self.RUNNING
[docs] def isStartable(self):
return self.status in STARTABLE_STATUSES
[docs] def specifyTaskDir(self, taskdir_spec):
"""
Specify the taskdir creation behavior. Use one of the following options:
A directory name (string). This may be a relative or absolute path
None - no taskdir is requested. The task will use the CWD as its taskdir
AUTO_TASKDIR - a new subdirectory will be created in the CWD using the
task name as the directory name.
TEMP_TASKDIR - a temporary directory will be created in the schrodinger
temp dir. This directory is cleaned up when the task is deleted.
:param taskdir_spec: one of the four options listed above
"""
if ((self._in_preprocessing and self._taskdir is not None) or
self.isRunning()):
raise RuntimeError('Taskdir specification may not be changed once '
'the taskdir is created.')
self._taskdir_setting = taskdir_spec
self._taskdir = None
[docs] def taskDirSetting(self):
"""
Returns the taskdir spec. See specifyTaskDir() for details.
"""
return self._taskdir_setting
[docs] def getTaskDir(self):
"""
Returns the full path of the task directory. This is only available if
the task directory exists (after creation of the taskdir or, if no task
dir is specified, any time).
"""
if self._taskdir_setting is None:
return os.getcwd()
if isinstance(self._taskdir_setting, (str, pathlib.Path)):
if os.path.exists(self._taskdir_setting):
self._taskdir = os.path.abspath(self._taskdir_setting)
if self._taskdir is None:
raise TaskDirNotFoundError(
'Taskdir has not been created yet. Consider '
'moving this call to an AFTER_TASKDIR '
'preprocessor.')
return self._taskdir
[docs] def getTaskFilename(self, fname):
"""
Return the appropriate absolute path for an input or output file in the
taskdir.
"""
parent_dir = self.getTaskDir()
return os.path.join(parent_dir, fname)
[docs] def addPreprocessor(self, func, order=None):
"""
Adds a preproceessor function to this task instance. If the function has
been decorated with @preprocessor, the order specified by the decorator
will be used as the default.
:param func: the function to add
:param order: the sorting order for the function relative to all other
preprocessors. Takes precedence over order specified by the
preprocessor decorator.
:type order: float
"""
if order is None:
decorated_order = funcchains.get_marked_func_order(func)
if decorated_order is None:
order = AFTER_TASKDIR
else:
order = decorated_order
self.addFuncToGroup(func, preprocessor, order)
[docs] def addPostprocessor(self, func, order=0):
"""
Adds a postproceessor function to this task instance. If the function
has been decorated with `@postprocessor`, the order specified by the
decorator will be used.
:param func: the function to add
:type func: typing.Callable
:param order: the sorting order for the function relative to all other
preprocessors. Takes precedence over order specified by the
preprocessor decorator.
:type order: float
"""
self.addFuncToGroup(func, postprocessor, order)
[docs] def preprocessors(self):
"""
:return: A list of preprocessors (both decorated methods on the task and
external functions that have been added via
addPreprocessor)
"""
return self.getFuncGroup(preprocessor)
[docs] def postprocessors(self):
"""
:return: A list of postprocessors, both decorated methods on the task
and external functions that have been added via `addPostprocessor()`
:rtype: list[typing.Callable]
"""
return self.getFuncGroup(postprocessor)
[docs] def reset(self, *args, **kwargs):
if not args and not kwargs:
if self.status is self.RUNNING:
raise RuntimeError("Can't reset a task while it's running")
elif self.status is self.FAILED:
self.failure_info = None
super().reset(*args, **kwargs)
[docs] def replicate(self):
"""
Create a new task with the same input and settings (but no output)
"""
old_task = self
new_task = self.__class__()
new_task.specifyTaskDir(old_task.taskDirSetting())
old_preprocess_callbacks = old_task.getAddedFuncs(preprocessor)
for func, order in old_preprocess_callbacks:
new_task.addPreprocessor(func, order)
for func, order in old_task.getAddedFuncs(postprocessor):
new_task.addPostprocessor(func, order)
if isinstance(new_task.input, parameters.CompoundParam):
new_task.input.setValue(old_task.input)
else:
new_task.input = old_task.input
return new_task
[docs] def isDebugEnabled(self):
return self._is_debug_enabled
[docs] def printDebug(self, *args):
if not self.isDebugEnabled():
return
info = self.getDebugString()
print(f'{info}:', *args)
[docs] def getDebugString(self):
return f'{datetime.now()} {self.name}-{self.status.name}'
[docs] def requestInterruption(self):
"""
Request the task to stop.
To enable this feature, subclasses should periodically check whether an
interruption has been requested and terminate if it has been. If such
logic has been included, `INTERRUPT_ENABLED` should be set to `True`.
"""
if not self.INTERRUPT_ENABLED:
raise RuntimeError("Interruption is not enabled for this task.")
self._interruption_requested = True
[docs] def isInterruptionRequested(self):
return self._interruption_requested
#===========================================================================
# Internal methods
#===========================================================================
@preprocessor(order=BEFORE_TASKDIR - 1000)
def _validateTaskName(self):
is_valid = fileutils.is_valid_jobname(self.name)
if not is_valid:
return False, fileutils.INVALID_JOBNAME_ERR % self.name
def __copy__(self):
task_copy = super().__copy__()
if task_copy.status is task_copy.RUNNING:
task_copy.status = task_copy.WAITING
return task_copy
def __deepcopy__(self, memo):
task_copy = super().__deepcopy__(memo)
if task_copy.status is task_copy.RUNNING:
task_copy.status = task_copy.WAITING
return task_copy
def __eq__(self, other):
"""
Tasks compare equal if all params excluding the status are equal.
"""
is_eq = super().__eq__(other)
if is_eq:
return True
else:
if isinstance(other, self.__class__):
self_copy = copy.copy(self)
other_copy = copy.copy(other)
return self_copy.toDict() == other_copy.toDict()
return False
def __onStatusChanged(self, status):
if status is self.RUNNING:
self.taskStarted.emit()
elif status is self.FAILED:
self.taskFailed.emit()
elif status is self.DONE:
self.taskDone.emit()
def _processingCallback(self, result):
if not result.passed:
self._recordFailure(TaskFailure(result.message))
return result.passed
def _defaultResultCallback(self, result):
"""
@overrides: funcchains.FuncChainMixin
"""
if not result.passed:
raise TaskFailure(result.message)
return True
[docs] @typing.final
def runPreprocessing(self, callback=None, calling_context=None):
"""
Run the preprocessors one-by-one. By default, any failing preprocessor
will raise a TaskFailure exception and terminate processing. This
behavior may be customized by supplying a callback function which will
be called after each preprocessor with the result of that preprocessor.
This method is "final" so that all preprocessing logic will be enclosed
in the try/finally block.
:param callback: a function that takes result and returns a bool that
indicates whether to continue on to the next preprocessor
:param calling_context: specify a value here to indicate the context
in which this preprocessing is being called. This value will be
stored in an instance variable, self.calling_context, which can be
accessed from any preprocessor method on this task. Typically this
value will be either self.GUI, self.CMDLINE, or None, but any value
may be supplied here and checked for in the preprocessor methods.
self.calling_context always reverts back to None at the end of
runPreprocessing.
"""
self.printDebug('runPreprocessing')
self._in_preprocessing = True
self._taskdir = None
self.calling_context = calling_context
try:
return self.processFuncChain(preprocessor, result_callback=callback)
finally:
self.calling_context = None
self._in_preprocessing = False
self.printDebug('done preprocessing')
def _runPostprocessing(self, callback=None):
return self.processFuncChain(postprocessor, result_callback=callback)
def _makeTempTaskDir(self):
parent_dir = fileutils.get_directory_path(fileutils.TEMP)
self._tempdir = tempfile.TemporaryDirectory(dir=parent_dir)
self._taskdir = self._tempdir.name
self._registerTempDir(self._tempdir)
def _registerTempDir(self, tmpdir):
"""
Register a tempdir to the class. This is used to clean up all tempdirs
in unit tests.
"""
self._all_task_tempdirs.append(tmpdir)
def _makeDir(self, taskdir):
os.makedirs(taskdir)
@preprocessor(order=_TASKDIR_ORDER)
def _createTaskDir(self):
"""
Create a task directory for running the task in.
"""
if self._taskdir_setting is TEMP_TASKDIR:
self._makeTempTaskDir()
return True
cwd = os.getcwd()
if self._taskdir_setting is None:
self._taskdir = cwd
return True
if self._taskdir_setting is AUTO_TASKDIR:
taskdir = os.path.abspath(self.name)
else:
taskdir = os.path.abspath(self._taskdir_setting)
self._taskdir = taskdir
try:
self._makeDir(taskdir)
except FileExistsError:
if self._taskdir_setting is not AUTO_TASKDIR:
# Allow specified path to already exist
return True
if self.calling_context is self.GUI:
return (True, f'Task directory {self._taskdir} already exists. '
'Contents will be overwritten. Continue?')
return False, f"Task directory {self._taskdir} already exists."
return True
def _recordFailure(self, exception, exc_traceback_str=None):
"""
Store the exception in `failure_info` and set status to failed
"""
if self.failure_info is not None:
return
message = str(exception)
self.failure_info = FailureInfo(exception=exception,
traceback=exc_traceback_str,
message=message)
if exc_traceback_str:
tb = exc_traceback_str
else:
tb = ''
print(
f'{tb}{repr(self)}> failed: {type(exception).__name__}("{message}")'
)
[docs] @contextlib.contextmanager
def guard(self):
"""
Context manager that saves any Exception raised inside
"""
try:
yield
except Exception:
err_type, exc_value, exc_traceback = sys.exc_info()
if err_type is TaskFailure:
exc_traceback_str = None
else:
exc_traceback_str = ''.join(
traceback.format_tb(exc_traceback)[-10:])
# We have to delete the traceback to prevent a circular ref.
# See the `traceback` module documentation for additional info.
del exc_traceback
self._recordFailure(exc_value, exc_traceback_str)
def _finish(self):
self.printDebug('_finish')
if self.failure_info is not None:
self.status = Status.FAILED
return
with self.guard():
self._runPostprocessing(callback=self._processingCallback)
if self.failure_info is not None:
self.status = Status.FAILED
return
self.status = Status.DONE
def __repr__(self):
if self.isAbstract():
return super().__repr__()
return (f'<{self.__class__.__name__}: {self.name} - '
f'{Status(self.status).name}>') # sometimes status is an int
@classmethod
def _populateClassParams(cls):
cls._convertNestedClassToDescriptor('Input', 'input')
cls._convertNestedClassToDescriptor('Output', 'output')
super()._populateClassParams()
@classmethod
def _convertNestedClassToDescriptor(cls, nested_class_name,
descriptor_name):
"""
If a nested class of the specified name is defined, this method will
instantiate that class and set that instance as a class variable. Ex:
class Foo:
class Bar:
pass
Calling Foo._convertNestedClassToDescriptor('Bar', 'bar') will do the
equivalent of putting bar = Bar() inside the Foo class. Typically used
to instatiate Param classes as descriptors on the class.
:param nested_class_name: the name of the class to look for
:param descriptor_name: the name that the descriptor instance to be
added to the class.
"""
if nested_class_name in cls.__dict__:
nested_class = getattr(cls, nested_class_name)
desc = nested_class()
desc.__set_name__(cls, descriptor_name)
setattr(cls, descriptor_name, desc)
#===============================================================================
# Task interfaces
#===============================================================================
class _AbstractFunctionTask(AbstractTask):
def run(self):
self._runMainFunction()
def _guardedMain(self):
with self.guard():
self.mainFunction()
def _runMainFunction(self):
raise NotImplementedError()
def mainFunction(self):
raise NotImplementedError()
[docs]class AbstractCmdTask(AbstractTask):
[docs] def run(self):
cmd = self.makeCmd()
for idx, arg in enumerate(cmd):
if not isinstance(arg, str):
msg = (f"makeCmd() must return a string of lists. Item {idx} "
f"is type {type(arg)}.")
raise ValueError(msg)
self.runCmd(cmd)
[docs] def runCmd(self, cmd):
raise NotImplementedError()
[docs] def makeCmd(self):
return []
[docs]class AbstractComboTask(AbstractCmdTask, _AbstractFunctionTask):
"""
Subclasses should only define params inside of input or output. Top-level
params defined in subclasses do NOT get serialized between the frontend and
backend task instances. Thus, any modifications of new top-level params in
the backend (i.e. mainFunction) will not have any effect on the rehydrated
frontend task.
"""
_run_as_backend: bool = False
ENTRYPOINT = 'combotask_entry_point.py'
# Private params, not for use by child classes
_task_module: str
_task_class: str
_task_script: str
_failure_info: str = None
_failure_tb: str = None
_combo_id: str = None
# Only these params will be serialized in frontend/backend conversions
_FRONTEND_TO_BACKEND_PARAMS = [
'name', 'input', '_run_as_backend', '_task_module', '_task_class',
'_task_script', '_combo_id'
]
_BACKEND_TO_FRONTEND_PARAMS = [
'output', 'status', '_run_as_backend', '_failure_info', '_failure_tb'
]
def _regenerateComboId(self):
"""
Generate a new combo id for this task. A combo id is a random string
that is used to prevent tasks with the same task name from overwriting
each other's combo files (i.e. _frontend.json and _backend.json).
"""
alphabet = string.ascii_lowercase + string.digits
self._combo_id = ''.join(random.choices(alphabet, k=12))
[docs] def initializeValue(self):
super().initializeValue()
if self._combo_id is None: # no combo id from rehydrated json file
self._regenerateComboId()
@property
def json_filename(self):
return self.getTaskFilename(
f'.{self.name}_{self._combo_id}_frontend.json')
@property
def json_out_filename(self):
return self.getTaskFilename(
f'.{self.name}_{self._combo_id}_backend.json')
[docs] def start(self, *args, **kwargs):
"""
@overrides: AbstractTask
"""
if self.isBackendMode():
return self.runBackend()
super().start(*args, **kwargs)
[docs] def isBackendMode(self):
return self._run_as_backend
[docs] def makeCmd(self):
"""
@overrides: AbstractCmdTask
"""
cmd = [
get_schrodinger_run(), self.ENTRYPOINT, '--task_json',
self._getFrontEndJsonArg()
]
return cmd
def _getFrontEndJsonArg(self):
return self.json_filename
def _writeFrontendJsonFile(self):
task_module = self._get_module()
backend_task = copy.deepcopy(self)
# deepcopy of a compoundparam only copies params
backend_task._taskdir = self._taskdir
if task_module == '__main__':
print(f'{self} is defined outside the build. Will attempt to copy '
'script to backend dir to run. If the script needs to import '
'other files, the task will still fail. In this case, move '
'the script and its dependencies to an importable location.')
cp_filename = self._copyScriptToBackend()
backend_task._task_script = os.path.basename(cp_filename)
backend_task._task_module = task_module
backend_task._task_class = type(self).__name__
# need to get json_filename before setting _run_as_backend to True
json_filename = self.json_filename
backend_task._processTaskFilesForFrontendWrite()
backend_task._run_as_backend = True
backend_task._writeComboJsonFile(json_filename)
def _copyScriptToBackend(self):
script_filename = inspect.getfile(type(self))
try:
return shutil.copy(script_filename, self.getTaskDir())
except shutil.SameFileError:
return script_filename
@preprocessor(order=_WRITE_JSON_ORDER)
def _prepareComboTask(self, *args, **kwargs):
self._writeFrontendJsonFile()
def _finish(self):
super()._finish()
# The next time this task is started, it should have a new combo id
self._regenerateComboId()
[docs] def backendMain(self):
raise NotImplementedError
def _processBackend(self):
json_out_path = self.json_out_filename
if not os.path.isfile(json_out_path):
msg = "No json file was returned from the backend. "
logfile = self.getTaskFilename(self._getLogFilename())
if os.path.isfile(logfile):
msg += f"Check {logfile} for more information."
self.printDebug(f'Log file contents:\n{self.getLogAsString()}')
else:
msg += f"Log file not found at {logfile}"
exception = RuntimeError(msg)
self._recordFailure(exception)
else:
with open(json_out_path, 'r') as infile:
# Create a new instance from the backend json output
TaskClass = type(self)
try:
rehydrated_backend = TaskClass.fromJson(json.load(infile))
except json.JSONDecodeError as e:
self._recordFailure(e)
else:
self._updateFromBackend(rehydrated_backend)
def _updateFromBackend(self, rehydrated_backend):
"""
Update the frontend task based on the rehydrated backend task
"""
if isinstance(self.output, parameters.CompoundParam):
self.output.setValue(rehydrated_backend.output)
self._processTaskFilesForBackendRehydration()
else:
self.output = rehydrated_backend.output
if rehydrated_backend.status == rehydrated_backend.FAILED:
backend_exc = pickle.loads(
rehydrated_backend._failure_info.encode())
backend_tb = rehydrated_backend._failure_tb
self._recordFailure(backend_exc, backend_tb)
def _writeComboJsonFile(self, filename):
if self.status is self.FAILED:
# Use protocol 0 since it's ascii-encodable
self._failure_info = pickle.dumps(self.failure_info.exception,
0).decode()
backend_tb = ''.join(
traceback.format_tb(self.failure_info.exception.__traceback__))
self._failure_tb = backend_tb
ser_task = self._createSerializationTask()
try:
with open(filename, 'w') as f:
json.dump(ser_task, f, indent=4)
except:
# If something goes wrong during serialization, we should make
# sure to remove the empty json file.
os.remove(filename)
raise
[docs] def runBackend(self):
self._processTaskFilesForBackendExecution()
self.progressChanged.connect(self._onBackendProgressChanged)
self.max_progressChanged.connect(self._onBackendProgressChanged)
self.progress_stringChanged.connect(self._onBackendProgressChanged)
with self.guard():
try:
self.backendMain()
except NotImplementedError:
self.mainFunction()
with self.guard():
self._processTaskFilesForBackendWrite()
if self.failure_info:
self.status = self.FAILED
if not isinstance(self.failure_info.exception, TaskFailure):
print(self.failure_info.traceback)
if self.failure_info.message:
print(self.failure_info.message)
# Mark as frontend to ensure correct params are serialized
self._run_as_backend = False
self._writeComboJsonFile(self.json_out_filename)
def _onBackendProgressChanged(self):
"""
Implement logic that will communicate progress change from the backend
to the front-end.
"""
def _get_module(self):
"""
Return the module string defining where the class for `self` is defined.
"""
return imputils.get_path_from_module(inspect.getmodule(self))
def _createSerializationTask(self) -> 'AbstractComboTask':
"""
Return a new instance of this task that has serialization param values
set for frontend/backend conversion. Non-serialization params have
default values.
"""
ser_task = self.__class__()
ser_param_names = self._getSerializationParamNames()
for param_name in ser_param_names:
param_value = getattr(self, param_name)
if isinstance(param_value, parameters.CompoundParam):
param_to_serialize = getattr(ser_task, param_name)
param_to_serialize.setValue(param_value)
else:
setattr(ser_task, param_name, param_value)
return ser_task
def _getSerializationParamNames(self) -> List[str]:
"""
Return a list of the names of params that should be serialized for
frontend/backend combo task conversion.
"""
if self._run_as_backend:
param_names = self._FRONTEND_TO_BACKEND_PARAMS
else:
param_names = self._BACKEND_TO_FRONTEND_PARAMS
return param_names
#===========================================================================
# TaskFile Processing
#===========================================================================
def _processTaskFilesForFrontendWrite(self):
"""
This will be called before writing out the combotask frontend json file.
Transforms all TaskFile and TaskFolder paths in self.input so that the
json file within the taskdir will be portable, if possible.
Raises a ValueError if any files/directories do not exist.
"""
def process_input(path, launchdir):
path = os.path.abspath(path)
return path
self._assertTaskFileExistence(self.input)
self._processTaskFiles(self.input, process_func=process_input)
def _processTaskFilesForBackendExecution(self):
"""
This will be called in the backend before executing the mainFunction of
the combotask. Override if the file paths are different in the backend
compared to the paths used in the frontend.
Raises a ValueError if any files/directories do not exist.
"""
self._assertTaskFileExistence(self.input)
def _processTaskFilesForBackendWrite(self):
"""
This will be called in the backend after the mainFunction returns before
writing the combotask backend json file. Converts absolute paths into
relative paths so that file references can remain valid if the taskdir
is copied or moved.
Raises a ValueError if any files/directories do not exist.
"""
def process_output(path, launchdir):
path = os.path.relpath(path)
return path
self._assertTaskFileExistence(self.output)
self._processTaskFiles(self.output, process_func=process_output)
def _processTaskFilesForBackendRehydration(self):
"""
This will be called before the output of the backend task is set back on
the frontend task.
Raises a ValueError if any files/directories do not exist.
"""
self._assertTaskFileExistence(self.output)
def _assertTaskFileExistence(self, param):
def assert_taskfile_existence(path):
if path is None:
return None
if not os.path.exists(path):
raise ValueError(
f'Filepath "{path}" does not exist. Make sure all '
'taskfiles and task folders point to existing files before '
'starting or completing the task.')
return path
if isinstance(param, parameters.CompoundParam):
paramtools.map_subparams(assert_taskfile_existence, param, TaskFile)
if isinstance(param, parameters.CompoundParam):
paramtools.map_subparams(assert_taskfile_existence, param,
TaskFolder)
def _processTaskFiles(self, param, *, process_func, dir=None):
if dir is None:
dir = self.getTaskDir()
def process_taskfile(path):
if path is None:
return None
if process_func is None:
return path
else:
new_path = process_func(path, dir)
return new_path
if isinstance(param, parameters.CompoundParam):
paramtools.map_subparams(process_taskfile, param, TaskFile)
if isinstance(param, parameters.CompoundParam):
paramtools.map_subparams(process_taskfile, param, TaskFolder)
#===============================================================================
# Task execution mixins
#===============================================================================
[docs]def get_schrodinger_run():
return 'run'
class _SaveTaskReferenceMixin:
def __init_subclass__(cls):
super().__init_subclass__()
# Let each class have its own set so failures are easier to understand
cls._saved_task_references = scollections.IdSet()
def start(self, *args, **kwargs):
super().start(*args, **kwargs)
if self.status == Status.RUNNING:
self._saveTaskReference()
def _finish(self):
super()._finish()
self._discardTaskReference()
def _saveTaskReference(self):
self._saved_task_references.add(self)
def _discardTaskReference(self):
self._saved_task_references.discard(self)
[docs]class BlockingMixin:
"""
Compatible with subclasses of AbstractFunctionTask.
"""
def _runMainFunction(self):
self._guardedMain()
self._finish()
[docs]class ThreadMixin(_SaveTaskReferenceMixin):
MAX_THREAD_TASKS = 500
qthread = parameters.NonParamAttribute()
[docs] def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.qthread = None
[docs] def kill(self):
"""
@overrides: AbstractTask
Killing threads is dangerous and can leading to deadlocking on
Windows, so we intentionally leave it unimplemented rather than
using QThread.terminate.
"""
raise NotImplementedError
def _runMainFunction(self):
# Make sure that there is a QApplication running. If there isn't,
# create a QCoreApplication.
application.get_application(create=True, use_qtcore_app=True)
self.qthread = QtCore.QThread()
# TODO: Decide whether to leave this as a monkey-patch or hook up
# qthread.started to _guardedMain instead. If we leave it as a patch,
# we should add a strong warning against calling .start() from multiple
# threads.
self.qthread.run = self._guardedMain
self.qthread.finished.connect(self.__onThreadFinished)
self.qthread.start()
@typing.final
def __onThreadFinished(self):
self._finish()
[docs]class QProcessError(Exception):
[docs] def __init__(self, message):
super().__init__(message)
[docs]class QProcessFailedToStartError(QProcessError):
pass
[docs]class QProcessCrashedError(QProcessError):
pass
[docs]class QProcessTimedout(QProcessError):
pass
[docs]class QProcessWriteError(QProcessError):
pass
[docs]class QProcessReadError(QProcessError):
pass
[docs]class QProcessUnknownError(QProcessError):
pass
_QProcessErrorToException = {
QProcess.FailedToStart: QProcessFailedToStartError,
QProcess.Crashed: QProcessCrashedError,
QProcess.Timedout: QProcessTimedout,
QProcess.WriteError: QProcessWriteError,
QProcess.ReadError: QProcessReadError,
QProcess.UnknownError: QProcessUnknownError
}
[docs]class SubprocessMixin(_SaveTaskReferenceMixin):
cmd = parameters.NonParamAttribute()
exit_code = parameters.NonParamAttribute()
qprocess = parameters.NonParamAttribute()
[docs] def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cmd = None
self.exit_code = None
self.qprocess = None
self._printing_output_to_terminal = False
[docs] def printingOutputToTerminal(self):
"""
:return: whether the `StdOut` and `StdErr` output from this task is
being printed to the terminal
:rtype: bool
"""
return self._printing_output_to_terminal
[docs] def setPrintingOutputToTerminal(self, print_to_terminal):
"""
Set this task to print `StdOut` and `StdErr` output to terminal, or not.
:param print_to_terminal: whether to send process output to terminal
:type print_to_terminal: bool
"""
self._printing_output_to_terminal = print_to_terminal
[docs] def runCmd(self, cmd):
# Make sure that there is a QApplication running. If there isn't,
# create a QCoreApplication.
application.get_application(create=True, use_qtcore_app=True)
self.exit_code = None
self.qprocess = None
self.cmd = cmd
cmd[0] = subprocess_utils.abs_schrodinger_path(cmd[0])
self._setupQProcess()
self.qprocess.start(cmd[0], cmd[1:])
def _setupQProcess(self):
self.qprocess = QtCore.QProcess()
if self.printingOutputToTerminal():
self.qprocess.setProcessChannelMode(
QtCore.QProcess.ForwardedChannels)
else:
self.qprocess.setProcessChannelMode(QtCore.QProcess.MergedChannels)
self.qprocess.setStandardOutputFile(self._getLogFilename())
self.qprocess.setWorkingDirectory(self.getTaskDir())
self.qprocess.finished.connect(self.__onSubprocessCompleted)
self.qprocess.errorOccurred.connect(self.__onErrorOccurred)
@typing.final
def __onSubprocessCompleted(self):
with self.guard():
self._onSubprocessCompleted()
self._finish()
def _onSubprocessCompleted(self):
self.exit_code = self.qprocess.exitCode()
if self.exit_code != 0:
msg = f'{self} returned non-zero exit code.'
log_str = self.getLogAsString()
if len(log_str) > 200:
# Elide
log_str = log_str[:200] + '...'
msg += f'\n{log_str}'
self._recordFailure(TaskFailure(msg))
@typing.final
def __onErrorOccurred(self, error):
with self.guard():
self._onErrorOccurred(error)
self._finish()
def _onErrorOccurred(self, error):
qprocess_exception = _QProcessErrorToException[error](
message=
f"Command: {self.cmd} had fatal error: {self.qprocess.errorString()}"
)
self.exit_code = self.qprocess.exitCode()
self._recordFailure(qprocess_exception)
def _getLogFilename(self):
return self.getTaskFilename(self.name + '.log')
[docs] def getLogAsString(self) -> str:
log_fn = self.getTaskFilename(self._getLogFilename())
if not os.path.isfile(log_fn):
return f'Log file not found: {log_fn}'
with open(log_fn) as log_file:
return log_file.read()
[docs] def kill(self):
"""
@overrides: AbstractTask
Kill the subprocess and set the status to FAILED.
"""
if self.status is not self.RUNNING:
raise RuntimeError("Can't kill a task that's not running.")
if self.qprocess:
self.qprocess.finished.disconnect(self.__onSubprocessCompleted)
self.qprocess.errorOccurred.disconnect(self.__onErrorOccurred)
self.qprocess.kill()
self.qprocess.waitForFinished()
self._recordFailure(TaskKilled())
self._finish()
#===============================================================================
# Prepackaged Task Classes
#===============================================================================
[docs]class BlockingFunctionTask(BlockingMixin, _AbstractFunctionTask):
"""
A task that simply runs a function and blocks for the duration of it.
To use, implement `mainFunction`.
"""
[docs]class ThreadFunctionTask(ThreadMixin, _AbstractFunctionTask):
"""
A task that runs a function in a separate thread.
To use, implement `mainFunction`.
Note: this class should not be used except in limited
circumstances, as much of our internal code is not thread
safe (e.g. structure.Structure - see PANEL-16783).
New implementations will have to register their usage
in test_thread_usage.py, and include the following warning
in the mainFunction of the task:
# This logic will be run in a worker thread and must not
# access thread-unsafe libraries, including structure.Structure.
"""
[docs]class SubprocessCmdTask(SubprocessMixin, AbstractCmdTask):
"""
A task that launches a subprocess.
To use, implement `makeCmd` and return a list of strings.
"""
[docs]class ComboBlockingFunctionTask(AbstractComboTask):
"""
This is mostly for testing purposes.
"""
[docs] def runCmd(self, cmd):
cls = type(self)
backend_task = cls.fromJsonFilename(self.json_filename)
backend_task.specifyTaskDir(self.getTaskDir())
backend_task.start()
os.rename(backend_task.json_out_filename, self.json_out_filename)
self._processBackend()
self._finish()
[docs]class ComboSubprocessTask(SubprocessMixin, AbstractComboTask):
"""
A task that runs a function in a subprocess.
To use, implement `mainFunction`.
"""
def _processTaskFilesForBackendRehydration(self):
def process_input(path, backend_dir):
return self.getTaskFilename(path)
self._processTaskFiles(self.output, process_func=process_input)
super()._processTaskFilesForBackendRehydration()
[docs] def runBackend(self):
# Specify the task dir as the cwd since we've already chdirs into
# the directory with all the task files
self.specifyTaskDir(None)
return super().runBackend()
[docs] def getTaskDir(self):
if self.isBackendMode():
return ''
return super().getTaskDir()
def _finish(self):
with self.guard():
self._processBackend()
super()._finish()
[docs]class SignalTask(AbstractTask):
"""
A task that relies on signals to proceed. Runs asynchronously via the event
loop without requiring a worker thread. To use, implement setUpMain to
connect any per-run signals and slots. Any slots should be decorated with
SignalTask.guard_method so that exceptions in slots get converted into
task failures. To end the task, emit self.mainDone to indicate the task has
successfully completed. To fail, raise a TaskFailure or other exception.
"""
mainDone = QtCore.pyqtSignal()
[docs] def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mainDone.connect(self._finish)
[docs] @staticmethod
def guard_method(func):
def wrapped_func(self, *args, **kwargs):
with self.guard():
return func(self, *args, **kwargs)
if self.failure_info:
self._finish()
return wrapped_func
[docs] def run(self):
with self.guard():
self.setUpMain()
[docs] def setUpMain(self):
raise NotImplementedError()