"""
This module contains classes that wrap prime backends that predict sequence
structures. Many of the parameters and class constants are from a time when
documentation was sparse. In the future, it's possible we'll tweak these
numbers as needed.
"""
import enum
import os
from schrodinger.job.util import hunt
from schrodinger.models import jsonable
from schrodinger.models import parameters
from schrodinger.protein import alignment
from schrodinger.protein import annotation
from schrodinger.protein import residue
from schrodinger.protein import sequence
from schrodinger.protein.constants import SSA_MAP
from schrodinger.tasks import tasks
from schrodinger.utils import fileutils
SEQ_ANNO_TYPES = annotation.ProteinSequenceAnnotations.ANNOTATION_TYPES
PSP_DATA_DIR = os.path.join(hunt('psp', 'data'))
[docs]class AbstractPredictor(tasks.SubprocessCmdTask):
"""
Base class for all predictors. Derived classes are expected to implement
class constants for:
- EXE - A string that should match to the predictors executable. Most
of the time this is the same as PREDICTOR_NAME
- PREDICTOR_NAME - A string with the name of the predictor. This is
used to find the Prime data directory that holds the model parameters
used by the predictor.
- CLASS_NUM - A parameter specific to the predictor. Usually found by
looking through the Prime predictors source code.
- NU - Another model parameter.
- NY - Another model parameter.
In addition, derived classes should implement the following methods:
generateInputFile - Should generate the required input file at
the file described by `input_fname`
prediction - Should read `self.getLogAsString()` and parse out the
actual prediction from the backend
makeCmd - This only needs to be implemented if the backend
takes a command different from the form:
`executable model_fname input_fname`
"""
EXE = NotImplemented
PREDICTOR_NAME = NotImplemented
CLASS_NUM = NotImplemented
NU = NotImplemented
NY = NotImplemented
input = Input()
input_fname: str
model_fname: str
[docs] def prediction(self):
"""
Return the actual prediction. This can take various forms depending
on the predictor.
"""
raise NotImplementedError
[docs] def generateAlignmentFile(self):
"""
Write the alignment file to be used as an input for the predictor. The
file will be a temporary file and will be removed in `.postprocess`.
Gaps in the alignment file are represented as '.'.
"""
self._aln_file = fileutils.tempfilename(suffix=None,
temp_dir=os.getcwd())
self._aln_fname = os.path.basename(self._aln_file)
with open(self._aln_fname, 'w', newline="\n") as aln_file:
aln_file.write(str(len(self.input.aln)) + '\n')
for seq in self.input.aln:
# Predictor alignment files use '.' for gap characters
seq_str = str(seq).replace(seq.gap_char, '.')
aln_file.write(seq_str + '\n')
@tasks.postprocessor
def _cleanUpTmpFiles(self):
self._aln_file.remove()
[docs] def generateModelFile(self):
"""
Generate the model definition file with the name `self.model_fname`.
This is done by finding the Prime data directory for the predictor
and getting the names of all the files in it.
The model file includes a header describing the number of model
files and the predictors class number (`self.CLASS_NUM`), and a
list of the model files.
"""
model_dir = os.path.join(PSP_DATA_DIR, 'predictors',
self.PREDICTOR_NAME)
model_list = [
os.path.join(model_dir, model_file)
for model_file in os.listdir(model_dir)
]
model_list_str = ' \n'.join(sorted(model_list))
model_file_name = self.model_fname
with open(model_file_name, "w", newline="\n") as model_file:
model_file.write(f"{len(model_list)} {self.CLASS_NUM}\n")
model_file.writelines(model_list_str)
model_file.write(' \n')
@tasks.preprocessor
def _generateInputFiles(self):
"""
Generate the input and model files.
"""
self.generateAlignmentFile()
self.generateInputFile()
self.generateModelFile()
[docs] def makeCmd(self):
"""
Return the command to run the predictor backend. The default
implementation returns the predictor executable, the model
file name, and the input file name.
:rtype: list[str]
"""
exe_path = self._getPredictorExe()
return [exe_path, self.model_fname, self.input_fname]
def _getPredictorExe(self):
"""
Find the absolute path of the predictor executable. The executable
is searched for within the `psp` build.
"""
return os.path.join(hunt('psp'), self.EXE)
def _getInputHeader(self):
"""
Return the header for the input file. For most predictors, the header
contains the number of sequences to predict properties for, the
NU parameter, and the NY parameter. All the predictors in the module
predict properties for only one sequence at a time.
Note that this is a convenience function useful for most predictors
but not necessarily /all/ predictors.
"""
num_seqs = 1
model_params = f'{num_seqs} {self.NU} {self.NY}\n'
return model_params
[docs] def postprocess(self):
self._aln_file.remove()
super().postprocess()
[docs]class SsproPredictor(AbstractPredictor):
"""
Secondary structure predictor.
"""
EXE = 'sspro4'
PREDICTOR_NAME = 'sspro'
CLASS_NUM = 3
NU = 20
NY = 3
input_fname: str = 'sspro.inp'
model_fname: str = 'sspro_model.def'
[docs] def makeCmd(self):
"""
Usage: $PSP_PATH/sspro4 model_definition dataset_file alignment_directory dataset_format
"""
# See sspro for more information on the different formats.
# This class only supports format 0 currently.
alignment_directory = './'
dataset_format = '0'
return super().makeCmd() + [alignment_directory, dataset_format]
def _validateStdout(self):
stdout = self.getLogAsString()
split_stdout = stdout.split('\n')
if len(split_stdout) < 2:
err_msg = ("Predictor returned incorrectly formatted output. \n"
"Predictor stdout:\n" + stdout)
raise RuntimeError(err_msg)
if any(c not in SSA_MAP for c in split_stdout[1]):
err_msg = ("Got unexpected character in output.\n"
"Predictor stdout:\n" + stdout)
raise RuntimeError(err_msg)
[docs] def rawPrediction(self):
"""
:return: The raw prediction string containing one character per residue
in the input sequence.
:rtype: str
"""
stdout = self.getLogAsString()
return stdout.split()[1].strip()
[docs] def prediction(self):
"""
:return: A list of ssa types from `structure`, one for each element in
`self.input.sequence`
:rtype: list
"""
self._validateStdout()
ssa = [SSA_MAP[c] for c in self.rawPrediction()]
return ssa
SolventAccessibility = jsonable.JsonableEnum('SolventAccessibility',
'BURIED EXPOSED')
[docs]class AccproPredictor(AbstractPredictor):
"""
Solvent accessibility predictor.
"""
EXE = 'accpro'
PREDICTOR_NAME = 'accpro'
CLASS_NUM = 20
NU = 20
NY = 3
input_fname: str = 'accpro.inp'
model_fname: str = 'accpro_model.def'
CHAR_TO_ACC_MAP = {
'e': SolventAccessibility.EXPOSED,
'b': SolventAccessibility.BURIED
}
[docs] def makeCmd(self):
"""
Usage: $PSP_PATH/accpro model_definition dataset_file alignment_directory dataset_format threshold_index
"""
alignment_directory = './'
dataset_format = '2'
threshold_index = '5'
return super().makeCmd() + [
alignment_directory, dataset_format, threshold_index
]
[docs] def rawPrediction(self):
"""
Example:
eeebbbebebebebbebbebebeebbbbbbbeeeee
e = exposed
b = buried
"""
stdout = self.getLogAsString()
return stdout.strip()
[docs] def prediction(self):
return [self.CHAR_TO_ACC_MAP[c] for c in self.rawPrediction()]
INVERSE_ACC_MAP = {v: k for k, v in AccproPredictor.CHAR_TO_ACC_MAP.items()}
[docs]def encode_acc(acc):
return ''.join(INVERSE_ACC_MAP[ac] for ac in acc)
# TODO: Confirm what these values should be
Disordered = jsonable.JsonableEnum('Disordered',
'HIGHSCORE MEDIUMSCORE LOWSCORE')
[docs]class SsAccDependentPredictors(AbstractPredictor):
"""
Base class for predictors that use secondary structure and solvent
accessibility predictions as inputs.
"""
input = Input()
@tasks.preprocessor(order=-1)
def _generateSsInput(self):
input = self.input
seq = self.input.seq
if input.ss_prediction is None:
no_prediction = (len(seq.pred_secondary_structures) == 1 and
seq.pred_secondary_structures[0][1] is None)
if no_prediction:
pred = predict_secondary_structure(input.seq,
input.aln,
mutate_in_place=False)
input.ss_prediction = pred.rawPrediction()
else:
ssa = (res.pred_secondary_structure for res in seq.residues())
ssa_string = encode_ssa(ssa)
input.ss_prediction = ssa_string
@tasks.preprocessor(order=-1)
def _generateAccInput(self):
input = self.input
seq = input.seq
if input.acc_prediction is None:
acc_predictions = [res.pred_accessibility for res in seq]
if all(p is None for p in acc_predictions):
pred = predict_solvent_accessibility(input.seq, input.aln)
input.acc_prediction = pred.rawPrediction()
else:
acc = (res.pred_accessibility for res in seq.residues())
acc_string = encode_acc(acc)
input.acc_prediction = acc_string
[docs]class DisproPredictor(SsAccDependentPredictors):
"""
Disordered regions predictor.
"""
EXE = 'dispro'
PREDICTOR_NAME = 'dispro'
CLASS_NUM = 2
NU = 25
NY = 2
input_fname: str = 'dispro.inp'
model_fname: str = 'dispro_model.def'
[docs] def makeCmd(self):
alignment_directory = './'
return super().makeCmd() + [alignment_directory]
[docs] def rawPrediction(self):
stdout = self.getLogAsString()
return stdout.split('\n')
[docs] def prediction(self):
print(self.rawPrediction())
disordered_pred = self.rawPrediction()[1].strip()
probabilities = [
float(p) for p in self.rawPrediction()[2].strip().split()
]
pred = []
for dis, prob in zip(disordered_pred, probabilities):
prob = float(prob)
if dis == 'N':
pred.append(Disordered.LOWSCORE)
elif prob > 0.9:
pred.append(Disordered.HIGHSCORE)
else:
pred.append(Disordered.MEDIUMSCORE)
return pred
DomainArrangement = jsonable.JsonableEnum('DomainArrangement',
'Interdomain DomainForming')
[docs]class DomproPredictor(SsAccDependentPredictors):
"""
Domain arrangement predictor.
"""
EXE = 'dompro'
PREDICTOR_NAME = 'dompro'
CLASS_NUM = 2
NU = 25
NY = 3
input_fname: str = 'dompro.inp'
model_fname: str = 'dompro_model.def'
[docs] def makeCmd(self):
alignment_directory = './'
return super().makeCmd() + [alignment_directory]
[docs] def rawPrediction(self):
stdout = self.getLogAsString()
return stdout.split('\n')[1]
[docs] def prediction(self):
return [
DomainArrangement.Interdomain
if c == 'N' else DomainArrangement.DomainForming
for c in self.rawPrediction()
]
[docs]class DiproPredictor(SsAccDependentPredictors):
"""
Disulfide bonds predictor.
"""
EXE = 'dipro'
PREDICTOR_NAME = 'dipro'
CLASS_NUM = 0.5
input_fname: str = 'dipro.inp'
model_fname: str = 'dipro_model.def'
[docs] def makeCmd(self):
"""
Usage: $PSP_PATH/dipro model_file sequence_file alignment_file format
"""
cmd = super().makeCmd()
cmd.extend([self._aln_fname, str(int(self.DiproFormat.NewDipro))])
return cmd
[docs] def prediction(self):
"""
:return: A list of disulfide bonds represented by 2-tuples with
two residue indexes
:rtype: list[tuple[int]]
"""
body_index = 0
stdout = self.getLogAsString()
output = stdout.split('\n')
for idx, line in enumerate(output):
if line.startswith("Bond_Index"):
body_index = idx
break
bonds = []
for line in output[body_index + 1:]:
if line == '':
break
_, idx_a, idx_b = line.strip().split()
# bond idxs are 1-indexed so we subtract 1 here.
bonds.append((int(idx_a) - 1, int(idx_b) - 1))
return bonds
[docs]class BetaproPredictor(AbstractPredictor):
"""
Beta strand contacts predictor
"""
EXE = 'betapro'
PREDICTOR_NAME = 'betapro'
CLASS_NUM = ''
NU = 20
NY = 3
input_fname: str = 'betapro.inp'
model_fname: str = 'betapro_model.def'
input = Input()
[docs] def makeCmd(self):
"""
Usage: $PSP_PATH/betapro model_file, protein_file, alignment_file
"""
cmd = super().makeCmd()
cmd.extend([self._aln_fname])
return cmd
[docs] def prediction(self):
# For now, we just return the stdout. In the future, we'll probably
# have to do some processing depending on how we want to present
# this data.
return self.getLogAsString()
[docs]class PredictorWrapperTask(tasks.BlockingFunctionTask):
"""
Task to run a specific predictor.
"""
[docs] def __init__(self, anno, seq, blast_ann):
super().__init__()
self._pred_func = PRED_ANNO_TO_PRED_FUNC[anno]
self._seq = seq
self._blast_ann = blast_ann
[docs] def mainFunction(self):
self._pred_func(self._seq, self._blast_ann)
def _run_prediction(pred, seq, aln):
pred.input.seq = seq
pred.input.aln = aln
pred.start()
pred.wait()
if pred.status is pred.FAILED:
print(pred.failure_info)
raise pred.failure_info.exception
return pred.prediction()
[docs]def predict_secondary_structure(seq, aln, mutate_in_place=True):
pred = SsproPredictor()
ss_predictions = _run_prediction(pred, seq, aln)
if mutate_in_place:
seq.setSSAPredictions(ss_predictions)
return pred
[docs]def predict_solvent_accessibility(seq, aln, mutate_in_place=True):
pred = AccproPredictor()
acc_predictions = _run_prediction(pred, seq, aln)
if mutate_in_place:
seq.setSolventAccessibilityPredictions(acc_predictions)
return pred
[docs]def predict_disordered_regions(seq, aln, mutate_in_place=True):
pred = DisproPredictor()
dis_predictions = _run_prediction(pred, seq, aln)
if mutate_in_place:
seq.setDisorderedRegionsPredictions(dis_predictions)
return pred
[docs]def predict_domain_arrangement(seq, aln, mutate_in_place=True):
pred = DomproPredictor()
dom_predictions = _run_prediction(pred, seq, aln)
if mutate_in_place:
seq.setDomainArrangementPredictions(dom_predictions)
return pred
[docs]def predict_disulfide_bond(seq, aln, mutate_in_place=True):
pred = DiproPredictor()
disulfide_predictions = _run_prediction(pred, seq, aln)
if mutate_in_place:
for bond in seq.pred_disulfide_bonds:
residue.remove_disulfide_bond(bond)
for (idx1, idx2) in disulfide_predictions:
residue.add_disulfide_bond(seq[idx1], seq[idx2], known=False)
seq.predictionsChanged.emit()
return pred
PRED_ANNO_TO_PRED_FUNC = {
SEQ_ANNO_TYPES.pred_secondary_structure: predict_secondary_structure,
SEQ_ANNO_TYPES.pred_accessibility: predict_solvent_accessibility,
SEQ_ANNO_TYPES.pred_domain_arr: predict_domain_arrangement,
SEQ_ANNO_TYPES.pred_disordered: predict_disordered_regions,
SEQ_ANNO_TYPES.pred_disulfide_bonds: predict_disulfide_bond
}
INVERSE_SSA_MAP = {v: k for k, v in SSA_MAP.items()}
[docs]def encode_ssa(ssa):
return ''.join(INVERSE_SSA_MAP[ss] for ss in ssa)