"""
Tasks for performing trajectory analyses.
Copyright Schrodinger, LLC. All rights reserved.
"""
import enum
import itertools
from collections import namedtuple
from typing import List
from schrodinger.models import parameters
from schrodinger.models.jsonable import JsonableEnum
from schrodinger.tasks import tasks, jobtasks
from schrodinger.utils import sea, fileutils
try:
from schrodinger.application.desmond.packages import analysis
from schrodinger.application.desmond.packages import traj_util
except ImportError:
# Desmond not installed
analysis = None
traj_util = None
if analysis:
LIST_ANALYZERS = [
analysis.CatPiFinder, analysis.PiPiFinder, analysis.HalogenBondFinder,
analysis.HydrogenBondFinder, analysis.SaltBridgeFinder
]
else:
LIST_ANALYZERS = []
EMPTY_ASL = 'NOT all'
[docs]class AnalysisMode(JsonableEnum):
"""
Enum defining the analysis classes used.
"""
Distance = enum.auto()
Angle = enum.auto()
Torsion = enum.auto()
PlanarAngle = enum.auto()
HydrogenBondFinder = enum.auto()
HalogenBondFinder = enum.auto()
SaltBridgeFinder = enum.auto()
ProtLigPiInter = enum.auto()
ProtProtPiInter = enum.auto()
RMSD = enum.auto()
AtomRMSF = enum.auto()
ResRMSF = enum.auto()
Gyradius = enum.auto()
PolarSurfaceArea = enum.auto()
SolventAccessibleSurfaceArea = enum.auto()
MolecularSurfaceArea = enum.auto()
PiPiFinder = enum.auto()
CatPiFinder = enum.auto()
Energy = enum.auto() # Energy Plots; uses SelectionAnalsyisMatrix analyzer
INTERACTION_MODES = {
AnalysisMode.HalogenBondFinder, AnalysisMode.HydrogenBondFinder,
AnalysisMode.SaltBridgeFinder, AnalysisMode.ProtLigPiInter,
AnalysisMode.ProtProtPiInter, AnalysisMode.PiPiFinder,
AnalysisMode.CatPiFinder
}
NON_COVALENT_MODES = {
AnalysisMode.HalogenBondFinder, AnalysisMode.HydrogenBondFinder,
AnalysisMode.SaltBridgeFinder
}
PI_INTERACTION_MODES = {AnalysisMode.PiPiFinder, AnalysisMode.CatPiFinder}
HASHABLE_MODES = {
*INTERACTION_MODES, AnalysisMode.RMSD, AnalysisMode.AtomRMSF,
AnalysisMode.ResRMSF
}
INDEX_BASED_MODES = {AnalysisMode.AtomRMSF, AnalysisMode.ResRMSF}
RMSF_PLOTS = {AnalysisMode.AtomRMSF, AnalysisMode.ResRMSF}
# Advanced plots display in the advanced frame, and open in new window
ADVANCED_PLOTS = {
AnalysisMode.AtomRMSF, AnalysisMode.ResRMSF, AnalysisMode.Energy
}
ANGSTROMS_RMSD = 'RMSD (Å)'
ANGSTROMS_RMSF = 'RMSF (Å)'
ANGSTROMS_DIST = 'Distance (Å)'
ANGSTROMS_AREA = 'Area (Ų)'
ANGLE_DEGREES = 'Angle (Degrees)'
DIHEDRAL_DEGREES = 'Dihedral (Degrees)'
INSTANCES = 'Number of Interactions'
AnalysisModeInfo = namedtuple('_AnalysisModeInfo',
['name', 'analysis_class', 'unit'])
ANALYSIS_MODE_MAP = {}
if analysis:
ANALYSIS_MODE_MAP = {
AnalysisMode.Distance: AnalysisModeInfo(
name='Distance',
analysis_class=analysis.Distance,
unit=ANGSTROMS_DIST),
AnalysisMode.Angle: AnalysisModeInfo(name='Angle',
analysis_class=analysis.Angle,
unit=ANGLE_DEGREES),
AnalysisMode.Torsion: AnalysisModeInfo(name='Dihedral',
analysis_class=analysis.Torsion,
unit=DIHEDRAL_DEGREES),
AnalysisMode.PlanarAngle: AnalysisModeInfo(
name='Planar Angle',
analysis_class=analysis.PlanarAngle,
unit=ANGLE_DEGREES),
AnalysisMode.HydrogenBondFinder: AnalysisModeInfo(
name='Hydrogen Bonds',
analysis_class=analysis.HydrogenBondFinder,
unit=INSTANCES),
AnalysisMode.HalogenBondFinder: AnalysisModeInfo(
name='Halogen Bonds',
analysis_class=analysis.HalogenBondFinder,
unit=INSTANCES),
AnalysisMode.SaltBridgeFinder: AnalysisModeInfo(
name='Salt Bridge',
analysis_class=analysis.SaltBridgeFinder,
unit=INSTANCES),
AnalysisMode.ProtLigPiInter: AnalysisModeInfo(
name='Protein-ligand interaction',
analysis_class=analysis.ProtLigInter,
unit=INSTANCES),
AnalysisMode.ProtProtPiInter: AnalysisModeInfo(
name='Protein-protein interaction',
analysis_class=analysis.ProtProtPiInter,
unit=INSTANCES),
AnalysisMode.RMSD: AnalysisModeInfo(name='RMSD',
analysis_class=analysis.RMSD,
unit=ANGSTROMS_RMSD),
AnalysisMode.AtomRMSF: AnalysisModeInfo(name='RMSF (Per Atom)',
analysis_class=analysis.RMSF,
unit=ANGSTROMS_RMSF),
AnalysisMode.ResRMSF: AnalysisModeInfo(
name='RMSF (Per Residue)',
analysis_class=analysis.ProteinRMSF,
unit=ANGSTROMS_RMSF),
AnalysisMode.Gyradius: AnalysisModeInfo(
name='Radius of Gyration',
analysis_class=analysis.Gyradius,
unit='Radius of Gyr.'),
AnalysisMode.PolarSurfaceArea: AnalysisModeInfo(
name='Polar surface area',
analysis_class=analysis.PolarSurfaceArea,
unit=ANGSTROMS_AREA),
AnalysisMode.SolventAccessibleSurfaceArea: AnalysisModeInfo(
name='Solvent accessible surface area',
analysis_class=analysis.SolventAccessibleSurfaceArea,
unit=ANGSTROMS_AREA),
AnalysisMode.MolecularSurfaceArea: AnalysisModeInfo(
name='Molecular Surface Area',
analysis_class=analysis.MolecularSurfaceArea,
unit=ANGSTROMS_AREA),
AnalysisMode.PiPiFinder: AnalysisModeInfo(
name='Pi-Pi Stacking',
analysis_class=analysis.PiPiFinder,
unit=INSTANCES),
AnalysisMode.CatPiFinder: AnalysisModeInfo(
name='Pi-Cation',
analysis_class=analysis.CatPiFinder,
unit=INSTANCES),
}
[docs]class ResidueInfo(parameters.CompoundParam):
"""
Model for protein residues in RMSF Residue
"""
residue_names = parameters.ListParam()
secondary_structures = parameters.ListParam()
# NOTE: res atoms and b factor are indexed by residue label
[docs]class TrajectoryAnalysisOutput(parameters.CompoundParam):
result = parameters.ListParam()
system_title = parameters.StringParam()
atom_numbers = parameters.ListParam() # TODO move to input
legend_name = parameters.StringParam()
fit_asl = parameters.StringParam() # TODO move to input
settings_hash = parameters.StringParam() # TODO move to input
residue_info = ResidueInfo() # Used by ResRMSF only
[docs]class TrajectoryEnergyOutput(jobtasks.CmdJobTask.Output):
results_file: str
system_title: str
legend_name: str
fit_asl: str = EMPTY_ASL # TODO move to input
settings_hash = parameters.StringParam() # TODO move to input
[docs]class TrajectoryAnalysisTaskMixin:
"""
Mixin for task classes to perform trajectory analyses.
"""
[docs] def mainFunction(self):
self.output.legend_name = ANALYSIS_MODE_MAP[
self.input.analysis_mode].name
args = []
if all([
self.input.msys_model, self.input.cms_model,
self.input.trajectory
]):
msys_model = self.input.msys_model
cms_model = self.input.cms_model
trj = self.input.trajectory
else:
msys_model, cms_model, trj = traj_util.read_cms_and_traj(
self.input.cms_fname)
self.output.system_title = cms_model.fsys_ct.title
args.extend([msys_model, cms_model])
if self.input.centroid_asl_list:
centroids = [
analysis.Centroid(msys_model, cms_model, asl)
for asl in self.input.centroid_asl_list
]
args.extend(centroids)
if self.input.additional_args:
args.extend(self.input.additional_args)
AnalysisClass = ANALYSIS_MODE_MAP[
self.input.analysis_mode].analysis_class
analyzer = AnalysisClass(*args, **self.input.additional_kwargs)
result = analysis.analyze(trj, analyzer)
analysis_mode = self.input.analysis_mode
if AnalysisClass in LIST_ANALYZERS:
for val in result:
self.output.atom_numbers.extend(list(itertools.chain(*val)))
self.output.result.append(len(val))
elif analysis_mode == AnalysisMode.ProtProtPiInter:
# FIXME - result should be a list of dicts, but seems to currently be a single dict.
for inter_map in result:
self.output.result.append(
sum([len(v) for v in inter_map.values()]))
elif analysis_mode == AnalysisMode.ProtLigPiInter:
for val in result:
inters = sum([len(v) for v in val.values()])
self.output.result.append(inters)
elif analysis_mode == AnalysisMode.ResRMSF:
res, vals = result
# Additionally, run secondary structure calculations
ss_results = self._runSecondaryStructure(trj, msys_model, cms_model)
self.output.residue_info.secondary_structures = ss_results
self.output.residue_info.residue_names = res
self.output.result = list(map(float, vals))
else:
# Results may be numpy float32 type so need to map to float to make Jsonable.
self.output.result = list(map(float, result))
[docs] def isInteractionTask(self):
return self.input.analysis_mode in INTERACTION_MODES
[docs] def isMultiseriesInteractionTask(self):
return (self.input.for_multiseries_plot and self.isInteractionTask())
def _runSecondaryStructure(self, trj, msys_model, cms_model):
"""
Runs a secondary structure analysis on input aids
Organizes results by residue.
"""
aids = self.input.additional_args[0]
ss_analysis = analysis.SecondaryStructure(msys_model, cms_model, aids)
res_labels, frames = analysis.analyze(trj, ss_analysis)
ss_results_by_res = [list() for _ in range(len(res_labels))]
for frame in frames:
for idx, entry in enumerate(frame):
ss_results_by_res[idx].append(entry)
return ss_results_by_res
[docs]class TrajectoryAnalysisTask(TrajectoryAnalysisTaskMixin,
tasks.BlockingFunctionTask):
"""
A blocking task to run relatively quick analyses.
"""
input = TrajectoryAnalysisInput()
output = TrajectoryAnalysisOutput()
[docs]class TrajectoryAnalysisSubprocTask(TrajectoryAnalysisTaskMixin,
tasks.ComboSubprocessTask):
"""
This is a general task to perform trajectory analyses
"""
input = TrajectoryAnalysisInput()
output = TrajectoryAnalysisOutput()
[docs]class TrajectoryEnergyJobTask(jobtasks.CmdJobTask):
"""
Task for running Energy analysis jobs.
"""
input = TrajectoryEnergyInput()
output = TrajectoryEnergyOutput()
[docs] class JobConfig(jobtasks.JobConfig):
host_settings: jobtasks.HostSettings = jobtasks.HostSettings(
allowed_host_types=jobtasks.AllowedHostTypes.GPU_ONLY)
[docs] def makeCmd(self):
cms_file = self.input.cms_fname
trj_dir = self.input.trj_dir
cfg_file = self.input.cfg_fname
# Strip *-out.cms and *-out.cms.gz:
base = fileutils.strip_extension(cms_file)[:-len('-out')]
# Write the energy matrix file:
st2_file = self.getTaskFilename(self.name + '.st2')
self.writeSt2File(st2_file)
results_file = self.name + '-results'
# Save absolute path:
self.output.results_file = self.getTaskFilename(results_file)
cmd = [
'run',
'analyze_simulation.py',
cms_file,
trj_dir,
results_file,
st2_file,
'-sim-cfg',
cfg_file,
]
return cmd
[docs] def writeSt2File(self, st2_file):
"""
Write an `*st2` file containing set ASLs for this task.
"""
asl_data = sea.Map()
assert self.input.set_asls
asl_data['Asl_Selections'] = self.input.set_asls
matrix_data = sea.Map()
matrix_data['SelectionEnergyMatrix'] = asl_data
m = sea.Map()
m['Keywords'] = [matrix_data]
with open(st2_file, 'w') as fh:
fh.write(str(m))
[docs] def isInteractionTask(self):
return False
[docs] def isMultiseriesInteractionTask(self):
# This plot only shows one series at a time
return False
[docs]class TrajPlotModel(parameters.CompoundParam):
"""
Main model for the trajectory plot GUI
"""
traj_analysis_task = TrajectoryAnalysisTask()
traj_analysis_subproctask = TrajectoryAnalysisSubprocTask()
traj_energy_jobtask = TrajectoryEnergyJobTask()
[docs]class RmsfPlotModel(parameters.CompoundParam):
"""
Model for an RMSF plot panel
"""
secondary_structure_colors: bool = True
b_factor_plot: bool = True
[docs]class SetRow(parameters.CompoundParam):
name: str
[docs]class EnergyPlotModel(parameters.CompoundParam):
"""
Model for an Energy plot panel
"""
sets: List[SetRow]
selected_sets: List[SetRow]
exclude_self_terms: bool = False
coulomb: bool = True
van_der_waals: bool = True
bond: bool = True
angle: bool = True
dihedral: bool = True