import copy
import os
from pathlib import Path
from typing import TYPE_CHECKING
from typing import List
from typing import Optional
from schrodinger.application.desmond import cmj
from schrodinger.application.desmond import launch_utils
from schrodinger.application.desmond import stage
from schrodinger.application.desmond import util
from schrodinger.application.desmond.constants import FEP_TYPES
from schrodinger.application.desmond.constants import FepLegTypes
from schrodinger.application.desmond.starter.ui import cmdline
from schrodinger.application.desmond.starter.ui.fep_plus import FepPlusArgs
from schrodinger.utils import sea
if TYPE_CHECKING:
from schrodinger.application.desmond import multisim
from schrodinger.application.scisol.packages.fep import graph # noqa: F401
[docs]def prepare_files_and_command_for_restart(
args: cmdline.BaseArgs) -> Optional[List[str]]:
"""
Using the parse command line arguments, read the checkpoint, prepare the
files needed for restarting and return the command to restart the job.
Returns `None` if the multisim stage could not be found.
:param args: Command line arguments.
"""
cpt_fname, rst_stage_idx = launch_utils.get_checkpoint_file_and_restart_number(
args.checkpoint)
rst_whole = rst_stage_idx is not None and rst_stage_idx > 0
engine = launch_utils.read_checkpoint_file(cpt_fname)
if not rst_stage_idx:
rst_stage_idx = launch_utils.get_restart_stage_from_engine(engine)
launch_utils.validate_restart_stage(engine, rst_stage_idx)
multisim_stage_numbers = launch_utils.get_multisim_stage_numbers(engine)
if not multisim_stage_numbers:
return None
stage_data_fnames = launch_utils.prepare_multisim_files_for_restart(
engine,
multisim_stage_numbers,
cpt_fname,
rst_stage_idx,
rst_whole,
skip_traj=args.skip_traj)
cmd = launch_utils.prepare_command_for_restart(engine,
stage_data_fnames,
args.HOST,
args.SUBHOST,
cpt_fname,
maxjob=args.maxjob,
jobname=args.JOBNAME,
msj=args.msj,
rst_stage_idx=rst_stage_idx,
rst_whole=rst_whole,
input_fname=args.inp_file)
forcefield = None
cmd += launch_utils.additional_command_arguments(
stage_data_fnames, args.RETRIES, args.WAIT, args.LOCAL, args.DEBUG,
args.TMPDIR, forcefield, args.OPLSDIR, args.NICE, args.SAVE)
return cmd
[docs]def is_fmp(input_fname: str) -> bool:
"""
Return True if the input_fname is for an fmp file.
Returns False otherwise.
"""
return Path(input_fname).suffix.lower() in ['.fmp', '.pkl']
[docs]def find_fmpdb_file(args: cmdline.BaseArgs) -> Optional[str]:
"""
Tries to find the fmpdb file's name if it's needed. If it cannot find,
returns `None` and issues a warning message.
"""
filename = None
inp_file = args.inp_file
if inp_file is None:
# Both restarting and extending (running simulations for longer) the
# previous job won't set the inp_file argument, then we require the
# .fmpdb file to be in the CWD and be named after "<jobname>_out.fmpdb".
# If the .fmpdb file is not found, warn.
filename = args.JOBNAME and (args.JOBNAME + '_out.fmpdb')
else:
# Two cases:
# 1. This is a graph-expansion job: inp_file is a .fmp file.
# Find the .fmpdb file from the .fmp file.
# 2. This is a from-scratch new job: inp_file is a .mae file or a .fmp
# file without an associated .fmpdb file.
# No need for a .fmpdb file and warnings.
if is_fmp(inp_file):
# Case 1
from schrodinger.application.scisol.packages.fep import graph # noqa: F811
g = graph.Graph.deserialize(inp_file)
filename = g.fmpdb and g.fmpdb.filename
if filename is None:
# Case 2
return None
if filename is None:
# Tried, but cannot figure out the file name.
print("WARNING: Cannot figure out the .fmpdb file name.")
elif not os.path.isfile(filename):
print("WARNING: .fmpdb file not found: %s" % filename)
else:
return filename
[docs]def prepare_files_and_command_for_fep_restart_extend(
args: FepPlusArgs,
edges: List[str],
launcher_stage_name: str = stage.FepLauncher.NAME
) -> (List[str], List[str]):
cmd = []
stage_data_fnames = []
if args.extend:
if not args.checkpoint:
args.checkpoint = launch_utils.find_checkpoint_file()
engine = launch_utils.read_checkpoint_file(args.checkpoint)
sim_protocols = None
inp_file = Path(args.inp_file or f"{engine.jobname}_out.fmp")
if not inp_file.exists():
raise RestartException(f"ERROR: {inp_file} must exist for "
f"extending.")
if is_fmp(inp_file):
from schrodinger.application.scisol.packages.fep import utils
from schrodinger.application.scisol.packages.fep import graph # noqa: F811
g = graph.Graph.deserialize(inp_file)
if g.fep_type in [FEP_TYPES.ABSOLUTE_BINDING, FEP_TYPES.SOLUBILITY]:
sim_protocols = {
utils.get_ligand_node(e).short_id: e.simulation_protocol
for e in g.edges_iter()
}
else:
sim_protocols = {
"_".join(e.short_id): e.simulation_protocol
for e in g.edges_iter()
}
# Modifies the checkpoint file.
fep_launcher_stage = launch_utils.find_stage(engine.stage,
launcher_stage_name)
current = cmj.ENGINE
cmj.ENGINE = engine
fep_launcher_stage.restart_edges(edges, sim_protocols=sim_protocols)
cmj.ENGINE = current
cpt_fname = "extend_%s" % os.path.basename(args.checkpoint)
orig_cpt_fname = args.checkpoint
engine.write_checkpoint(cpt_fname)
# Modifies the msj file.
main_msj = _update_input_graph_file_param(args, engine)
args.msj = f"{args.JOBNAME}.extend.msj"
main_msj.write(args.msj)
fep_launcher_dispatch = main_msj.get(f"{launcher_stage_name}.dispatch")
extend_stage_nums = dict()
for protocol_name, job in fep_launcher_dispatch.items():
extend_stage_nums[protocol_name] = _write_extend_msjs(
args, fep_launcher_dispatch[protocol_name], protocol_name)
main_msj.put(f"{launcher_stage_name}.dispatch",
sea.Map(fep_launcher_dispatch))
main_msj.put(f"{launcher_stage_name}.restart",
sea.Map(extend_stage_nums))
main_msj.write(args.msj)
elif args.RESTART:
cpt_fname, rst_stage_idx = launch_utils.get_checkpoint_file_and_restart_number(
args.checkpoint)
rst_whole = rst_stage_idx is not None and rst_stage_idx > 0
engine = launch_utils.read_checkpoint_file(cpt_fname)
if not rst_stage_idx:
rst_stage_idx = launch_utils.get_restart_stage_from_engine(engine)
launch_utils.validate_restart_stage(engine, rst_stage_idx)
main_msj = _update_input_graph_file_param(args, engine)
args.msj = f"{args.JOBNAME}.restart.msj"
main_msj.write(args.msj)
orig_cpt_fname = cpt_fname
if args.RESTART or args.extend:
multisim_stage_numbers = launch_utils.get_multisim_stage_numbers(engine)
if not multisim_stage_numbers:
raise RestartException("ERROR: multisim stage not found.")
if args.extend:
rst_stage_idx = multisim_stage_numbers[-1]
rst_whole = False
stage_data_fnames.extend(
launch_utils.prepare_multisim_files_for_restart(
engine,
multisim_stage_numbers,
orig_cpt_fname,
rst_stage_idx,
rst_whole,
skip_traj=args.skip_traj))
stage_data_fnames.extend(
_prepare_mapper_stages_for_restart(engine, rst_stage_idx))
# Deduplicate names
stage_data_fnames = list(set(stage_data_fnames))
cmd = launch_utils.prepare_command_for_restart(
engine,
stage_data_fnames,
args.HOST,
args.SUBHOST,
cpt_fname,
maxjob=args.maxjob,
jobname=args.JOBNAME,
msj=args.msj,
rst_stage_idx=rst_stage_idx,
rst_whole=rst_whole)
return cmd, stage_data_fnames
def _prepare_mapper_stages_for_restart(engine: cmj.Engine,
rst_stage_idx: int) -> List[str]:
"""
If the FepMapperReport or FepMapperCleanup stage is present after the
restart stage, include the FepMapper stage data when restarting the job.
:param engine: Represents the current job state.
:param rst_stage_idx: The restart stage index.
:return: List of filenames to be used for restarting the job.
"""
stage_data_fnames = []
for stg in engine.stage[rst_stage_idx:]:
if stg.NAME in [
stage.FepMapperReport.NAME, stage.FepMapperCleanup.NAME
]:
mapper_stage_number = (
launch_utils.find_stage_number(engine.stage,
stage.FepMapper.NAME) or
launch_utils.find_stage_number(engine.stage,
stage.ProteinFepMapper.NAME) or
launch_utils.find_stage_number(
engine.stage, stage.CovalentFepMapper.NAME)) - 1
# find_stage_number returns a 1-based index, and engine.stage's
# first element is a primer stage which should be uncounted
if mapper_stage_number < rst_stage_idx:
stage_data_fnames.append(
f"{engine.jobname}_{mapper_stage_number}-out.tgz")
break
return stage_data_fnames
def _update_input_graph_file_param(args: FepPlusArgs,
engine: cmj.Engine) -> "multisim.Msj":
"""
Add input_graph_file to main msj for certain stages
"""
from schrodinger.application.desmond import multisim # FIXME: DESMOND-9971
if args.msj:
main_msj = multisim.parse(args.msj)
else:
main_msj = multisim.parse(string=engine.msj_content)
# Use the previous out.fmp file as input graph
STAGES = [
"vacuum_report", "fep_mapper_report", "fep_mapper_cleanup",
"fep_absolute_binding_analysis", "fep_solubility_analysis"
]
out_fmp_name = f"{engine.jobname}_out.fmp"
if os.path.isfile(out_fmp_name):
for stage_name in STAGES:
for s in main_msj.find(stage_name):
s.put("input_graph_file", out_fmp_name)
return main_msj
def _write_extend_msjs(args: FepPlusArgs,
job: List[List[str]],
protocol_name="default") -> List[int]:
"""
Write the .extend.msj and return the stage number to restart from for
each msj
"""
extend_stage_nums = []
# Order of legs has to be same as order of legs in dispatch
# avoid iterating over list while mutating it
for job_args in copy.deepcopy(job):
jobname = job_args[job_args.index("-JOBNAME") + 1]
leg_type = util.get_leg_type_from_jobname(jobname)
leg_name = util.get_leg_name_from_jobname(jobname)
extend_stage_num = _write_extend_msj(args, job, protocol_name, leg_type,
leg_name)
if extend_stage_num:
extend_stage_nums.append(extend_stage_num)
return extend_stage_nums
def _write_extend_msj(args: FepPlusArgs, job: List[List[str]],
protocol_name: str, leg_type: str, leg_name: str) -> \
Optional[int]:
"""
1) Modify the dispatch command for the given protocol+leg in the main
msj to point to the new extend msjs.
2) Modify the added time of the extend stage in the subjob msj.
3) Return the extend stage number.
This will only modify the msj if the leg existed in the original msj (for
example, it may skip 'vacuum').
"""
from schrodinger.application.desmond import multisim
if leg_type == FepLegTypes.VACUUM:
sim_time = args.time_vacuum
elif leg_type == FepLegTypes.COMPLEX:
sim_time = args.time_complex
elif leg_type == FepLegTypes.SOLVENT:
sim_time = args.time_solvent
elif leg_type == FepLegTypes.HYDRATION:
sim_time = args.hydration_fep_sim_time
elif leg_type == FepLegTypes.SUBLIMATION:
sim_time = args.sublimation_fep_sim_time
elif leg_type in (FepLegTypes.FRAGMENT_HYDRATION,
FepLegTypes.RESTRAINED_FRAGMENT_HYDRATION):
# These legs are never extended
return None
else:
sim_time = args.time
leg_idx = _find_leg_idx(leg_name, job)
if leg_idx is None:
return
if sim_time < 1.0:
# skip extending a given leg if simulation time is < 1 ps.
job.pop(leg_idx)
return
if protocol_name == "default":
new_msj_fname = f"{args.JOBNAME}_{leg_type}.extend.msj"
else:
new_msj_fname = f"{args.JOBNAME}_{leg_type}_{protocol_name}.extend.msj"
msj_fname = None
for i, e in enumerate(job[leg_idx]):
if (e == "-m"):
msj_fname = job[leg_idx][i + 1]
job[leg_idx][i + 1] = new_msj_fname
if msj_fname and os.path.isfile(msj_fname):
subjob_msj = multisim.parse(msj_fname)
subjob_msj.put("desmond_extend.added_time", sim_time)
subjob_msj.write(new_msj_fname)
else:
raise RestartException(f"ERROR: File not found: '{msj_fname}'")
return subjob_msj.find_stages("desmond_extend")[0].STAGE_INDEX
def _find_leg_idx(leg: str, job: List) -> Optional[int]:
"""
:param: job
List of commands.
"""
for idx, cmd in enumerate(job):
jobname = cmd[cmd.index("-JOBNAME") + 1]
if util.get_leg_name_from_jobname(jobname) == leg:
return idx
[docs]class RestartException(Exception):
pass