import copy
import os
from pathlib import Path
from typing import TYPE_CHECKING
from typing import List
from typing import Optional
from schrodinger.application.desmond import cmj
from schrodinger.application.desmond import launch_utils
from schrodinger.application.desmond import stage
from schrodinger.application.desmond import util
from schrodinger.application.desmond.constants import FEP_TYPES
from schrodinger.application.desmond.constants import UiMode
from schrodinger.application.desmond.constants import SIMULATION_PROTOCOL
from schrodinger.application.desmond import multisim
from schrodinger.application.desmond.starter.ui import cmdline
from schrodinger.application.desmond.starter.ui.cmdline import FepArgs
from schrodinger.utils import sea
if TYPE_CHECKING:
from schrodinger.application.scisol.packages.fep import graph # noqa: F401
[docs]def is_fmp(input_fname: str) -> bool:
"""
Return True if the input_fname is for an fmp file.
Returns False otherwise.
"""
return Path(input_fname).suffix.lower() in ['.fmp', '.pkl']
[docs]def find_fmpdb_file(args: cmdline.BaseArgs) -> Optional[str]:
"""
Tries to find the fmpdb file's name if it's needed. If it cannot find,
returns `None` and issues a warning message.
"""
filename = None
inp_file = args.inp_file
if inp_file is None:
# Both restarting and extending (running simulations for longer) the
# previous job won't set the inp_file argument, then we require the
# .fmpdb file to be in the CWD and be named after "<jobname>_out.fmpdb".
# If the .fmpdb file is not found, warn.
filename = args.JOBNAME and (args.JOBNAME + '_out.fmpdb')
else:
# Two cases:
# 1. This is a graph-expansion job: inp_file is a .fmp file.
# Find the .fmpdb file from the .fmp file.
# 2. This is a from-scratch new job: inp_file is a .mae file or a .fmp
# file without an associated .fmpdb file.
# No need for a .fmpdb file and warnings.
if is_fmp(inp_file):
# Case 1
from schrodinger.application.scisol.packages.fep import graph # noqa: F811
g = graph.Graph.deserialize(inp_file)
filename = g.fmpdb and g.fmpdb.filename
if filename is None:
# Case 2
return None
if filename is None:
# Tried, but cannot figure out the file name.
print("WARNING: Cannot figure out the .fmpdb file name.")
elif not os.path.isfile(filename):
print("WARNING: .fmpdb file not found: %s" % filename)
else:
return filename
[docs]def prepare_files_and_command_for_restart(args: cmdline.BaseArgs) -> List[str]:
"""
Return a command for launching the restart multisim job.
Exit if the multisim stage could not be found.
:param args: Command line arguments.
"""
cpt_fname, rst_stage_idx = launch_utils.get_checkpoint_file_and_restart_number(
args.checkpoint)
rst_whole = rst_stage_idx is not None and rst_stage_idx > 0
engine = launch_utils.read_checkpoint_file(cpt_fname)
if not rst_stage_idx:
rst_stage_idx = launch_utils.get_restart_stage_from_engine(engine)
multisim_stage_numbers = launch_utils.get_multisim_stage_numbers(engine)
if not multisim_stage_numbers:
raise RestartException("ERROR: multisim stage not found.")
launch_utils.validate_restart_stage(engine, rst_stage_idx)
stage_data_fnames = launch_utils.prepare_multisim_files_for_restart(
engine,
multisim_stage_numbers,
rst_stage_idx,
rst_whole,
skip_traj=args.skip_traj)
host = f'{args.HOST}:{args.ppj}' if ':' not in args.HOST else args.HOST
cmd = launch_utils.prepare_command_for_restart(engine,
stage_data_fnames,
args.HOST,
args.SUBHOST,
cpt_fname,
maxjob=args.ppj,
jobname=args.JOBNAME,
msj=args.msj,
rst_stage_idx=rst_stage_idx,
rst_whole=rst_whole)
forcefield = None
cmd += launch_utils.additional_command_arguments(
stage_data_fnames, args.RETRIES, args.WAIT, args.LOCAL, args.DEBUG,
args.TMPDIR, forcefield, args.OPLSDIR, args.NICE, args.SAVE)
return cmd
[docs]def prepare_files_and_command_for_fep_restart_extend(
args: FepArgs,
edges: List[str],
launcher_stage_name: str = stage.FepLauncher.NAME
) -> (List[str], List[str]):
stage_data_fnames = []
cpt_fname, rst_stage_idx = launch_utils.get_checkpoint_file_and_restart_number(
args.checkpoint)
engine = launch_utils.read_checkpoint_file(cpt_fname)
rst_whole = False
multisim_stage_numbers = launch_utils.get_multisim_stage_numbers(engine)
if not multisim_stage_numbers:
raise RestartException("ERROR: multisim stage not found.")
if args.mode == UiMode.EXTEND:
from schrodinger.application.scisol.packages.fep import utils
from schrodinger.application.scisol.packages.fep import graph # noqa: F811
rst_stage_idx = multisim_stage_numbers[-1]
g = graph.Graph.deserialize(f"{engine.jobname}_out.fmp")
if g.fep_type in [FEP_TYPES.ABSOLUTE_BINDING, FEP_TYPES.SOLUBILITY]:
sim_protocols = {
utils.get_ligand_node(e).short_id: e.simulation_protocol
for e in g.edges_iter()
}
else:
sim_protocols = {
"_".join(e.short_id): e.simulation_protocol
for e in g.edges_iter()
}
# Modifies the checkpoint file.
fep_launcher_stage = launch_utils.find_stage(engine.stage,
launcher_stage_name)
current = cmj.ENGINE
cmj.ENGINE = engine
fep_launcher_stage.restart_edges(edges, sim_protocols=sim_protocols)
cmj.ENGINE = current
cpt_fname = "extend_%s" % os.path.basename(args.checkpoint)
engine.write_checkpoint(cpt_fname)
# Modifies the msj file.
main_msj = _update_input_graph_file_param(args, engine)
args.msj = f"{args.JOBNAME}.extend.msj"
main_msj.write(args.msj)
fep_launcher_dispatch = main_msj.get(f"{launcher_stage_name}.dispatch")
extend_stage_nums = dict()
for protocol_name, job in fep_launcher_dispatch.items():
extend_stage_nums[protocol_name] = _write_extend_msjs(
args, fep_launcher_dispatch[protocol_name], protocol_name)
main_msj.put(f"{launcher_stage_name}.dispatch",
sea.Map(fep_launcher_dispatch))
main_msj.put(f"{launcher_stage_name}.restart",
sea.Map(extend_stage_nums))
main_msj.write(args.msj)
elif args.mode == UiMode.RESTART:
rst_whole = rst_stage_idx is not None and rst_stage_idx > 0
if not rst_stage_idx:
rst_stage_idx = launch_utils.get_restart_stage_from_engine(engine)
launch_utils.validate_restart_stage(engine, rst_stage_idx)
main_msj = _update_input_graph_file_param(args, engine)
args.msj = f"{args.JOBNAME}.restart.msj"
main_msj.write(args.msj)
stage_data_fnames.extend(
launch_utils.prepare_multisim_files_for_restart(
engine,
multisim_stage_numbers,
rst_stage_idx,
rst_whole,
skip_traj=args.skip_traj))
stage_data_fnames.extend(
_prepare_mapper_stages_for_restart(engine, rst_stage_idx))
# Deduplicate names
stage_data_fnames = list(set(stage_data_fnames))
cmd = launch_utils.prepare_command_for_restart(engine,
stage_data_fnames,
args.HOST,
args.SUBHOST,
cpt_fname,
maxjob=args.maxjob,
jobname=args.JOBNAME,
msj=args.msj,
rst_stage_idx=rst_stage_idx,
rst_whole=rst_whole)
return cmd, stage_data_fnames
def _prepare_mapper_stages_for_restart(engine: cmj.Engine,
rst_stage_idx: int) -> List[str]:
"""
If the FepMapperReport or FepMapperCleanup stage is present after the
restart stage, include the FepMapper stage data when restarting the job.
:param engine: Represents the current job state.
:param rst_stage_idx: The restart stage index.
:return: List of filenames to be used for restarting the job.
"""
stage_data_fnames = []
for stg in engine.stage[rst_stage_idx:]:
if stg.NAME in [
stage.FepMapperReport.NAME, stage.FepMapperCleanup.NAME
]:
mapper_stage_number = (
launch_utils.find_stage_number(engine.stage,
stage.FepMapper.NAME) or
launch_utils.find_stage_number(engine.stage,
stage.ProteinFepMapper.NAME) or
launch_utils.find_stage_number(
engine.stage, stage.CovalentFepMapper.NAME)) - 1
# find_stage_number returns a 1-based index, and engine.stage's
# first element is a primer stage which should be uncounted
if mapper_stage_number < rst_stage_idx:
stage_data_fnames.append(
f"{engine.jobname}_{mapper_stage_number}-out.tgz")
break
return stage_data_fnames
def _update_input_graph_file_param(args: FepArgs,
engine: cmj.Engine) -> multisim.Msj:
"""
Add input_graph_file to main msj for certain stages
"""
if args.msj:
main_msj = multisim.parse(args.msj)
else:
main_msj = multisim.parse(string=engine.msj_content)
# Use the previous out.fmp file as input graph
STAGES = [
"vacuum_report", "fep_mapper_report", "fep_mapper_cleanup",
"fep_absolute_binding_analysis", "fep_solubility_analysis"
]
out_fmp_name = f"{engine.jobname}_out.fmp"
if os.path.isfile(out_fmp_name):
for stage_name in STAGES:
for s in main_msj.find(stage_name):
s.put("input_graph_file", out_fmp_name)
return main_msj
def _write_extend_msjs(args: FepArgs,
job: List[List[str]],
protocol_name="default") -> List[int]:
"""
Write the .extend.msj and return the stage number to restart from for
each msj
"""
extend_stage_nums = []
# Order of legs has to be same as order of legs in dispatch
# avoid iterating over list while mutating it
for job_args in copy.deepcopy(job):
jobname = job_args[job_args.index("-JOBNAME") + 1]
leg_type = util.get_leg_type_from_jobname(jobname)
leg_name = util.get_leg_name_from_jobname(jobname)
extend_stage_num = _write_extend_msj(args, job, protocol_name, leg_type,
leg_name)
if extend_stage_num:
extend_stage_nums.append(extend_stage_num)
return extend_stage_nums
def _write_extend_msj(args: FepArgs, job: List[List[str]],
protocol_name: str, leg_type: str, leg_name: str) -> \
Optional[int]:
"""
1) Modify the dispatch command for the given protocol+leg in the main
msj to point to the new extend msjs.
2) Modify the added time of the extend stage in the subjob msj.
3) Return the extend stage number.
This will only modify the msj if the leg existed in the original msj (for
example, it may skip 'vacuum').
"""
from schrodinger.application.desmond import multisim
sim_time = args.get_time_for_leg(leg_type)
if sim_time is None:
return None
leg_idx = _find_leg_idx(leg_name, job)
if leg_idx is None:
return
if sim_time < 1.0:
# skip extending a given leg if simulation time is < 1 ps.
job.pop(leg_idx)
return
if protocol_name == SIMULATION_PROTOCOL.DEFAULT:
new_msj_fname = f"{args.JOBNAME}_{leg_type}.extend.msj"
elif protocol_name in [
SIMULATION_PROTOCOL.CHARGED, SIMULATION_PROTOCOL.FORMALCHARGED
]:
new_msj_fname = f"{args.JOBNAME}_{leg_type}_chg.extend.msj"
else:
new_msj_fname = f"{args.JOBNAME}_{leg_type}_{protocol_name}.extend.msj"
msj_fname = None
for i, e in enumerate(job[leg_idx]):
if (e == "-m"):
msj_fname = job[leg_idx][i + 1]
job[leg_idx][i + 1] = new_msj_fname
if msj_fname and os.path.isfile(msj_fname):
subjob_msj = multisim.parse(msj_fname)
subjob_msj.put("desmond_extend.added_time", sim_time)
subjob_msj.write(new_msj_fname)
else:
raise RestartException(f"ERROR: File not found: '{msj_fname}'")
return subjob_msj.find_stages("desmond_extend")[0].STAGE_INDEX
def _find_leg_idx(leg: str, job: List) -> Optional[int]:
"""
:param: job
List of commands.
"""
for idx, cmd in enumerate(job):
jobname = cmd[cmd.index("-JOBNAME") + 1]
if util.get_leg_name_from_jobname(jobname) == leg:
return idx
[docs]class RestartException(Exception):
pass