"""
This module provides fundamental facilities for writing a multisim driver
script, for writing multisim concrete stage classes, and for dealing with
protocol files.
Copyright Schrodinger, LLC. All rights reserved.
"""
import copy
import glob
import os
import pickle
import shutil
import signal
import subprocess
import sys
import tarfile
import threading
import time
import weakref
from io import BytesIO
from typing import BinaryIO
from typing import Iterable
from typing import List
from typing import Optional
from typing import Union
from pathlib import Path
import schrodinger.application.desmond.bld_ver as bld
import schrodinger.application.desmond.cmdline as cmdline
import schrodinger.application.desmond.envir as envir
import schrodinger.application.desmond.picklejar as picklejar
import schrodinger.application.desmond.util as util
import schrodinger.infra.mm as mm
import schrodinger.job.jobcontrol as jobcontrol
import schrodinger.utils.sea as sea
from schrodinger.application.desmond import constants
from schrodinger.application.desmond import queue
from schrodinger.utils import fileutils
from .picklejar import Picklable
from .picklejar import PicklableMetaClass
from .picklejar import PickleJar
# Contributors: Yujie Wu
# Info
VERSION = "4.0.0"
BUILD = bld.desmond_build_version()
# Machinery
QUEUE = None
ENGINE = None
# Log
LOGLEVEL = [
"silent",
"quiet",
"verbose",
"debug",
]
GENERAL_LOGLEVEL = "quiet"
# Suffixes
PACKAGE_SUFFIX = ".tgz"
CHECKPOINT_SUFFIX = "-multisim_checkpoint"
# Filenames
CHECKPOINT_FNAME = "$MAINJOBNAME" + CHECKPOINT_SUFFIX
_PRODUCTION_SIMULATION_STAGES = ["lambda_hopping", "replica_exchange"]
def _print(loglevel, msg):
if LOGLEVEL.index(loglevel) <= LOGLEVEL.index(GENERAL_LOGLEVEL):
if loglevel == "debug":
print("MSJDEBUG: %s" % msg)
else:
print(msg)
sys.stdout.flush()
[docs]def print_tonull(msg):
pass
[docs]def print_silent(msg):
_print("silent", msg)
[docs]def print_quiet(msg):
_print("quiet", msg)
[docs]def print_verbose(msg):
_print("verbose", msg)
[docs]def print_debug(msg):
_print("debug", msg)
def _time_str_to_time(time_str, scale=1.0):
h, m, s = [e[:-1] for e in time_str.split()]
return scale * (float(h) * 3600 + float(m) * 60 + float(s))
def _time_to_time_str(inp_time):
h, r = divmod(int(inp_time), 3600)
m, s = divmod(r, 60)
return "%sh %s' %s\"" % (h, m, s)
[docs]class JobStatus(object):
# Good status
WAITING = 101
RUNNING = 102
SUCCESS = 103
# Bad status and non-retriable
BACKEND_ERROR = 201
PERMANENT_LICENSE_FAILURE = 202
NON_RETRIABLE_FAILURE = 299
# Bad status and retriable
TEMPORARY_LICENSE_FAILURE = 301
KILLED = 302
FIZZLED = 303
LAUNCH_FAILURE = 304
FILE_NOT_FOUND = 305
FILE_CORRUPT = 306
STRANDED = 307
CHECKPOINT_REQUESTED = 308
CHECKPOINT_WITH_RESTART_REQUESTED = 309
RETRIABLE_FAILURE = 399
STRING = {
WAITING: "is waiting for launching",
RUNNING: "is running",
SUCCESS: "was successfully finished",
PERMANENT_LICENSE_FAILURE: ("could not run due to permanent license "
"failure"),
TEMPORARY_LICENSE_FAILURE: "died due to temporary license failure",
KILLED: "was killed",
FIZZLED: "fizzled",
STRANDED: "was stranded",
LAUNCH_FAILURE: "failed to launch",
FILE_NOT_FOUND: ("was finished, but registered output files were not "
"found"),
FILE_CORRUPT: ("was finished, but an essential output file was found "
"corrupt"),
BACKEND_ERROR: "died due to backend error",
RETRIABLE_FAILURE: "died on unknown retriable failure",
NON_RETRIABLE_FAILURE: "died on unknown non-retriable failure",
CHECKPOINT_REQUESTED: "user requested job be checkpointed",
CHECKPOINT_WITH_RESTART_REQUESTED: "user requested job be checkpointed and restarted"
}
[docs] def __init__(self, code=WAITING):
self._code = code
self._error = None
def __str__(self):
s = ""
try:
s += JobStatus.STRING[self._code]
except KeyError:
if self._error is None:
s += "unknown error"
if self._error is not None:
s += "\n" + self._error
return s
def __eq__(self, other):
if isinstance(other, JobStatus):
return self._code == other._code
else:
try:
return self._code == int(other)
except ValueError:
raise NotImplementedError
def __ne__(self, other):
return not self.__eq__(other)
[docs] def set(self, code, error=None):
if isinstance(code, JobStatus):
self._code = code
else:
try:
self._code = int(code)
except ValueError:
raise NotImplementedError
self._error = error
[docs] def is_good(self):
return self._code < 200
[docs] def is_retriable(self):
return self._code > 300
[docs] def should_restart_from_checkpoint(self):
return self._code == self.CHECKPOINT_WITH_RESTART_REQUESTED
[docs]class JobOutput(object):
[docs] def __init__(self):
# Key: file name. Value: None or a callable that checks the file
self._file = {}
self._type = {} # Key: file name. Value: "file" | "dir"
self._tag = {} # Key: tag. Value: file name
self._struct = None
# Note on pickling: Values in `self._file' will be set to None when
# `self' is pickled.
[docs] def __len__(self):
"""
Returns the number of registered output files.
"""
return len(self._file)
def __iter__(self):
"""
Iterates through the registered output files.
Note that the order of the files here are not necessarily the same order
of file registration.
"""
for f in self._file:
yield f
def __list__(self):
return list(self._files)
def __deepcopy__(self, memo={}): # noqa: M511
newobj = JobOutput()
memo[id(self)] = newobj
newobj._file = copy.deepcopy(self._file)
newobj._type = copy.deepcopy(self._type)
newobj._tag = copy.deepcopy(self._tag)
return newobj
def __getstate__(self):
tmp_dict = copy.copy(self.__dict__)
_file = tmp_dict["_file"]
for k in _file:
_file[k] = None
return tmp_dict
[docs] def update_basedir(self, old_basedir, new_basedir):
old_basedir += os.sep
new_basedir += os.sep
new_file = {}
for k in self._file:
v = self._file[k]
if k.startswith(old_basedir):
k = k.replace(old_basedir, new_basedir)
new_file[k] = v
self._file = new_file
new_type = {}
for k in self._type:
v = self._type[k]
if k.startswith(old_basedir):
k = k.replace(old_basedir, new_basedir)
new_type[k] = v
self._type = new_type
for k in self._tag:
v = self._tag[k]
if v.startswith(old_basedir):
v = v.replace(old_basedir, new_basedir)
self._tag[k] = v
try:
if self._struct and self._struct.startswith(old_basedir):
self._struct = self._struct.replace(old_basedir, new_basedir)
except AttributeError:
pass
try:
new_cms = []
for e in self.cms:
new_cms.append(e.replace(old_basedir, new_basedir))
self.cms = new_cms
except AttributeError:
pass
[docs] def add(self, filename, checker=None, tag=None, type="file"):
"""
:param type: either "file" and "dir".
"""
if filename:
if type not in ("file", "dir"):
raise ValueError(
'Valid values for \'type\' are "file" and "dir". '
f'But "{type}" is given')
self._file[filename] = checker
self._type[filename] = type
if tag is not None:
if tag in self._tag:
old_filename = self._tag[tag]
del self._file[old_filename]
del self._type[old_filename]
self._tag[tag] = filename
[docs] def remove(self, filename):
"""
"""
try:
del self._file[filename]
except KeyError:
pass
try:
del self._type[filename]
except KeyError:
pass
for key, value in self._tag.items():
if value == filename:
del self._tag[key]
break
[docs] def get(self, tag):
return self._tag.get(tag)
[docs] def check(self, status):
for fname in self._file:
_print("debug", "checking output file: %s" % fname)
if self._type[fname] == "file":
if os.path.isfile(fname):
checker = self._file[fname]
if checker:
err_msg = checker(fname)
if err_msg:
status.set(JobStatus.FILE_CORRUPT, err_msg)
return
else:
_print("debug", "Output file: %s not found" % fname)
try:
_print(
"debug", "Files in current directory: %s" %
str(os.listdir(os.path.dirname(fname))))
except OSError:
_print(
"debug",
"Directory not found: %s" % os.path.dirname(fname))
status.set(JobStatus.FILE_NOT_FOUND)
return
elif self._type[fname] == "dir":
if not os.path.isdir(fname):
_print("debug", "Output directory: %s not found" % fname)
try:
_print(
"debug", "Files in parent directory: %s" %
str(os.listdir(os.path.dirname(fname))))
except OSError:
_print(
"debug",
"Directory not found: %s" % os.path.dirname(fname))
status.set(JobStatus.FILE_NOT_FOUND)
return
status.set(JobStatus.SUCCESS)
[docs] def set_struct_file(self, fname):
self._struct = fname
if fname not in self._file:
self.add(fname)
[docs] def struct_file(self):
if not self._struct:
for fname in self:
if fname.endswith(
(".mae", ".cms", ".maegz", ".cmsgz", ".mae.gz", ".cms.gz")):
return fname
else:
return self._struct
return None
[docs] def log_file(self):
for fname in self:
if fname.endswith(".log"):
return fname
return None
[docs]class JobErrorHandler:
[docs] @staticmethod
def default(job):
"""
If the job status is bad, attempt to print the log file and
nvidia-smi output.
"""
if not job.status.is_good():
job._print(
"quiet", "jlaunch_dir: %s\n" % job.dir +
"jlaunch_cmd: %s" % subprocess.list2cmdline(job.jlaunch_cmd))
log_fname = job.output.log_file()
if log_fname and os.path.exists(log_fname):
job._print("quiet", "Log file : %s" % log_fname)
with open(log_fname, "r") as f:
log_content = f.readlines()
job._print("quiet",
"Log file content:\n%s" % ">".join(log_content))
job._print("quiet", "(end of log file)\n")
else:
job._print("quiet", "No log file registered for this job\n")
# call nvidia-smi and print output to log file
if job.USE_GPU:
try:
output = subprocess.check_output("nvidia-smi",
universal_newlines=True)
job._print("quiet", "nvidia-smi output:\n%s" % output)
except (FileNotFoundError, subprocess.CalledProcessError):
job._print("quiet", "No nvidia-smi output available\n")
[docs] @staticmethod
def restart_for_backend_error(job):
"""
Run the default handler and if the status is
killed or backend error, mark the failure as retriable.
"""
if not job.status.is_good():
JobErrorHandler.default(job)
if job.status in [JobStatus.BACKEND_ERROR, JobStatus.KILLED]:
job.status.set(JobStatus.RETRIABLE_FAILURE)
[docs]def exit_code_is_defined(job):
"""
Return True if job has an exit code. Failed jobs may not have exit codes if
they are killed by the queueing system or otherwise untrackable.
"""
try:
int(job.ExitCode)
except ValueError:
return False
return True
[docs]class Job(object):
# most jobs do not use gpu
USE_GPU = False
[docs] class Time(object):
[docs] def __init__(self, launch, start, end, num_cpu, cpu_time, duration):
self.launch = launch
self.start = start
self.end = end
self.num_cpu = num_cpu
self.cpu_time = cpu_time
self.duration = duration
@staticmethod
def _get_time_helper(jobtime):
try:
t = time.mktime(time.strptime(jobtime, jobcontrol.timestamp_format))
s = time.ctime(t)
except AttributeError:
t = None
s = "(unknown)"
return t, s
[docs] @staticmethod
def get_time(jctrl, num_cpu):
launch_time, str_launch_time = Job._get_time_helper(jctrl.LaunchTime)
if jctrl.StartTime:
start_time, str_start_time = Job._get_time_helper(jctrl.StartTime)
else:
return Job.Time(str_launch_time, "(not started)", "N/A", num_cpu,
"N/A", "N/A")
if jctrl.StopTime:
stop_time, str_stop_time = Job._get_time_helper(jctrl.StopTime)
else:
return Job.Time(str_launch_time, start_time, "(stranded)", num_cpu,
"N/A", "N/A")
if start_time is not None and num_cpu != "(unknown)":
cpu_time = util.time_duration(start_time, stop_time, num_cpu)
duration = util.time_duration(start_time, stop_time)
else:
cpu_time = "(unknown)"
duration = "(unknown)"
return Job.Time(str_launch_time, str_start_time, str_stop_time, num_cpu,
cpu_time, duration)
[docs] def get_proc_time(self):
proc_time = Job.get_time(self.jctrl, self.num_cpu).cpu_time
return _time_str_to_time(proc_time) if proc_time != "(unknown)" else 0.0
[docs] def __init__(self,
jobname,
parent,
stage,
jlaunch_cmd,
dir,
host_list=None,
prefix=None,
what=None,
err_handler=JobErrorHandler.default,
is_output=True):
self.jobname = jobname
self.tag = None
# Job object from which this `Job' object was derived.
self.parent = parent
# other Job objects from which this `Job' object was derived.
self.other_parent = None
# Job control object, will be set once the job is launched.
self.jctrl = None
self.jlaunch_cmd = jlaunch_cmd # Job launch command
# List of hosts where this job can be running
self.host_list = host_list
# Actual host where this job is running
self.host = jobcontrol.Host("localhost")
# By default, subjobs do not need a host other than localhost.
self.need_host = False
self.num_cpu = 1
self.use_hostcpu = False
# Launch directory, also where the job's outputs will be copied back
self.dir = dir
self.prefix = prefix # Prefix directory of the launch directory
self.what = what # A string that stores more specific job description
self.output = JobOutput() # Output file names
self.input = JobInput() # Input file names
self.status = JobStatus() # Job status
# `None' or a callable object that will be called to handle job errors.
self.err_handler = err_handler
self._jctrl_hist = []
self._has_run = False
if self.parent and self.prefix is None:
self.prefix = self.parent.prefix
if isinstance(stage, weakref.ProxyType):
self.stage = stage
else:
self.stage = weakref.proxy(stage)
# Note on pickling: `self.err_handler' will not be picked.
self.old = False # whether the job was run with the current instance of ENGINE
# is_output is used to signal that this is a stage's output. Some
# stages which implement `hook_captured_successful_job` should set
# `is_output=False` on the intermediate jobs and then set
# `is_output=True` on any final jobs it creates
self.is_output = is_output
@property
def is_for_jc(self) -> bool:
"""
Whether or not this job should be submitted to job control
"""
return self.is_launchable and isinstance(self.jlaunch_cmd, list)
@property
def is_launchable(self) -> bool:
return bool(self.jlaunch_cmd)
@property
def failed(self) -> bool:
return not self.status.is_good()
@property
def is_retriable(self) -> bool:
# Retriable means it can be retried by the -RETRIES mechanism which is
# different than just restarting a job when using -RESTART
return self.status.is_retriable()
@property
def is_incomplete(self) -> bool:
return self.failed or self.status in [
JobStatus.WAITING, JobStatus.RUNNING
]
@property
def is_restartable(self):
return self.is_incomplete and self.is_launchable
def __deepcopy__(self, memo={}): # noqa: M511
newobj = object.__new__(self.__class__)
memo[id(self)] = newobj
for k, v in self.__dict__.items():
if k in ["stage", "jctrl", "parent"]:
value = self.__dict__[k]
elif k == "other_parent":
value = copy.copy(self.other_parent)
elif k == "_jctrl_hist":
value = []
else:
value = copy.deepcopy(v, memo)
setattr(newobj, k, value)
return newobj
def __getstate__(self, state=None):
state = state if (state) else copy.copy(self.__dict__)
if "err_handler" in state:
del state["err_handler"]
if "jctrl" in state:
state["jctrl"] = str(self.jctrl)
if "_jctrl_hist" in state:
state["_jctrl_hist"] = ["removed_in_serialization"]
if "jlaunch_cmd" in state:
if callable(state["jlaunch_cmd"]):
state["jlaunch_cmd"] = "removed_in_serialization"
if "stage" in state:
state["stage"] = (self.stage if (isinstance(self.stage, int)) else
self.stage._INDEX)
return state
def __setstate__(self, state):
self.__dict__.update(state)
if "stage" in state and ENGINE:
self.stage = weakref.proxy(ENGINE.stage[self.stage])
def __repr__(self):
"""
Returns the jobname string in the format: <jobname>.
"""
r = f"<{self.jobname}"
if self.jctrl:
r += f"({self.jctrl})"
r += f" status: {self.status}"
# Temporarily for testing
if isinstance(self.stage, StageBase):
r += f" stage: {self.stage.NAME}"
r += f" is_output: {self.is_output}"
r += f" old: {self.old} is_for_jc: {self.is_for_jc}"
r += ">"
return r
def _print(self, loglevel, msg):
"""
The internal print function of this job. Printing is at the same
'loglevel' as self.stage.
"""
self.stage._print(loglevel, msg)
def _log(self, msg):
"""
The internal log function of this job.
"""
self.stage._log(msg)
def _host_str(self):
"""
Returns a string representing the hosts.
"""
if self.jlaunch_cmd:
if '-HOST' in self.jlaunch_cmd:
return self.jlaunch_cmd[self.jlaunch_cmd.index('-HOST') + 1]
host_str = self.host.name
if self.use_hostcpu and -1 == host_str.find(":"):
host_str += ":%d" % self.num_cpu
return host_str
[docs] def describe(self):
if self.status != JobStatus.LAUNCH_FAILURE:
self._print("quiet", " Launch time: %s" % self.jctrl.LaunchTime)
self._print("quiet", " Host : %s" % self._host_str())
self._print(
"quiet", " Jobname : %s\n" % self.jobname +
" Stage : %d (%s)" % (self.stage._INDEX, self.stage.NAME))
self._print(
"verbose", " Prefix : %s\n" % self.prefix +
" Jlaunch_cmd: %s\n" % subprocess.list2cmdline(self.jlaunch_cmd) +
" Outputs : %s" % str(list(self.output)))
if self.what:
self._print("quiet", " Description: %s" % self.what)
[docs] def process_completed_job(self,
jctrl: jobcontrol.Job,
checkpoint_requested=False,
restart_requested=False):
"""
Check for valid output and set status of job, assuming job is already
complete.
:param checkpoint_requested: Set to True if the job should checkpoint.
Default if False.
:param restart_requested: Set to True if the job should checkpoint and restart.
Default if False.
"""
self.jctrl = jctrl
# Make sure the job data has been downloaded and flushed to disk
self.jctrl.download()
# Not available on windows
if hasattr(os, 'sync'):
os.sync()
self._print(
"debug",
"Job seems finished. Checking its exit-status and exit-code...")
self._print("debug", "Job exit-status = '%s'" % self.jctrl.ExitStatus)
if self.jctrl.ExitStatus == "killed":
self._print("debug", "Job exit-code = N/A")
self.status.set(JobStatus.KILLED)
elif self.jctrl.ExitStatus == "fizzled":
self._print("debug", "Job exit-code = N/A")
self.status.set(JobStatus.FIZZLED)
else:
exit_code = self.jctrl.ExitCode
if not exit_code_is_defined(self.jctrl):
# If the exit code is not set, the backend must have died
# without collecting the exit code. This could happen if a job
# is qdeled, or the backend gets killed by OOM, or the job
# monitoring process is killed by any reason.
# Set status to a retriable status.
self.status.set(JobStatus.KILLED)
elif exit_code == 0:
if checkpoint_requested:
self.status.set(JobStatus.CHECKPOINT_REQUESTED)
elif restart_requested:
self.status.set(JobStatus.CHECKPOINT_WITH_RESTART_REQUESTED)
else:
self.output.check(self.status)
elif exit_code == 17:
# The mmlic3 library will return the following error codes upon
# checkout:
# 0 : success
# 15 : temporary, retryable failure; perhaps the server
# couldn't be contacted
# 16 : all licenses are in use. SGE is capable of requeuing
# the job.
# 17 : fatal, unrecoverable license error.
self.status.set(JobStatus.PERMANENT_LICENSE_FAILURE)
elif exit_code in {15, 16}:
self.status.set(JobStatus.TEMPORARY_LICENSE_FAILURE)
else:
self.status.set(JobStatus.BACKEND_ERROR)
[docs] def requeue(self, jctrl: jobcontrol.Job):
# Make sure the job data has been downloaded and flushed to disk
jctrl.download()
# Not available on windows
if hasattr(os, 'sync'):
os.sync()
# Delete stale checkpoint files that are not needed for restarting
def _filter_tgz(input_fnames: List[str]):
return set(filter(lambda x: x.endswith('-out.tgz'), input_fnames))
stale_input_tgz_fnames = _filter_tgz(jctrl.InputFiles) - _filter_tgz(
jctrl.OutputFiles)
for fname in stale_input_tgz_fnames:
util.remove_file(fname)
self._print("quiet", f"Restart checkpointed job: {self.jlaunch_cmd}")
self._print("quiet",
f"Deleted stale input files: {stale_input_tgz_fnames}")
self.stage.restart_subjobs([self])
self.status.set(JobStatus.WAITING)
[docs] def finish(self):
if self.status != JobStatus.LAUNCH_FAILURE:
jobtime = Job.get_time(self.jctrl, self.num_cpu)
self._print("quiet",
"\n%s %s." % (str(self.jctrl), str(self.status)))
self._print(
"quiet",
" Host : %s\n" % self._host_str() +
" Launch time: %s\n" % jobtime.launch +
" Start time : %s\n" % jobtime.start +
" End time : %s\n" % jobtime.end +
" Duration : %s\n" % jobtime.duration +
" CPUs : %s\n" % self.num_cpu +
" CPU time : %s\n" % jobtime.cpu_time +
" Exit code : %s\n" % self.jctrl.ExitCode +
" Jobname : %s\n" % self.jobname +
" Stage : %d (%s)" % (self.stage._INDEX, self.stage.NAME),
)
if self.err_handler:
self.err_handler(self)
if self.status.is_retriable():
self._print("quiet",
" Retries : 0 - Job has failed too many times.")
self.stage.finalize_job(self)
self.stage.finalize_stage()
class _create_param_when_needed(object):
def __init__(self, param):
self._param = param
def __get__(self, obj, cls):
if cls == StageBase:
a = sea.Map(self._param)
a.add_tag("generic")
else:
a = None
for c in cls.__bases__[::-1]: # left-most base takes precedence
if issubclass(c, StageBase):
if a is None:
a = copy.deepcopy(c.PARAM)
else:
a.update(copy.deepcopy(c.PARAM))
a.update(self._param, tag="stagespec")
setattr(cls, "PARAM", a)
return a
class _StageBaseMeta(PicklableMetaClass):
def __init__(cls, name, bases, dict):
PicklableMetaClass.__init__(cls, name, bases, dict)
cls.stage_cls[cls.NAME] = cls
[docs]class StageBase(Picklable, metaclass=_StageBaseMeta):
count = 0, Picklable
stage_cls = {}
stage_obj = {} # key = stage name; value = stage instance.
NAME = "generic"
RESTARTABLE = False # Whether or not a stage can be restarted after it's already ran
# Basic stage parameters
PARAM = _create_param_when_needed("""
DATA = {
title = ?
should_sync = true
dryrun = false
prefix = ""
jobname = "$MAINJOBNAME_$STAGENO"
dir = "$[$JOBPREFIX/$]$[$PREFIX/$]$MAINJOBNAME_$STAGENO"
compress = "$MAINJOBNAME_$STAGENO-out%s"
struct_output = ""
should_skip = false
effect_if = ?
jlaunch_opt = []
transfer_asap = no
}
VALIDATE = {
title = [{type = none} {type = str}]
should_sync = {type = bool}
dryrun = {type = bool}
prefix = {type = str }
jobname = {type = str }
dir = {type = str }
compress = {type = str }
struct_output = {type = str }
should_skip = {type = bool}
effect_if = [{type = none} {type = list size = -2 _skip = all}]
jlaunch_opt = {
type = list size = 0
elem = {type = str}
check = ""
black_list = ["-HOST" "-USER" "-JOBNAME"]
}
transfer_asap = {type = bool}
}
""" % (PACKAGE_SUFFIX,))
[docs] def __init__(self, should_pack=True):
# Will be set by the parser (see `parse_msj' function below).
self.param = None
self._PREV_STAGE = None # Stage object of the previous stage
self._NEXT_STAGE = None # Stage object of the next stage
self._ID = StageBase.count # ID number of this stage
self._INDEX = None # Stage index. Not serialized.
self._is_shown = False
self._is_packed = False
self._should_pack = should_pack
# For parameter validation
# function objects to be called before the main parameter check
self._precheck = []
# function objects to be called after the main parameter check
self._postcheck = []
self._files4pack = []
self._files4copy = []
self._pack_fname = ""
# Has the `prestage' method been called?
self._is_called_prestage = False
self._used_jobname = []
self._start_time = None # Holds per stage start time
self._stage_duration = None # Holds per stage duration time
self._gpu_time = 0.0 # Accumulates total GPU time
self._num_gpu_subjobs = 0 # Number of GPU subjobs
self._packed_fnames = set()
self._job_manager = JobManager()
StageBase.count += 1
@property
def jobs(self) -> List[Job]:
return self._job_manager.jobs
[docs] def get_prejobs(self) -> List[Job]:
"""
Get the stage's input jobs
"""
if self._PREV_STAGE is None:
return []
return self._PREV_STAGE.get_output_jobs()
[docs] def add_jobs(self, jobs: Iterable[Job]):
"""
Add jobs to the stage's job manager
"""
self._job_manager.add_jobs(jobs)
[docs] def add_job(self, job: Job):
"""
Shortcut for `add_jobs`
"""
self._job_manager.add_jobs([job])
[docs] def get_output_jobs(self) -> List[Job]:
"""
Get the stage's output jobs
"""
return self.filter_jobs(status=[JobStatus.SUCCESS], is_output=[True])
[docs] def filter_jobs(self, **kwargs) -> List[Job]:
"""
Return a list of jobs based on a matching a set of criteria given as
arguments. Read `JobManager.filter_jobs` for available arguments.
"""
return self._job_manager.filter_jobs(**kwargs)
def __getstate__(self, state=None):
state = state if (state) else picklejar.PickleState()
state.NAME = self.NAME
state._ID = self._ID
state._is_shown = self._is_shown
state._is_packed = self._is_packed
state._job_manager = self._job_manager
try:
state._pack_fname = self._pack_fname
except AttributeError:
state._pack_fname = ""
return state
def __setstate__(self, state):
if state.NAME != self.NAME:
raise TypeError("Unmatched stage: %s vs %s" %
(state.NAME, self.NAME))
self.__dict__.update(state.__dict__)
def _print(self, loglevel, msg):
_print(loglevel, msg)
def _log(self, msg):
self._print("quiet", "stage[%d] %s: %s" % (self._INDEX, self.NAME, msg))
def _get_macro_dict(self):
macro_dict = copy.copy(ENGINE.macro_dict)
macro_dict["$STAGENO"] = self._INDEX
return macro_dict
def _gen_unique_jobname(self, suggested_jobname):
trial_jobname = suggested_jobname
number = 1
while trial_jobname in self._used_jobname:
trial_jobname = suggested_jobname + ("_%d" % number)
number += 1
self._used_jobname.append(trial_jobname)
sea.update_macro_dict({"$JOBNAME": trial_jobname})
return trial_jobname
def _get_jobname_and_dir(self, job, macro_dict={}): # noqa: M511
sea.set_macro_dict(self._get_macro_dict())
sea.update_macro_dict(macro_dict)
if self.param.prefix.val != "":
sea.update_macro_dict({"$PREFIX": self.param.prefix.val})
if job.prefix != "" and job.prefix is not None:
sea.update_macro_dict({"$JOBPREFIX": job.prefix})
try:
if job.tag is not None:
sea.update_macro_dict({"$JOBTAG": job.tag})
except AttributeError:
pass
util.chdir(ENGINE.base_dir)
sea.update_macro_dict({"$JOBNAME": self.param.jobname.val})
return (
self.param.jobname.val,
os.path.abspath(self.param.dir.val),
)
def _param_jlaunch_opt_check(self, key, val_list, prefix, ev):
try:
black_list = set(self.PARAM.VALIDATE.jlaunch_opt.black_list.val)
except AttributeError:
return
jlaunch_opt = set(val_list.val)
bad_opt = jlaunch_opt & black_list
if bad_opt:
s = " ".join(bad_opt)
ev.record_error(
prefix,
"Bad values for jlaunch_opt of %s stage: %s" % (self.NAME, s))
def _reg_param_precheck(self, func):
if func not in self._precheck:
self._precheck.append(func)
def _reg_param_postcheck(self, func):
if func not in self._postcheck:
self._postcheck.append(func)
def _set(self, key, setter, transformer=None):
param = self.param[key]
if param.has_tag("setbyuser"):
if callable(setter):
setter(param)
elif isinstance(setter, sea.Atom):
if callable(transformer):
setter.val = transformer(param.val)
else:
setter.val = param.val
def _effect(self, param):
effect_if = param.effect_if
if isinstance(effect_if, sea.List):
for condition, block in zip(effect_if[0::2], effect_if[1::2]):
# TODO: Don't use private function
val = sea.evalor._eval(PARAM, condition)
if isinstance(val, bool):
condition = val
elif isinstance(val[0], str):
condition = _operator[val[0]](self, PARAM, val[1:])
else:
condition = val[0]
if condition:
if isinstance(block, sea.Atom):
block = sea.Map(block.val)
# Checks if within the `block' is the 'effect_if' parameter
# set.
if "effect_if" not in block:
block.effect_if = sea.Atom("none")
# TODO what is the purpose of this line below?
effect_if[1] = block
block = block.dval
param.update(block)
self._effect(param)
return param
[docs] def describe(self):
self._print("quiet", "\nStage %d - %s" % (self._INDEX, self.NAME))
self._print("verbose",
"{\n" + self.param.__str__(" ", tag="setbyuser") + "}")
[docs] def migrate_param(self, param: sea.Map):
"""
Subclasses can implement this to migrate params to provide backward
compatibility with older msj files, ideally with a deprecation warning.
"""
[docs] def check_param(self):
def clear_trjidx(prmdata):
"""
do not use idx files
"""
try:
if "maeff_output" in prmdata:
del prmdata["maeff_output"]["trjidx"]
except (KeyError, TypeError):
pass
check_func_name = "multisim_stage_%d_jlaunch_opt_check" % self._ID
self.PARAM.VALIDATE.jlaunch_opt.check.val = check_func_name
sea.reg_xcheck(check_func_name, self._param_jlaunch_opt_check)
# Note that `self.param's parent should be the global `PARAM'.
# But this statement will implicitly change its parent to `self.PARAM'.
# At the end of this function we need to change it back to `PARAM'.
orig_param_data = self.PARAM.DATA
self.PARAM.DATA = self.param
clear_trjidx(self.PARAM.DATA)
ev = sea.Evalor(self.param, "\n")
for func in self._precheck:
try:
func()
except ParseError as e:
ev.record_error(err=str(e))
sea.check_map(self.PARAM.DATA, self.PARAM.VALIDATE, ev, "setbyuser")
for func in self._postcheck:
try:
func()
except ParseError as e:
ev.record_error(err=str(e))
self.param.set_parent(PARAM.stage)
self.PARAM.DATA = orig_param_data
return ev
[docs] def push(self, job):
if not self._is_called_prestage and not self.param.should_skip.val:
self._is_called_prestage = True
self.prestage()
if job is None:
self._print(
"debug", "All surviving jobs have been pushed into stage[%d]." %
self._INDEX)
self.release()
else:
self._print(
"debug",
"Job was just pushed into stage[%d]: %s" %
(self._INDEX, str(job)),
)
if not self.param.should_sync.val:
self.release()
[docs] def determine(self):
param = self._effect(self.param)
if param.should_skip.val:
self.add_jobs(self.get_prejobs())
[docs] def crunch(self):
"""
This is where jobs of this stage are created. This function should
be overriden by the subclass.
"""
[docs] def restart_subjobs(self, jobs):
"""
Subclass should override this if it supports subjob restarting.
"""
[docs] def release(self, is_restarting=False):
"""
Calls the 'crunch' method to generate new jobs objects and submits
them to the 'QUEUE'.
"""
util.chdir(ENGINE.base_dir)
if not self._is_shown:
self.describe()
self._is_shown = True
is_restarting = True
if not self.param.should_skip.val:
self._is_packed = False
if self._start_time is None:
self._start_time = time.time()
self.determine()
if is_restarting:
self.restart_subjobs(self.filter_jobs(is_restartable=[True]))
if not self.filter_jobs(old=[False]) and not \
self.param.should_skip.val:
# If no new jobs were created from restart_subjobs, run crunch
self.crunch()
jlaunch_opt = [str(e) for e in self.param.jlaunch_opt.val]
if jlaunch_opt != [""]:
for job in self.filter_jobs(is_for_jc=[True],
status=[JobStatus.WAITING]):
job.jlaunch_cmd += jlaunch_opt
if not self.param.dryrun.val:
ENGINE.write_checkpoint()
self._job_manager.submit_jobs(QUEUE)
for job in self.filter_jobs(is_for_jc=[False], old=[False]):
if not job._has_run and callable(job.jlaunch_cmd):
if not self.param.dryrun.val:
job.jlaunch_cmd(job)
job._has_run = True
self.finalize_job(job)
if self.param.dryrun.val:
for job in self.filter_jobs(is_for_jc=[True],
status=[JobStatus.WAITING]):
job.status.set(JobStatus.SUCCESS)
self.finalize_job(job)
if self.jobs:
self.finalize_stage()
[docs] def finalize_job(self, job: Job):
"""
Call `hook_captured_successful_job` on any successful jobs and write
a checkpoint
"""
self._print("debug", "Captured %s" % job)
if job.status == JobStatus.SUCCESS and not self.param.should_skip.val:
self.hook_captured_successful_job(job)
if self.param.transfer_asap.val:
self.pack_stage(force=True)
if job.USE_GPU:
self._gpu_time += job.get_proc_time()
self._num_gpu_subjobs += 1
ENGINE.write_checkpoint()
self._print("debug", "running jobs:")
self._print(
"debug",
self.filter_jobs(status=[JobStatus.WAITING, JobStatus.RUNNING],
old=[False]))
self._print("debug", "successful jobs:")
self._print("debug",
self.filter_jobs(status=[JobStatus.SUCCESS], old=[False]))
self._print("debug", "failed jobs:")
self._print("debug", self.filter_jobs(failed=[True], old=[False]))
[docs] def finalize_stage(self):
"""
If the stage is done running, pack the stage, and if the stage is
successful, continue to the next stage
"""
running_jobs = self.filter_jobs(
status=[JobStatus.WAITING, JobStatus.RUNNING], old=[False])
failed_jobs = self.filter_jobs(failed=[True], old=[False])
successful_jobs = self.filter_jobs(status=[JobStatus.SUCCESS],
old=[False])
if not running_jobs:
if not failed_jobs:
# All jobs were successful
if self.param.should_skip.val:
self._print("quiet", f"\nStage {self._INDEX} is skipped.\n")
else:
self._print(
"quiet", f"\nStage {self._INDEX} completed "
f"successfully.\n")
self.poststage()
move_on = True
elif successful_jobs and self._check_partial_success():
# Some stages can pass with partial success
self.poststage()
move_on = True
else: # No jobs were successful
self._print(
"quiet", f"\nStage {self._INDEX} failed. "
f"No subjobs completed.\n")
move_on = False
self.pack_stage(force=self.param.transfer_asap.val)
if self._NEXT_STAGE is not None and move_on:
self._NEXT_STAGE.push(None)
def _check_partial_success(self):
"""
Check whether or not the stage is considered successful based on
whether it allows completion with some failed/some successful
subjobs. Should be overridden by subclasses that need to implement
this functionality.
"""
return False
[docs] def prestage(self):
pass
[docs] def poststage(self):
pass
[docs] def hook_captured_successful_job(self, job):
pass
[docs] def time_stage(self):
this_stop_time = time.time()
self._stage_duration = util.time_duration(self._start_time,
this_stop_time)
[docs] def pack_stage(self, force=False):
if force or ((not self.param.should_skip.val) and self._should_pack and
(not self._is_packed)):
self._pack_stage()
def _pack_stage(self):
self._is_packed = True
util.chdir(ENGINE.base_dir)
# Standard checkpoint to a file
pack_fname = None
if self.param.compress.val != "":
sea.update_macro_dict({"$STAGENO": self._INDEX})
pack_fname = self.param.compress.val
if not pack_fname.lower().endswith((
PACKAGE_SUFFIX,
"tar.gz",
)):
pack_fname += PACKAGE_SUFFIX
self.param.compress.val = pack_fname
self._pack_fname = pack_fname
print_debug(f"pack_stage: pack_fname:{pack_fname}")
# Collects all data paths for transferring.
data_paths = set()
for job in self.jobs:
# Some stages just pass on a job from the previous stage
# directly to the next stage. So we check the stage ID
# to avoid packing the same job again.
if job.stage._ID == self._ID:
if job.dir and pack_fname:
data_paths.add(job.dir)
else:
for e in job.output:
data_paths.add(e)
reg_file = []
if isinstance(job.jctrl, jobcontrol.Job):
reg_file.extend(job.jctrl.OutputFiles)
reg_file.extend(job.jctrl.InputFiles)
reg_file.extend(job.jctrl.LogFiles)
if job.jctrl.StructureOutputFile:
reg_file.append(job.jctrl.StructureOutputFile)
for fname in reg_file:
if not os.path.isabs(fname):
data_paths.add(os.path.join(job.dir, fname))
# Creates a stage-specific checkpoint file -- just a symbolic link
# to the current checkpoint file.
ENGINE.write_checkpoint()
stage_checkpoint_fname = None
if os.path.isfile(CHECKPOINT_FNAME):
stage_checkpoint_fname = (os.path.basename(CHECKPOINT_FNAME) + "_" +
str(self._INDEX))
shutil.copyfile(CHECKPOINT_FNAME, stage_checkpoint_fname)
# Includes this checkpoint file for transferring.
data_paths.add(os.path.abspath(stage_checkpoint_fname))
if pack_fname:
with tarfile.open(pack_fname,
mode="w:gz",
format=tarfile.GNU_FORMAT,
compresslevel=1) as pack_file:
pack_file.dereference = True
for path in data_paths | set(self._files4pack):
print_debug(f"pack_stage: add_to_tar: {path} exists: "
f"{os.path.exists(path)} cwd: {os.getcwd()}")
if os.path.exists(path):
relpath = util.relpath(path, ENGINE.base_dir)
pack_file.add(relpath)
data_paths = [pack_fname]
if ENGINE.JOBBE:
for path in data_paths:
# Makes all paths relative. Otherwise jobcontrol won't
# transfer them!!!
path = util.relpath(path, ENGINE.base_dir)
if not path:
continue
if path in self._packed_fnames:
continue
self._packed_fnames.add(path)
print_debug(f"pack_stage: outputFile: {path} relpath: "
f"{util.relpath(path, ENGINE.base_dir)} "
f"cwd: {os.getcwd()}")
# Only when we do NOT compress files, we allow to transfer
# files ASAP. DESMOND-7401.
if (self.param.transfer_asap.val and
not pack_fname) and os.path.exists(path):
ENGINE.JOBBE.copyOutputFile(path)
else:
ENGINE.JOBBE.addOutputFile(path)
for path in self._files4copy:
path = util.relpath(path, ENGINE.base_dir)
if path in self._packed_fnames:
continue
self._packed_fnames.add(path)
print_debug(f"pack_stage: files4copy: {path} relpath: "
f"{util.relpath(path, ENGINE.base_dir)} "
f"cwd: {os.getcwd()}")
ENGINE.JOBBE.copyOutputFile(path)
try:
self.time_stage()
self._print(
"quiet",
"Stage %d duration: %s\n" % (self._INDEX, self._stage_duration))
except TypeError:
self._print(
"quiet",
"Stage %d duration could not be calculated." % self._INDEX)
[docs]class JobManager:
"""
A class for managing a stage's jobs. The jobs are stored in the `_jobs`
list internally but should only be accessed by the `jobs` property or
`filter_jobs`.
"""
[docs] def __init__(self):
self._jobs: List[Job] = []
@property
def jobs(self) -> List[Job]:
return [*self._jobs] # Return copy so list is not modified by user
[docs] def clear(self):
self._jobs = []
[docs] def add_jobs(self, jobs: Iterable[Job]):
"""
Add the given jobs to the job manager but does not add duplicate jobs
"""
for job in jobs:
job.old = False
existing_jobs = set(self.jobs)
self._jobs.extend(job for job in jobs if job not in existing_jobs)
[docs] def submit_jobs(self, queue: queue.Queue):
jobs = self.filter_jobs(status=[JobStatus.WAITING], is_for_jc=[True])
queue.push(jobs)
[docs] def filter_jobs(self,
status=None,
old=None,
is_for_jc=None,
is_output=None,
failed=None,
is_launchable=None,
is_restartable=None,
is_incomplete=None) -> List[Job]:
"""
Get a subset of the job manager's jobs. Each argument can either be
None, to indicate no filtering on the property, or a list of
acceptable values for the given argument's property.
When passing in multiple arguments, the function returns jobs which
satisfy all given criteria.
"""
def _filter_job(job):
if status:
if job.status not in status:
return False
if old:
if job.old not in old:
return False
if is_for_jc:
if job.is_for_jc not in is_for_jc:
return False
if is_output:
if job.is_output not in is_output:
return False
if failed:
if job.failed not in failed:
return False
if is_launchable:
if job.is_launchable not in is_launchable:
return False
if is_restartable:
if job.is_restartable not in is_restartable:
return False
if is_incomplete:
if job.is_incomplete not in is_incomplete:
return False
return True
return [*filter(_filter_job, self.jobs)]
[docs]class StructureStageBase(StageBase):
"""
StructureStageBase can be used for stages that take in
a path to a structure, apply some transformation,
and then write out an updated structure.
"""
[docs] def __init__(self, *args, **kwargs):
self.TAG = self.NAME.upper()
super().__init__(*args, **kwargs)
[docs] def crunch(self):
self._print("debug", f"In {self.NAME}.crunch")
for pj in self.get_prejobs():
jobname, jobdir = self._get_jobname_and_dir(pj)
if not os.path.isdir(jobdir):
os.makedirs(jobdir)
with fileutils.chdir(jobdir):
new_job = copy.deepcopy(pj)
new_job.stage = weakref.proxy(self)
new_job.output = JobOutput()
new_job.need_host = False
new_job.dir = jobdir
new_job.status.set(JobStatus.SUCCESS)
new_job.parent = pj
output_fname = self.run(jobname, pj.output.struct_file())
if output_fname is None:
new_job.status.set(JobStatus.BACKEND_ERROR)
else:
new_job.output.set_struct_file(
os.path.abspath(output_fname))
self.add_job(new_job)
self._print("debug", f"Out {self.NAME}.crunch")
[docs] def run(self, jobname: str, input_fname: str) -> Optional[str]:
"""
:param jobname: Jobname for this stage.
:param input_fname: Filename for the input structure.
:return: Filename for the output structure or `None`
if there was an error generating the output.
"""
raise NotImplementedError
class _get_jc_backend_when_needed(object):
def __get__(self, obj, cls):
jobbe = jobcontrol.get_backend()
setattr(cls, "JOBBE", jobbe)
return jobbe
[docs]class Engine(object):
JOBBE = _get_jc_backend_when_needed()
[docs] def __init__(self, opt=None):
# This may be reset by the command options.
self.jobname = None
self.username = None
self.mainhost = None
self.host = None
self.cpu = None
self.inp_fname = None
self.msj_fname = None # The .msj file of this restarting job.
self.MSJ_FNAME = None # Original .msj file name.
self.msj_content = None
self.out_fname = None
# Not serialized because it will be always reset at restarting
self.set = None
self.cfg = None
self.cfg_content = None
self.maxjob = None
self.max_retry = None
self.relay_arg = None
self.launch_dir = None
self.description = None
self.loglevel = GENERAL_LOGLEVEL
self.stage = [] # Serialized. Will be set when serialization.
self.date = None # Date of the original job.
self.time = None # Time of the original job.
self.START_TIME = None # Start time of the original job.
self.start_time = None # Start time. Will change in restarting.
self.stop_time = None # Stop time. Will change in restarting.
self.base_dir = None # Current base dir. Will change in restarting.
# Stage No. to restart from. Will change in restarting.
self.refrom = None
self.base_dir_ = None # Base dir of last job.
self.jobid = None # Current job ID. Will change in restarting.
# Job ID of the original job. Not affected by restarting.
self.JOBID = None
# version numbers and installation will change in restarting
self.version = VERSION # MSJ version.
self.build = BUILD
self.mmshare_ve = envir.CONST.MMSHARE_VERSION
# Installation dir. Will change in restarting.
self.schrodinger = envir.CONST.SCHRODINGER
# Installation dir of the previous run. Will change in restarting.
self.schrod_old = None
self.old_jobnames = []
# Will be set when probing the checkpoint file
self.chkpt_fname = None
self.chkpt_fh = None
self.restart_stage = None
self.__more_init()
if opt:
self.reset(opt)
def __more_init(self):
"""
Will be called by '__init__' and 'deseriealize'.
This is introduced to avoid breaking the previous checkpoint file by
adding a new attribute.
"""
self.notify = None
self.macro_dict = None
self.max_walltime = None
self.checkpoint_requested_event = None
def __find_restart_stage_helper(self, stage):
if stage.filter_jobs(is_incomplete=[True]) or not stage.jobs:
self.restart_stage = self.restart_stage if (
self.restart_stage) else stage
stage._is_shown = False
stage._is_packed = False
def _find_restart_stage(self):
self.restart_stage = None
self._foreach_stage(self.__find_restart_stage_helper)
def _fix_job(self, stage):
if not stage.RESTARTABLE and stage.filter_jobs(is_incomplete=[True]):
stage._job_manager.clear()
for job in stage.jobs:
job.old = True
if job.dir and job.dir.startswith(self.base_dir_ + os.sep):
job.dir = job.dir.replace(self.base_dir_, self.base_dir)
elif job.dir and job.dir == self.base_dir_:
# With JOB_SERVER, the job dir may not be a subdirectory
# so replace the top-level dir too. This is needed for
# restarting MD jobs from the production stage.
job.dir = job.dir.replace(self.base_dir_, self.base_dir)
job.output.update_basedir(self.base_dir_, self.base_dir)
try:
job.input.update_basedir(self.base_dir_, self.base_dir)
except AttributeError:
pass
# Fixes the "stage" attribute of all jobs of this stage. And fixes job
# launching command.
if isinstance(job.stage, int):
job.stage = weakref.proxy(self.stage[job.stage])
if isinstance(job.jlaunch_cmd, list) and isinstance(
job.jlaunch_cmd[0], str):
job.jlaunch_cmd[0] = job.jlaunch_cmd[0].replace(
self.schrod_old, self.schrodinger)
[docs] def restore_stages(self, print_func=print_quiet):
# DESMOND-7934: Preserve the task stage from the checkpoint
# if a custom msj is specified.
checkpoint_stage_list = None
if self.msj_fname and self.msj_content:
checkpoint_stage_list = parse_msj(None,
msj_content=self.msj_content,
pset=self.set)
parsee0 = "the multisim script file" if (self.msj_fname) else None
parsee1 = "the '-set' option" if (self.set) else None
parsee = (parsee0 + " and " + parsee1 if
(parsee0 and parsee1) else parsee0 if (parsee0) else parsee1)
if parsee:
print_func("\nParsing %s..." % parsee)
try:
msj_content = None if (self.msj_fname) else self.msj_content
stage_list = parse_msj(self.msj_fname, msj_content, self.set)
except ParseError as a_name_to_make_flake8_happy:
print_quiet("\n%s\nParsing failed." %
str(a_name_to_make_flake8_happy))
sys.exit(1)
if checkpoint_stage_list and stage_list:
refrom = self.refrom
# Find the restart stage index if not specified
if refrom is None:
# The first stage has the parameters we want
# restore from the checkpoint.
refrom = 2
for idx, s in enumerate(self.stage):
if s.filter_jobs(is_incomplete=[True]) or not s.jobs:
refrom = idx
break
# Restore stages before the restart stage from the checkpoint
# and update the ones after the checkpoint
stage_list = checkpoint_stage_list[:refrom -
1] + stage_list[refrom - 1:]
if "task" != stage_list[0].NAME:
print("ERROR: The first stage is not a 'task' stage.")
sys.exit(1)
if self.cfg:
with open(self.cfg, "r") as fh:
cfg = sea.Map(fh.read())
for stage in stage_list:
if "task" == stage.NAME:
if "desmond" in stage.param.set_family:
stage.param.set_family.desmond.update(cfg)
else:
stage.param.set_family["desmond"] = cfg
if self.cpu:
# Value of `self.cpu' is a string, which specifies either a single
# integer or 3 integers separated by spaces. We must parse the
# string to get the integers and assign the latter to stages.
cpu_str = self.cpu.split()
try:
cpu = [int(e) for e in cpu_str]
n_cpu = len(cpu)
cpu = cpu[0] if (1 == n_cpu) else cpu
if 1 != n_cpu and 3 != n_cpu:
raise ValueError("Incorrect configuration of the CPU: %s" %
self.cpu)
except ValueError:
raise ParseError("Invalid value for the 'cpu' parameter: '%s'" %
self.cpu)
for stage in stage_list:
if stage.NAME in [
"simulate",
"minimize",
"replica_exchange",
"lambda_hopping",
"vrun",
"fep_vrun",
"watermap",
]:
stage.param["cpu"] = cpu
stage.param.cpu.add_tag("setbyuser")
elif stage.NAME in [
"mcpro_simulate",
"watermap_cluster",
"ffbuilder",
]:
stage.param["cpu"] = (cpu if (1 == n_cpu) else
(cpu[0] * cpu[1] * cpu[2]))
stage.param.cpu.add_tag("setbyuser")
stage_state = [
e.__getstate__() for e in (self.stage[:self.refrom] if (
self.refrom and self.refrom > 0) else self.stage)
]
self.stage = build_stages(stage_list, self.out_fname, stage_state)
for stage in self.stage:
self._fix_job(stage)
# `self.msj_content' contains only user's settings. `stage_list[1:-1]'
# will avoid the initial ``primer'' and the final ``concluder'' stages.
self.msj_content = write_msj(stage_list[1:-1], to_str=True)
[docs] def reset(self, opt):
"""
Resets this engine with the command options.
"""
# Resets the '_is_reset_*' attributes.
for k in self.__dict__:
if k[:10] == "_is_reset_":
self.__dict__[k] = False
if opt.refrom:
self.refrom = opt.refrom
if opt.jobname:
self.jobname = opt.jobname
if opt.user:
self.username = opt.user
if opt.mainhost:
self.mainhost = opt.mainhost
if opt.host:
self.host = opt.host
if opt.cpu:
self.cpu = opt.cpu
if opt.inp:
self.inp_fname = os.path.abspath(opt.inp)
if opt.msj:
self.msj_fname = os.path.abspath(opt.msj)
if opt.out:
self.out_fname = opt.out
if opt.set:
self.set = opt.set
if opt.maxjob is not None:
self.maxjob = opt.maxjob
if opt.max_retries is not None:
self.max_retry = opt.max_retries
if opt.relay_arg:
self.relay_arg = sea.Map(opt.relay_arg)
if opt.launch_dir:
self.launch_dir = opt.launch_dir
if opt.notify:
self.notify = opt.notify
if opt.encoded_description:
self.description = cmdline.get_b64decoded_str(
opt.encoded_description)
if opt.quiet:
self.loglevel = "quiet"
if opt.verbose:
self.loglevel = "verbose"
if opt.debug:
self.loglevel = "debug"
if opt.max_walltime:
self.max_walltime = opt.max_walltime
self.cfg = opt.cfg
[docs] def boot_setup(self, base_dir=None):
"""
Set up an `Engine` object, but do not start the queue.
:param base_dir: Set to the path for the base_dir or
`None`, the default, to use the cwd.
"""
global ENGINE, GENERAL_LOGLEVEL, CHECKPOINT_FNAME
self._init_signals()
GENERAL_LOGLEVEL = self.loglevel
ENGINE = self
if self.loglevel == "debug":
_print("quiet", "Multisim debugging mode is on.\n")
if self.description:
_print("quiet", self.description)
#######################################################################
# Boots the engine.
_print("quiet", "Booting the multisim workflow engine...")
self.date = time.strftime("%Y%m%d") if (not self.date) else self.date
self.time = time.strftime("%Y%m%dT%H%M%S") if (
not self.time) else self.time
self.start_time = time.time() if (
not self.start_time) else self.start_time
self.base_dir_ = self.base_dir
self.base_dir = base_dir or os.getcwd()
self.jobid = envir.get("SCHRODINGER_JOBID")
self.JOBID = self.JOBID if (self.JOBID) else self.jobid
self.maxjob = 0 if self.maxjob < 1 else self.maxjob
self.max_retry = (self.max_retry if
(self.max_retry is not None) else int(
envir.get("SCHRODINGER_MAX_RETRIES", 3)))
self.MSJ_FNAME = self.MSJ_FNAME if (self.MSJ_FNAME) else self.msj_fname
# Resets these variables.
self.version = VERSION
self.build = BUILD
self.mmshare_ver = envir.CONST.MMSHARE_VERSION
self.schrod_old = self.schrodinger
self.schrodinger = envir.CONST.SCHRODINGER
_print("quiet", " multisim version: %s" % self.version)
_print("quiet", " mmshare version: %s" % self.mmshare_ver)
_print("quiet", " Jobname: %s" % self.jobname)
_print("quiet", " Username: %s" % self.username)
_print("quiet", " Main job host: %s" % self.mainhost)
_print("quiet", " Subjob host: %s" % self.host)
_print("quiet", " Job ID: %s" % self.jobid)
_print(
"quiet", " multisim script: %s" %
os.path.basename(self.msj_fname if
(self.msj_fname) else self.MSJ_FNAME))
_print(
"quiet", " Structure input file: %s" %
os.path.basename(self.inp_fname))
if self.cpu:
_print("quiet", ' CPUs per subjob: "%s"' % self.cpu)
else:
_print("quiet",
" CPUs per subjob: (unspecified in command)")
_print("quiet",
" Job start time: %s" % time.ctime(self.start_time))
_print("quiet", " Launch directory: %s" % self.launch_dir)
_print("quiet", " $SCHRODINGER: %s" % self.schrodinger)
if oplsdir := os.getenv(mm.OPLS_DIR_ENV):
_print("quiet",
f" $OPLS_DIR: {os.path.basename(oplsdir)}")
# Only need to copy the file once
if os.getenv(constants.SCHRODINGER_MULTISIM_DONT_COPY_OPLS) is None:
os.environ[constants.SCHRODINGER_MULTISIM_DONT_COPY_OPLS] = '1'
opls_fname = f'{self.jobname}-out.opls'
if not Path(opls_fname).exists():
shutil.copyfile(oplsdir, opls_fname)
self.JOBBE.copyOutputFile(opls_fname)
else:
_print("quiet", " $OPLS_DIR: <empty>")
sys.stdout.flush()
self.macro_dict = {
"$MAINJOBNAME": self.jobname,
"$MASTERJOBNAME": self.jobname, # TODO: Here to read old msj files
"$USERNAME": self.username,
"$SUBHOST": self.host,
}
sea.set_macro_dict(copy.copy(self.macro_dict))
self.restore_stages()
if self.chkpt_fh:
def show_job_state(stage, engine=self):
engine._check_stage(stage)
if stage._final_status[0] == "1":
_print("quiet", " Jobnames of failed subjobs:")
for job in stage.filter_jobs(failed=[True]):
_print("quiet", " %s" % job.jobname)
_print("quiet", "")
_print("quiet", "Checkpoint state:")
self._foreach_stage(show_job_state)
_print("quiet", "\nSummary of user stages:")
for stage in self.stage[1:-1]:
if stage.param.title.val:
_print(
"quiet", " stage %d - %s, %s" %
(stage._INDEX, stage.NAME, stage.param.title.val))
else:
_print("quiet", " stage %d - %s" % (stage._INDEX, stage.NAME))
_print("quiet", "(%d stages in total)" % (len(self.stage) - 2))
CHECKPOINT_FNAME = os.path.join(
self.base_dir,
sea.expand_macro(CHECKPOINT_FNAME, sea.get_macro_dict()))
[docs] def boot(self):
"""
Boot the `Engine` and run the jobs.
"""
global QUEUE
self.boot_setup()
max_walltime_timer = None
if self.max_walltime:
self.checkpoint_requested_event = threading.Event()
_print("quiet", f"Checkpoint after {self.max_walltime} seconds.")
max_walltime_timer = threading.Timer(
self.max_walltime,
lambda: self.checkpoint_requested_event.set())
max_walltime_timer.start()
if self.host is None:
_print(
"quiet", "\nCould not determine host. "
"Please check schrodinger.hosts and the queue configuration and try again."
)
self.cleanup(exit_code=1, skip_stage_check=True)
return
QUEUE = queue.Queue(self.host,
self.maxjob,
max_retries=self.max_retry,
periodic_callback=self.handle_jobcontrol_message)
self.JOBBE.addOutputFile(os.path.basename(CHECKPOINT_FNAME))
self.start_time = time.time()
_print("quiet", "\nWorkflow is started now.")
try:
if self.START_TIME is None:
self.START_TIME = self.start_time
self.stage[0].start(self.inp_fname)
else:
if self.refrom is None or self.refrom < 1:
self._find_restart_stage()
else:
self.restart_stage = self.stage[self.refrom]
if self.restart_stage:
if self.msj_fname:
_print(
"quiet", "Updating stages with the new .msj file: "
f"{self.msj_fname}...")
_print(
"quiet",
f"Stage {self.restart_stage._INDEX} and after "
"will be affected by the new .msj file.")
# We need to rerun the `set_family' functions.
self.run_set_family(self.restart_stage._INDEX)
_print(
"quiet", "Restart workflow from stage %d." %
self.restart_stage._INDEX)
self.restart_stage.push(None)
else:
_print(
"quiet",
"The previous multisim job has completed successfully.")
_print(
"quiet",
"If you want to restart from a completed stage, "
"specify its stage number to")
_print(
"quiet", "the '-RESTART' option as: "
"-RESTART <checkpoint-file>:<stage_number>.")
QUEUE.run()
exit_code = 0
skip_stage_check = False
except SystemExit:
sys.exit(1)
except StopRequest:
restart_fname = queue.CHECKPOINT_REQUESTED_FILENAME
with open(restart_fname, 'w') as f:
pass
self.JOBBE.addOutputFile(os.path.basename(restart_fname))
exit_code = 0
skip_stage_check = True
except StopAndRestartRequest:
restart_fname = queue.CHECKPOINT_WITH_RESTART_REQUESTED_FILENAME
with open(restart_fname, 'w') as f:
pass
self.JOBBE.addOutputFile(os.path.basename(restart_fname))
exit_code = 0
skip_stage_check = True
except Exception:
ei = sys.exc_info()
sys.excepthook(ei[0], ei[1], ei[2])
_print(
"quiet", "\n\nUnexpected exception occurred. Terminating the "
"multisim execution...")
exit_code = 1
skip_stage_check = False
if max_walltime_timer is not None:
max_walltime_timer.cancel()
self.cleanup(exit_code, skip_stage_check=skip_stage_check)
[docs] def run_set_family(self, max_stage_idx=None):
"""
Re-run set_family for all task stages up to `max_stage_idx`.
"""
max_stage_idx = max_stage_idx or len(self.stage)
stage = self.stage[0]
while stage is not None and stage._INDEX < max_stage_idx:
if stage.NAME == "task":
stage.set_family()
stage = stage._NEXT_STAGE
[docs] def handle_jobcontrol_message(self, stop=False):
restart = False
if self.checkpoint_requested_event is not None:
restart = self.checkpoint_requested_event.is_set()
if not (stop or restart or self.JOBBE.haltRequested()):
return
_print("quiet",
"\nRecieved 'halt' message. Stopping job on user's request...")
_print("quiet",
f"{len(QUEUE.running_jobs)} subjob(s) are currently running.")
num_killed = QUEUE.stop()
if num_killed:
_print("quiet",
f"{num_killed} subjob(s) failed to stop and were killed.")
else:
_print("quiet", "Subjobs stopped successfully.")
if restart:
raise StopAndRestartRequest()
raise StopRequest()
def _init_signals(self):
# Signal handling stuff.
for signal_name in [
"SIGTERM", "SIGINT", "SIGHUP", "SIGUSR1", "SIGUSR2"
]:
# Certain signals are not available depending on the OS.
if hasattr(signal, signal_name):
signal.signal(
getattr(signal, signal_name),
lambda x, stack_frame: self._handle_signal(signal_name),
)
def _reset_signals(self):
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.signal(signal.SIGUSR1, signal.SIG_DFL)
signal.signal(signal.SIGUSR2, signal.SIG_DFL)
try:
signal.signal(signal.SIGHUP, signal.SIG_DFL)
except AttributeError:
pass
def _handle_signal(self, signal_name):
self._reset_signals()
print("\n\n%s: %s signal received" % (time.asctime(), signal_name))
return self.handle_jobcontrol_message(stop=True)
def _foreach_stage(self, callback):
stage = self.stage[0]._NEXT_STAGE
while stage._NEXT_STAGE is not None:
callback(stage)
stage = stage._NEXT_STAGE
def _check_stage(self, stage, print_func=print_quiet):
INTERPRETATION = {
-2: "2 was skipped",
-1: "0 not run",
0: "0 failed",
1: "1 partially completed",
2: "2 completed",
}
subjob = ""
if stage._is_shown:
if stage.param.should_skip.val:
status = INTERPRETATION[-2]
else:
num_done = len(
stage.filter_jobs(status=[JobStatus.SUCCESS], old=[False]))
num_incomplete = len(
stage.filter_jobs(is_incomplete=[True], old=[False]))
if num_done > 0:
if num_incomplete == 0:
status = INTERPRETATION[2]
else:
status = INTERPRETATION[1]
subjob = " %d subjobs failed, %d subjobs done." % (
num_incomplete, num_done)
else:
if num_incomplete > 0:
status = INTERPRETATION[0]
else:
status = INTERPRETATION[-1]
else:
status = INTERPRETATION[-1]
print_func(" Stage %d %s.%s" % (stage._INDEX, status[2:], subjob))
stage._final_status = status
[docs] def cleanup(self, exit_code=0, skip_stage_check=False):
"""
:param skip_stage_check: Set to True to skip
checking each stage to determine the exit code.
"""
print("Cleaning up files...")
sys.stdout.flush()
self._foreach_stage(
lambda stage: stage._is_shown and stage.pack_stage())
self.stop_time = time.time()
job_duration = util.time_duration(self.start_time, self.stop_time)
print("\nMultisim summary (%s):" % time.ctime(self.stop_time))
self._foreach_stage(self._check_stage)
# FIXME: duration for this restarting?
print(" Total duration: %s" % job_duration)
all_gpu_times = []
all_gpu_subjobs = []
self._foreach_stage(lambda stage: all_gpu_times.append(stage._gpu_time))
self._foreach_stage(
lambda stage: all_gpu_subjobs.append(stage._num_gpu_subjobs))
total_gpu_time = sum(all_gpu_times)
if total_gpu_time:
print(" Total GPU time: %s (used by %d subjob(s))" %
(_time_to_time_str(total_gpu_time), sum(all_gpu_subjobs)))
final_status = []
for stage in self.stage[1:-1]:
if stage.filter_jobs(old=[False]):
final_status.append(int(stage._final_status[0]))
if final_status:
is_successful = min(final_status)
else:
is_successful = 0 # Fail if no stages ran
if exit_code == 0:
if is_successful == 2:
print("Multisim completed.")
elif is_successful == 1:
print("Multisim partially completed.")
else:
print("Multisim failed.")
else:
print("Multisim failed.")
if self.notify:
recipients = (self.notify if (isinstance(self.notify, list)) else [
self.notify,
])
print("\nSending log file to the email address(es): %s" %
", ".join(recipients))
sys.stdout.flush()
log_fname = self.jobname + "_multisim.log"
if os.path.isfile(log_fname):
email_message = open(log_fname, "r").read()
else:
email_message = "Log file: %s not found.\n"
email_message += str(self.JOBID) + "\n"
email_message += self.launch_dir + "\n"
email_message += self.description + "\n"
if exit_code == 0:
if is_successful == 2:
email_message += "Multisim completed."
elif is_successful == 1:
email_message += "Multisim partially completed."
else:
email_message += "Multisim failed."
else:
email_message += "Multisim failed."
import smtplib
from email.mime.text import MIMEText
composer = MIMEText(email_message)
composer["Subject"] = "Multisim: %s" % self.jobname
composer["From"] = "noreply@schrodinger.com"
composer["To"] = ", ".join(recipients)
try:
smtp = smtplib.SMTP()
smtp.connect()
smtp.sendmail("noreply@schrodinger.com", recipients,
composer.as_string())
smtp.close()
except Exception:
print("WARNING: Failed to send notification email.")
print("WARNING: There is probably no SMTP server running on "
"main host.")
if exit_code == 0 and is_successful != 2 and not skip_stage_check:
exit_code = 1
sys.exit(exit_code)
[docs] def serialize(self, fh: BinaryIO):
self.msj_fname = None
self.set = None
self.refrom = None
self.chkpt_fh = None
self.stop_time = time.ctime()
pickle.dump(self, fh)
PickleJar.serialize(fh)
[docs] def serialize_bytes(self) -> bytes:
"""
Return the binary contents of the serialized engine.
"""
fh = BytesIO()
self.serialize(fh)
fh.flush()
return fh.getvalue()
def __getstate__(self):
tmp_dict = copy.copy(self.__dict__)
# Can't checkpoint event
tmp_dict["checkpoint_requested_event"] = None
return tmp_dict
[docs] @staticmethod
def deserialize(fh: BinaryIO):
unpickler = picklejar.CustomUnpickler(fh, encoding="latin1")
engine = unpickler.load()
# This adds class metadata that was serialized
# above. Without this, these values are reset to
# the default.
PickleJar.deserialize(fh)
engine.chkpt_fh = fh
engine.__more_init()
try:
engine.old_jobnames.append(engine.jobname)
except AttributeError:
engine.old_jobnames = [
engine.jobname,
]
return engine
[docs] def write_checkpoint(self, fname=None, num_retry=10):
if not fname:
fname = CHECKPOINT_FNAME
# Write to a temporary file
fname_lock = fname + ".lock"
with open(fname_lock, "wb") as fh:
self.serialize(fh)
for i in range(num_retry):
try:
# not available in py2
os.replace(fname_lock, fname)
return
except AttributeError:
# rename fails on Windows if the destination already exists
if os.path.isfile(fname):
os.remove(fname)
os.rename(fname_lock, fname)
except PermissionError as err: # TODO: DESMOND-9511
print(i, os.getcwd(), fname_lock, fname)
for fn in glob.glob("*"):
print(i, fn)
if i == num_retry - 1:
raise err
else:
print(f"retry {i+1} due to err: {err}")
time.sleep(30)
[docs]class StopRequest(Exception):
pass
[docs]class StopAndRestartRequest(Exception):
pass
[docs]class ParseError(Exception):
pass
[docs]def is_restartable_version(version_string):
version_number = [int(e) for e in version_string.split(".")]
current = [int(e) for e in VERSION.split(".")]
for v, c in zip(version_number[:3], current[:3]):
if v < c:
return False
return True
[docs]def is_restartable_build(engine):
from . import bld_def as bd
bld_comm = bd.bld_types[bd.DESMOND_COMMERCIAL]
try:
restart_files_build = engine.build
except AttributeError:
return True
return restart_files_build != bld_comm or BUILD == bld_comm
[docs]def build_stages(stage_list, out_fname=None, stage_state=[]): # noqa: M511
"""
Build up the stages for the job, adding the initial Primer
and final Concluder stages.
"""
import schrodinger.application.desmond.stage as stg
primer_stage = stg.Primer()
concluder_stage = stg.Concluder(out_fname)
primer_stage.param = copy.deepcopy(stg.Primer.PARAM.DATA)
concluder_stage.param = copy.deepcopy(stg.Concluder.PARAM.DATA)
stage_list.insert(0, primer_stage)
stage_list.append(concluder_stage)
build_stagelinks(stage_list)
for stage, state in zip(stage_list, stage_state):
if stage.NAME == state.NAME:
stage.__setstate__(state)
return stage_list
[docs]def build_stagelinks(stage_list):
for i, stage in enumerate(stage_list[1:-1]):
# Note the list that we are traversing here is `stage_list[1:-1]'.
stage._PREV_STAGE = stage_list[i]
stage._NEXT_STAGE = stage_list[i + 2]
stage_list[0]._PREV_STAGE = None
stage_list[-1]._NEXT_STAGE = None
try:
stage_list[0]._NEXT_STAGE = stage_list[1]
stage_list[-1]._PREV_STAGE = stage_list[-2]
except IndexError:
stage_list[0]._NEXT_STAGE = None
stage_list[-1]._PREV_STAGE = None
for i, stage in enumerate(stage_list):
stage._INDEX = i
[docs]def probe_checkpoint(fname, indent=""):
print(indent + "Probing checkpoint file: %s" % fname)
with open(fname, "rb") as fh:
engine = Engine.deserialize(fh)
engine.schrod_old = engine.schrodinger
def probe_print(s):
print(indent + " " + s)
probe_print(" multisim version: %s" % engine.version)
probe_print(" mmshare version: %s" % engine.mmshare_ver)
probe_print(" Jobname: %s" % engine.jobname)
probe_print(" Previous jobnames: %s" % engine.old_jobnames)
probe_print(" Username: %s" % engine.username)
probe_print(" Main job host: %s" % engine.mainhost)
probe_print(" Subjob host: %s" % engine.host)
if engine.cpu:
probe_print(' CPUs per subjob: "%s"' % engine.cpu)
else:
probe_print(" CPUs per subjob: unspeficied in command")
probe_print(" Original start time: %s" % time.ctime(engine.START_TIME))
probe_print(" Checkpoint time: %s" % engine.stop_time)
probe_print(" Main job ID: %s" % engine.jobid)
probe_print(" Structure input file: %s" %
os.path.basename(engine.inp_fname))
probe_print(" Original *.msj file: %s" %
os.path.basename(engine.MSJ_FNAME))
engine.base_dir_ = engine.base_dir
engine.restore_stages(print_func=print_tonull)
probe_print("\nStages:")
engine.chkpt_fname = fname
def show_failed_jobs(stage, engine=engine):
engine._check_stage(stage, probe_print)
if stage._final_status[0] == "1":
probe_print(" Jobnames of failed subjobs:")
for job in stage.filter_jobs(failed=[True]):
probe_print(" %s" % job.jobname)
engine._foreach_stage(show_failed_jobs)
print()
print("Current version of multisim is %s" % VERSION)
print("This checkpoint file "
"can%sbe restarted with the current version of multisim." %
(" " if (is_restartable_version(engine.version)) else " not "))
return engine
[docs]def escape_string(s):
ret = ""
should_quote = False
if s == "":
return '""'
for c in s:
if c == '"':
ret += '\\"'
should_quote = True
elif c == "'" and ret[-1] == "\\":
ret = ret[:-1] + "'"
should_quote = True
else:
ret += c
if c <= " ":
should_quote = True
if should_quote:
ret = '"' + ret + '"'
return ret
[docs]def append_stage(
cmj_fname,
stage_type,
cfg_file=None,
jobname=None,
dir=None,
compress=None,
parameter={}, # noqa: M511
):
if not os.path.isfile(cmj_fname):
return None
try:
fh = open(cmj_fname, "r")
s = fh.read()
fh.close()
except IOError:
print("error: Reading failed. file: '%s'", cmj_fname)
return None
if stage_type == "simulate":
s += "simulate {\n"
elif stage_type == "minimize":
s += "minimize {\n"
elif stage_type == "replica_exchange":
s += "replica_exchange {\n"
else:
print("error: Unknown stage type '%s'" % stage_type)
return None
if cfg_file is not None:
s += ' cfg_file = "%s"\n' % cfg_file
if jobname is not None:
s += ' jobname = "%s"\n' % jobname
if dir is not None:
s += ' dir = "%s"\n' % dir
if compress is not None:
s += ' compress = "%s"\n' % compress
for p in parameter:
if parameter[p] is not None:
s += " %s = %s\n" % (p, parameter[p])
s += "}\n"
return s
[docs]def concatenate_relaxation_stages(raw):
"""
Attempts to concatenate relaxation stages by finding all adjacent
non-production `simulate` stages. If no concatenatable stages are found,
None is returned. Otherwise, a new raw map with the relaxation `simulate`
stages replaced with a single `concatenate` stage is returned.
:param raw: the raw map representing the MSJ
:type raw: `sea.Map`
:return: a new raw map representing the updated msj, or None.
:rtype: `sea.Map` or `None`
"""
new_raw = copy.deepcopy(raw)
while True:
stages_to_concat, insertion_point = get_concat_stages(new_raw.stage)
if len(stages_to_concat) > 1:
concat_stage = sea.Map()
concat_stage.__NAME__ = "concatenate"
concat_simulate_stages = sea.List()
concat_simulate_stages.add_tag("setbyuser")
for stage in stages_to_concat:
new_raw.stage.remove(stage)
concat_simulate_stages.append(stage)
concat_stage.simulate = concat_simulate_stages
concat_stage.title = concat_stage.simulate[0].title
if 'maeff_output' in concat_stage.simulate[0].val:
concat_stage.maeff_output = concat_stage.simulate[
0].maeff_output
new_raw.stage.insert(insertion_point, concat_stage)
new_raw.stage.add_tag("setbyuser", propagate=False)
else:
break
if len(new_raw.stage) != len(raw.stage):
return new_raw
return None
[docs]def get_concat_stages(stages, param_attr=""):
"""
Get a list of the stages that can be concatenated together, and the
insertion point of the resulting concatenate stage. Stages can be
concatenated if they are adjacent simulate stages with the same restraints,
excluding the final production stage, which can be lambda hopping, replica
exchange, or otherwise the last simulate stage.
:param stages: A list of objects representing multisim stages.
For flexibility, these can be either maps or stages. For stages, a
param attribute must be passed that will give the location of the
param on the stage.
:type stages: list of (sea.Map or stage.Stage)
:param param_attr: optional name of the attribute of the objects param, in
case of a stage.Stage object.
:type param_attr: str
"""
stages_to_concat = []
insertion_point = None
i = last_stage = 0
has_permanent_restrain = False
first_simulate_param = None
last_gcmc_block = None
def is_restrained(param):
return (("restrain" in param and param.restrain.val != "none") or
bool(has_explicit_restraints(param)))
for stage in stages:
stage_param = getattr(stage, param_attr) if param_attr else stage
try:
if stage_param.should_skip.val:
# don't let skipped stages break up otherwise consecutive
# simulate stages
if last_stage:
last_stage = i
i += 1
continue
except AttributeError:
pass
name = stage_param.__NAME__
# TODO we can't check stage.AssignForcefield.NAME here because we can't
# import module stage (would be circular). That's pretty strong
# evidence that we should move these concatenation-related functions to
# a stage_utils module
if name == "assign_forcefield":
has_permanent_restrain |= is_restrained(stage_param)
if name in _PRODUCTION_SIMULATION_STAGES:
break
elif name == "simulate":
# simulate stages must be adjacent to concatenate
if last_stage and last_stage != i - 1:
break
# gcmc stages can only be concatenated if gcmc blocks are identical
# across stages
if "gcmc" in stage_param.keys(tag="setbyuser"):
gcmc_param = stage_param.gcmc
if (last_gcmc_block is not None and
gcmc_param.val != last_gcmc_block.val):
break
else:
gcmc_param = sea.Atom("none")
last_gcmc_block = gcmc_param
# conditions on restrain block to concatenate
if first_simulate_param is None:
# we use whole `stage_param` instead of the restraints
# themselves to (partially) support both old-style "restrain"
# and new-style "restraints" in "Concatenate" stage (single
# "flavor" per stage); ideally this needs to be
# revised/tightened at some point during or after DESMOND-10079
first_simulate_param = stage_param
if restraints_incompatible(stage_param, first_simulate_param,
has_permanent_restrain):
break
if insertion_point is None:
insertion_point = i
last_stage = i
stages_to_concat.append(stage)
i += 1
# the production stage can be either the last simulate stage or one of those
# defined in _PRODUCTION_SIMULATION_STAGES. if we've reached the last stage
# without breaking it means the production stage is a normal simulate stage.
# In that case, we need to remove the production stage from the list of
# stages to concatenate
if i == len(stages) and stages_to_concat:
stages_to_concat.pop()
return stages_to_concat, insertion_point
[docs]def make_empty_restraints(existing='ignore') -> sea.Map:
outcome = sea.Map()
outcome["existing"] = existing
outcome["new"] = sea.List()
return outcome
[docs]def get_restrain(sm: sea.Map) -> sea.Sea:
try:
return sm.get_value("restrain")
except KeyError:
return sea.Atom("none")
[docs]def get_restraints(sm: sea.Map) -> sea.Map:
try:
return sm.get_value("restraints")
except KeyError:
return make_empty_restraints()
[docs]def get_restraints_xor_convert_restrain(param: sea.Map) -> sea.Map:
"""
Returns `restrains` or `restrain` (converted into `restraints`
format) from the `param`. Raises `ValueError` if both are set.
:param param: stage parameters
:return: restraints block
"""
restrain = get_restrain(param)
if has_explicit_restraints(param):
if restrain.val != 'none':
raise ValueError("Concatenate stage cannot include "
"`restrain` and `restraints` simultaneously")
else:
return get_restraints(param)
else:
return _restraints_from_restrain(restrain)
[docs]def restraints_incompatible(param: sea.Map, initial_param: sea.Map,
has_permanent_restrain: bool):
"""
Returns whether restraints parameters are compatible with switching
during a concatenate stage. For compatibility the parameters has to
differ from the initial ones by only a scaling factor (which can include
zero). Furthermore, there can be no differences between restraints and
initial restraints if `permanent_restrain` is truthy, as there is no way
to selectively scale restraints.
:param param: the param for a given stage
:type param: `sea.Map`
:param initial_param: parameters for the first stage
:type initial_param: `sea.Map`
:param has_permanent_restrain: whether or not there are restraints applied
to all stages via the `permanent_restraints` mechanism
:type has_permanent_restrain: bool
:return: a message declaring how the restraints are incompatible, or
an empty string if they are compatible
:rtype: str
"""
param_restrain = get_restrain(param)
initial_param_restrain = get_restrain(initial_param)
have_restrain = (param_restrain.val != "none" or
initial_param_restrain.val != "none")
have_restraints = (has_explicit_restraints(param) or
has_explicit_restraints(initial_param))
if have_restrain and have_restraints:
return ("We cannot concatenate stages that mix restraints "
"given via the `restraints` and `restrain` parameters")
if have_restrain:
current = _restraints_from_restrain(param_restrain)
initial = _restraints_from_restrain(initial_param_restrain)
else:
current = get_restraints(param)
initial = get_restraints(initial_param)
return _check_restraints_compatibility(
current=current,
initial=initial,
has_permanent_restrain=has_permanent_restrain)
[docs]def has_explicit_restraints(param: sea.Map):
"""
:param param: the param for a given stage
:return: whether or not the `restraints` block has new or existing
restraints
"""
if "restraints" in param:
explicit_restraints = param.restraints
has_new = "new" in explicit_restraints and explicit_restraints.new.val
has_existing = ("existing" in explicit_restraints and
explicit_restraints.existing.val !=
constants.EXISTING_RESTRAINT.IGNORE)
return has_new or has_existing
return False
[docs]def check_restrain_diffs(restrain, initial_restrain):
"""
See if the differences between two restrain blocks are
concatenation-compatible, meaning they are both `sea.Map` objects and
differ only by a force constant.
:param restrain: the restrain block for a given stage
:type restrain: `sea.Map` or `sea.List`
:param initial_restrain: the restraints for the first stage
:type initial_restrain: `sea.Map` or `sea.List`
:return: a message declaring how the restraints are incompatible, or
an empty string if they are compatible
:type: str
"""
if restrain == initial_restrain:
return ""
def head_if_single(o):
return o[0] if isinstance(o, sea.List) and len(o) == 1 else o
restrain = head_if_single(restrain)
initial_restrain = head_if_single(initial_restrain)
if isinstance(restrain, sea.Map) and isinstance(initial_restrain, sea.Map):
for restrain_diff in sea.diff(restrain, initial_restrain):
for key in restrain_diff:
if key not in ["force_constant", "fc", "force_constants"]:
return ("We cannot change restraint parameters other than "
"the force constant between integrators")
return ""
elif isinstance(restrain, sea.List) or isinstance(initial_restrain,
sea.List):
return ("We cannot change between lists of restraint parameters "
"unless they are identical.")
else:
raise ValueError("restraints definition blocks expected to be "
"`sea.List` or `sea.Map`")
def _check_restraints_compatibility(initial: sea.Map, current: sea.Map,
has_permanent_restrain: bool) -> str:
"""
Returns whether the restrain parameters are compatible with switching
during a concatenate stage. For compatibility, `current` has to differ
from the `initial` by only a scaling factor (which can include zero).
:param initial: preceding `restraints` block
:type initial: `sea.Map`
:param current: `restraints` block
:type current: `sea.Map`
:param has_permanent_restrain: whether or not there are restraints applied
to all stages via the `permanent_restraints` mechanism
:type has_permanent_restrain: bool
:return: a message declaring how the restraints are incompatible, or
an empty string if they are compatible
:rtype: str
"""
def get(m, n):
return m[n].val if n in m else None
def is_none(r):
return get(r, 'existing') == 'ignore' and not get(r, 'new')
def is_retain(r):
return get(r, 'existing') == 'retain' and not get(r, 'new')
if current != initial and not is_retain(current):
# there can be no difference between restraints
# blocks if system has permanent restraints
if has_permanent_restrain:
return ("Subsequent simulate blocks cannot have differing "
"restrain blocks when permanent restraints are used")
# we cannot go from no restrain to some restrain
if is_none(initial):
return ("Subsequent simulate blocks cannot have restrain block "
"unless the first simulate block or concatenate stage does")
elif not is_none(current): # none is acceptable
if current.existing != initial.existing:
return ("Subsequent simulate blocks cannot have "
"differing restraints")
else:
return check_restrain_diffs(current.new, initial.new)
return ""
def _restraints_from_restrain(
old: Union[sea.Atom, sea.List, sea.Map]) -> sea.Map:
"""
Translates old-style restraints specification ("restrain")
into equivalent new-style blurb. Current version is incomplete,
limited to the features needed for the concatenation support.
:param old: old-style "restrain" block (string, map or list)
:return: equivalent new-style "restrains" block
"""
outcome = make_empty_restraints(
existing='retain' if old.val == 'retain' else 'ignore')
if old.val in ('none', 'retain'):
pass
elif isinstance(old, sea.Map):
outcome["new"].append(old) # copies `old`
elif isinstance(old, sea.List):
outcome["new"].extend(old) # copies `old`
else:
raise ValueError("`restrain` block must be `none`, `retain`, "
"`sea.Map` or `sea.List`")
for blk in outcome["new"]:
fc = blk["fc"] if "fc" in blk else blk.force_constant
blk["force_constants"] = fc # copies `fc`
return outcome
PARAM = None # `sea.Map' object containing the whole job's msj setting
[docs]def msj2sea(fname, msj_content=None):
"""
Parses a file as specified by 'fname' or a string given by 'msj_content'
(if both are given, the former will be ignored), and returns a 'sea.Map'
object that represents the stage settings with a structure like the
following::
stage = [
{ <stage 1 settings> }
{ <stage 2 settings> }
{ <stage 3 settings> }
...
]
Each stage's name can be accessed in this way: raw.stage[1].__NAME__, where
'raw' is the returned 'sea.Map' object.
"""
if not msj_content:
msj_file = open(fname, "r")
msj_content = msj_file.read()
msj_file.close()
raw = sea.Map("stage = [" + msj_content + "]")
# User might set a stage as "stagename = {...}" by mistake. Raises a
# meaningful exception when this happens.
for s in raw.stage:
if isinstance(s, sea.Atom) and s.val == "=":
raise SyntaxError(
"Stage name cannot be followed by the assignment operator: '='")
stg = list(range(len(raw.stage)))[::2]
for i in stg:
try:
s = raw.stage[i + 1]
name = raw.stage[i].val.lower()
s.__NAME__ = name
except IndexError:
raise SyntaxError("stage %d is undefined" % i + 1)
stg.reverse()
for i in stg:
del raw.stage[i]
return raw
[docs]def msj2sea_full(fname, msj_content=None, pset=""):
raw = msj2sea(fname, msj_content)
for i, e in enumerate(raw.stage):
try:
stage_cls = StageBase.stage_cls[e.__NAME__]
except KeyError:
raise ParseError("Unrecognized stage name: %s\n" % e.__NAME__)
param = copy.deepcopy(stage_cls.PARAM.DATA)
param.update(e, tag="setbyuser")
param.__NAME__ = e.__NAME__
param.__CLS__ = stage_cls
raw.stage[i] = param
if pset:
raw.stage.insert(0, sea.Atom("dummy"))
pset = pset.split(chr(30))
for e in pset:
i = e.find("=")
if i <= 0:
raise ParseError("Syntax error in setting: %s" % e)
try:
key = e[:i].strip()
value = e[i + 1:].strip()
except IndexError:
raise ParseError("Syntax error in setting: %s" % e)
if key == "" or value == "":
raise ParseError("Syntax error in setting: %s" % e)
raw.set_value(key,
sea.Map("value = %s" % value).value.val,
tag="setbyuser")
del raw.stage[0]
return raw
[docs]def parse_msj(fname, msj_content=None, pset=""):
"""
sea.update_macro_dict must be called prior to calling this function.
"""
try:
global PARAM
PARAM = msj2sea_full(fname, msj_content, pset)
PARAM.stage.insert(0, sea.Atom("dummy"))
except Exception as e:
raise ParseError(str(e))
print_debug("All settings of this multisim job...")
print_debug(PARAM)
print_debug("All settings of this multisim job... End")
# Constructs stage objects and their parameters.
stg = []
error = ""
for i, e in enumerate(PARAM.stage[1:], start=1):
s = e.__CLS__() # Creates a stage instance.
# handle backward-compatibility issues
s.migrate_param(e)
s.param = e
# FIXME: How to deal with exceptions raised by the parsing and checking
# functions?
ev = s.check_param()
if ev.err != "":
error += "Value error(s) for stage[%d]:\n%s\n" % (i, ev.err)
if ev.unchecked_map:
error += "Unrecognized parameters for stage[%d]: %s\n\n" % (
i, ev.unchecked_map)
stg.append(s)
if error:
raise ParseError(error)
return stg
[docs]def write_msj(stage_list, fname=None, to_str=True):
"""
Given a list of stages, writes out a .msj file of the name 'fname'.
If 'to_str' is True, a string will be returned. The returned string
contains the contents of the .msj file.
If 'to_str' is False and not file name is provided, then this function does
nothing.
"""
if fname is None and to_str is False:
return
s = ""
for stage in stage_list:
s += stage.NAME + " {\n"
s += stage.param.__str__(" ", tag="setbyuser")
s += "}\n\n"
if fname is not None:
fh = open(fname, "w")
print(s, file=fh)
fh.close()
if to_str:
return s
[docs]def write_sea2msj(stage_list, fname=None, to_str=True):
if fname is None and to_str is False:
return
s = ""
for stage in stage_list:
name = stage.__NAME__
s += name + " {\n"
s += stage.__str__(" ", tag="setbyuser")
s += "}\n\n"
if fname is not None:
fh = open(fname, "w")
print(s, file=fh)
fh.close()
if to_str:
return s
def _collect_inputfile_from_file_list(list_, fnames):
for v in list_:
if isinstance(v, sea.Atom) and isinstance(v.val, str) and v.val != "":
fnames.append(v.val)
elif isinstance(v, sea.Map):
_collect_inputfile_from_file_map(v, fnames)
elif isinstance(v, sea.List):
_collect_inputfile_from_file_list(v, fnames)
return fnames
def _collect_inputfile_from_file_map(map, fnames):
for k, v in map.key_value():
if isinstance(v, sea.Atom) and isinstance(v.val, str) and v.val != "":
fnames.append(v.val)
elif isinstance(v, sea.Map):
_collect_inputfile_from_file_map(v, fnames)
elif isinstance(v, sea.List):
_collect_inputfile_from_file_list(v, fnames)
return fnames
def _collect_inputfile_from_list(list_, fnames):
for v in list_:
if isinstance(v, sea.Map):
_collect_inputfile_from_map(v, fnames)
elif isinstance(v, sea.List):
_collect_inputfile_from_list(v, fnames)
return fnames
def _collect_inputfile_from_map(map, fnames):
for k, v in map.key_value():
if (isinstance(v, sea.Atom) and k.endswith("_file") and
isinstance(v.val, str) and v.val != ""):
fnames.append(v.val)
elif isinstance(v, sea.Map):
if k.endswith("_file"):
_collect_inputfile_from_file_map(v, fnames)
else:
_collect_inputfile_from_map(v, fnames)
elif isinstance(v, sea.List):
if k.endswith("_file"):
_collect_inputfile_from_file_list(v, fnames)
else:
_collect_inputfile_from_list(v, fnames)
return fnames
[docs]class AslValidator(object):
CTSTR = """hydrogen
1 0 0 0 1 0 999 V2000
-1.6976 2.1561 0.0000 C 0 0 0 0 0 0
M END
$$$$
"""
CT = None
[docs] def __init__(self):
self.invalid_asl_expr = []
[docs] def is_valid(self, asl):
if AslValidator.CT is None:
import schrodinger.structure as structure
AslValidator.CT = next(
structure.StructureReader.fromString(AslValidator.CTSTR,
format="sd"))
import schrodinger.structutils.analyze as analyze
try:
analyze.evaluate_asl(AslValidator.CT, asl)
except mm.MmException:
return False
return True
[docs] def validate(self, a):
if isinstance(a, sea.Atom):
v = a.val
if (isinstance(v, str) and v[:4].lower() == "asl:" and
not self.is_valid(v[4:])):
self.invalid_asl_expr.append(v)
elif isinstance(a, sea.Sea):
a.apply(lambda x: self.validate(x))
[docs]def validate_asl_expr(stage_list):
"""
Validates all ASL expressions that start with the "asl:" prefix.
"""
validator = AslValidator()
for stage in stage_list:
if not stage.param.should_skip.val:
stage.param.apply(lambda x: validator.validate(x))
return validator.invalid_asl_expr
# - Registered functions should share this prototype: foo( stage, PARAM, arg ),
# and should return a boolean value, where PARAM is a sea.Map object in the
# global scope of this module.
_operator = {}
[docs]def reg_checking(name, func):
_operator[name] = func