import traceback
import numpy as np
from itertools import combinations
import schrodinger.application.desmond.fep_edge_data as fep_edge_data
from schrodinger.application.desmond import ana
from schrodinger.application.desmond import constants
from schrodinger.application.desmond import struc
from schrodinger.application.desmond import util
from schrodinger.application.desmond.ana import Option
from schrodinger.application.desmond.ana import Premise
from schrodinger.application.desmond.packages import analysis
from schrodinger.application.desmond.packages import topo
from schrodinger.application.desmond.packages import traj
# FIXME: why using private methods here?
_cpx_sid_protein_residues = fep_edge_data.FEPEdgeData._set_cpx_sid_protein_residues
_sort_residue_tags = fep_edge_data.FEPEdgeData._parse_residue_tags
_pl_interaction_similarity_matrix = fep_edge_data.FEPEdgeData._pl_interaction_similarity_matrix
NO_LIGAND_ASL_ERR = ("WARNING: LigandASL is not defined, required for parched "
"trajectory generation. Skipping parched.")
class _GetRepFrameIndex(ana.Task):
"""
Calculates the similarity matrix based on the protein-ligand interaction
data saved in the database, and then clusters the trajectory frames with the
similarity matrix. The final result is the index of the frame in the center
of the largest cluster.
Results:
- key = "Keywords[i].ResultLambda{fep_lambda}.RepFrameIndex"
- val = Frame index
"""
def __init__(self, name: str, fep_lambda: int):
super().__init__(name)
# yapf: disable
self._results_field = f"Keywords[i].ResultLambda{fep_lambda}"
pl_inter = f"{self._results_field}.Keywords[i].ProtLigInter"
dictprot = f"{pl_inter}.DictProtein"
self._options = [[
# For calculating the similarity matrix
("hbonds", Option(f"{pl_inter}.HBondResult")),
("hydrophobic", Option(f"{pl_inter}.HydrophobicResult")),
("pi_cat", Option(f"{pl_inter}.PiCatResult")),
("pi_pi", Option(f"{pl_inter}.PiPiResult")),
("polar", Option(f"{pl_inter}.PolarResult")),
("water_br", Option(f"{pl_inter}.WaterBridgeResult")),
# For getting the residue tags
('prot_hbond', Option(f'{dictprot}Hbond[*][0]', [])),
('prot_hydrophobic', Option(f'{dictprot}Hydrophobic[*][0]', [])),
('prot_ionic', Option(f'{dictprot}Ionic[*][0]', [])),
('prot_pi', Option(f'{dictprot}Pi[*][0]', [])),
('prot_water_br', Option(f'{dictprot}WaterBridge[*][0]', [])),
]]
# yapf: enable
def execute(self, _, **kwargs):
# FIXME: This function is basically a copy of the old `_get_sim_matrix`.
# We should **rewrite** it to make data-dependency explicit and to
# remove all dependency on the legacy `fep_edge_data`.
res = _sort_residue_tags(_cpx_sid_protein_residues(kwargs))
pl_contacts = kwargs
for k in ('prot_hbond', 'prot_hydrophobic', 'prot_ionic', 'prot_pi',
'prot_water_br'):
pl_contacts.pop(k, None)
mtx = _pl_interaction_similarity_matrix(pl_contacts, res)
centers, frame_labels = analysis.cluster(mtx)
largest_cluster = max(set(frame_labels), key=frame_labels.count)
self.results = [
ana.Datum(f"{self._results_field}.RepFrameIndex",
centers[largest_cluster])
]
class _GetRepSolubilityFrameIndex(ana.Task):
"""
Calculates the similarity matrix based on the molecular interaction
saved in the database, and then clusters the trajectory frames with the
similarity matrix. The final result is the index of the frame in the center
of the largest cluster.
Results:
- key = "Keywords[i].ResultLambda{fep_lambda}.RepFrameIndex"
- val = Frame index
"""
def __init__(self, name: str, fep_lambda: int):
super().__init__(name)
# yapf: disable
self._results_field = f"Keywords[i].ResultLambda{fep_lambda}"
ac_inter = f"{self._results_field}.Keywords[i].AmorphousCrystalInter"
self._options = [[
# For calculating the similarity matrix
("hbonds", Option(f"{ac_inter}.HBondResult")),
("hydrophobic", Option(f"{ac_inter}.HydrophobResult")),
("pi_cat", Option(f"{ac_inter}.CatPiResult")),
("pi_pi", Option(f"{ac_inter}.PiPiResult")),
("polar", Option(f"{ac_inter}.PolarResult")),
]]
# yapf: enable
def execute(self, _, **kwargs):
amorph_result = kwargs
nframes = len(amorph_result['hbonds'])
inter = np.array([amorph_result[k] for k in amorph_result],
dtype=bool).T
mtx = np.ones((nframes, nframes))
for (iframe, jframe) in combinations(list(range(nframes)), 2):
iresult, jresult = inter[iframe, :], inter[jframe, :]
sim = sum(iresult & jresult) / sum(iresult | jresult) \
if sum(iresult | jresult) != 0 else 0.0
mtx[iframe, jframe] = mtx[jframe, iframe] = sim
centers, frame_labels = analysis.cluster(mtx)
largest_cluster = max(set(frame_labels), key=frame_labels.count)
self.results = [
ana.Datum(f"{self._results_field}.RepFrameIndex",
centers[largest_cluster])
]
class _GenRepStructureForFepLambda(ana.Task):
"""
Generates the representative structure corresponding to the representative
frame of the parched trajectory.
Results:
- key = "Keywords[i].ResultLambda{fep_lambda}.RepMaeFname"
- val = Name of the representative structure file: '{out_bname_pattern}.mae'
"""
def __init__(self, name: str, fep_lambda: int, out_bname_pattern: str):
super().__init__(name)
results_field = f"Keywords[i].ResultLambda{fep_lambda}"
# yapf: disable
# RE. code formatting:
# The code reads so much clearer if we align the `Premise`s vertically.
def execute(_,
cms_fname: Premise(f"{results_field}.ParchedCmsFname"), # noqa: F722
trj_fname: Premise(f"{results_field}.ParchedTrjFname"), # noqa: F722
frame_index: Premise(f"{results_field}.RepFrameIndex") # noqa: F722
):
# yapf: enable
util.verify_file_exists(cms_fname)
util.verify_traj_exists(trj_fname)
out_fname = eval(f"f'{out_bname_pattern}.mae'")
tr = traj.read_traj(trj_fname)
_, cms_model = topo.read_cms(cms_fname)
fsys_ct = cms_model.fsys_ct
topo.update_ct(fsys_ct, cms_model, tr[frame_index])
fsys_ct.property['i_fep_representative_frame'] = frame_index
struc.delete_structure_properties(fsys_ct, [
constants.CT_TYPE, "s_m_original_cms_file",
"s_chorus_trajectory_file"
] + [e for e in fsys_ct.property if e.startswith('r_chorus')])
fsys_ct.write(out_fname)
self.results = [
ana.Datum(f"{results_field}.RepMaeFname", out_fname)
]
self.execute = execute
[docs]class GenRepStructureForFep(ana.Task):
"""
Task to generate representative structures for both FEP lambda states.
Please refer to the docstrings of `_GetRepFrameIndex` and
`_GenRepStructureForFepLambda` as for some detail about the generation of
the representative structure.
Results:
- key = "Keywords[i].ResultLambda0.RepMaeFname"
- val = Name of the representative structure file: 'rep_lambda0.mae'
- key = "Keywords[i].ResultLambda1.RepMaeFname"
- val = Name of the representative structure file: 'rep_lambda1.mae'
"""
[docs] def __init__(self, name: str):
super().__init__(
name, (_GetRepFrameIndex(name + "_frame_index_0", 0),
_GetRepFrameIndex(name + "_frame_index_1", 1),
_GenRepStructureForFepLambda(name + "_0", 0, "rep_lambda0"),
_GenRepStructureForFepLambda(name + "_1", 1, "rep_lambda1")))
[docs]class GenRepStructureForAbsoluteFep(ana.Task):
"""
Task to generate representative structures for the fully interacting
state of Absolute Binding FEP.
Please refer to the docstrings of `_GetRepFrameIndex` and
`_GenRepStructureForFepLambda` as for some detail about the generation of
the representative structure.
Results:
- key = "Keywords[i].ResultLambda0.RepMaeFname"
- val = Name of the representative structure file: 'rep_lambda0.mae'
"""
[docs] def __init__(self, name: str):
super().__init__(
name, (_GetRepFrameIndex(f"{name}_frame_index_0", 0),
_GenRepStructureForFepLambda(f"{name}_0", 0, "rep_lambda0")))
[docs]class GenRepStructureForSublimationFep(ana.Task):
"""
Task to generate representative structures for the fully interacting
state of Solubility FEP.
Please refer to the docstrings of `_GetRepFrameIndex` and
`_GenRepStructureForFepLambda` as for some detail about the generation of
the representative structure.
Results:
- key = "Keywords[i].ResultLambda1.RepMaeFname"
- val = Name of the representative structure file: 'rep_lambda1.mae'
"""
[docs] def __init__(self, name: str):
super().__init__(
name, (_GetRepSolubilityFrameIndex(f"{name}_frame_index_1", 1),
_GenRepStructureForFepLambda(f"{name}_1", 1, "rep_lambda1")))
[docs]def postprocess_traj(arkdb_fname: str, ref_ct_fname: str = None):
try:
arkdb = ana.ArkDb(arkdb_fname)
except Exception as e:
# FIXME: Change print to exception?
print("ERROR: Reading %s failed. "
"Cancelled postprocessing trajectory.\n%s\n%s" %
(arkdb_fname, e, traceback.format_exc()))
return (None,) * 6
arkdb.put('ReferenceStruct', ref_ct_fname)
# Parched trajectory generation
fep_type = arkdb.get('Keywords[i].FEPSimulation.PerturbationType', None)
# Representative structure generation
leg_type = arkdb.get('Keywords[i].LigandInfo.LegType', None) or \
arkdb.get('Keywords[i].PeptideInfo.LegType', None)
if fep_type == constants.FEP_TYPES.ABSOLUTE_BINDING:
ParchTrajectory = ana.ParchTrajectoryForAbsoluteFep
GenRepStructure = GenRepStructureForAbsoluteFep
elif fep_type == constants.FEP_TYPES.SOLUBILITY:
ParchTrajectory = ana.TrajectoryForSolubilityFep
GenRepStructure = GenRepStructureForSublimationFep
else:
ParchTrajectory = ana.ParchTrajectoryForFep
GenRepStructure = GenRepStructureForFep
tasks = [ParchTrajectory("parch_trajectory")]
if fep_type in [
constants.FEP_TYPES.LIGAND_SELECTIVITY,
constants.FEP_TYPES.PROTEIN_SELECTIVITY
]:
lambda_0_asl = arkdb.get('Keywords[i].ResultLambda0.LigandASL', None)
lambda_1_asl = arkdb.get('Keywords[i].ResultLambda1.LigandASL', None)
if lambda_0_asl is None or lambda_1_asl is None:
# There will be no LigandASL defined when there was no ligand
# analysis done as part of SID report, such as during a multisite
# mutation
print(NO_LIGAND_ASL_ERR)
return (None,) * 6
if leg_type not in [
constants.FepLegTypes.HYDRATION, constants.FepLegTypes.SOLVENT,
constants.FepLegTypes.VACUUM, None
]:
tasks.append(GenRepStructure("gen_rep_structure"))
is_completed = ana.execute(arkdb, tasks)
if not is_completed:
print(
"ERROR: Postprocessing trajectory for FEP simulations failed with "
"errors:\n %s\n" % "\n ".join(ana.collect_logs(tasks)))
# FIXME: Since the results have been stored in the database, returning the
# results by this function may be unnecessary. But changing this
# behavior demands refactoring at a larger scope.
return (arkdb.get("Keywords[i].ResultLambda0.ParchedCmsFname", None),
arkdb.get("Keywords[i].ResultLambda0.ParchedTrjFname", None),
arkdb.get("Keywords[i].ResultLambda1.ParchedCmsFname", None),
arkdb.get("Keywords[i].ResultLambda1.ParchedTrjFname", None),
arkdb.get("Keywords[i].ResultLambda0.RepMaeFname", None),
arkdb.get("Keywords[i].ResultLambda1.RepMaeFname", None))
if __name__ == '__main__': # pragma: no cover
import sys
arkdb_fname = sys.argv[1]
files = postprocess_traj(arkdb_fname)
print(f"OUTPUT: {files}")