# Copyright Schrodinger, LLC. All rights reserved.
import copy
import glob
import os
import subprocess
from . import constants
from .jobs import blast_get_default_settings
from .jobs import blast_run
from .sequence import Sequence
from .sequence_group import SequenceGroup
predictors_initialized = False
[docs]def init_predictors():
    global predictors_initialized
    if predictors_initialized:
        return
    os_cpu = os.getenv("OS_CPU")
    if not os_cpu:
        mmshare_exec = os.getenv("MMSHARE_EXEC")
        if mmshare_exec:
            first, os_cpu = os.path.split(mmshare_exec)
    schrodinger = os.getenv("SCHRODINGER")
    if not os_cpu or not schrodinger:
        return False
    # Get Prime library path
    pattern = os.path.join(schrodinger, 'psp-v*', 'lib', os_cpu)
    path_list = glob.glob(pattern)
    if not path_list:
        return False
    if os.name == 'nt' or os_cpu.startswith("Darwin"):
        lib_path = os.getenv("PATH")
        lib_path += ':' + path_list[0]
        os.environ["PATH"] = lib_path
    else:
        lib_path = os.getenv("LD_LIBRARY_PATH")
        lib_path += ':' + path_list[0]
        os.environ["LD_LIBRARY_PATH"] = lib_path
    predictors_initialized = True
    return True 
[docs]def get_predictor_path(predictor_name):
    os_cpu = os.getenv("OS_CPU")
    if not os_cpu:
        mmshare_exec = os.getenv("MMSHARE_EXEC")
        if mmshare_exec:
            first, os_cpu = os.path.split(mmshare_exec)
    schrodinger = os.getenv("SCHRODINGER")
    if not os_cpu or not schrodinger:
        return None
    pattern = os.path.join(schrodinger, 'psp-v*', 'bin', os_cpu, predictor_name)
    if os.name == "nt":
        pattern += ".exe"
    path_list = glob.glob(pattern)
    if not path_list:
        return None
    return path_list[0] 
[docs]def get_predictor_data_path(predictor_name):
    os_cpu = os.getenv("OS_CPU")
    if not os_cpu:
        mmshare_exec = os.getenv("MMSHARE_EXEC")
        if mmshare_exec:
            first, os_cpu = os.path.split(mmshare_exec)
    schrodinger = os.getenv("SCHRODINGER")
    if not os_cpu or not schrodinger:
        return None
    pattern = os.path.join(schrodinger, 'psp-v*', 'data', 'predictors',
                           predictor_name)
    path_list = glob.glob(pattern)
    if not path_list:
        return None
    return path_list[0] 
[docs]def make_model_def_file(model_data_dir, model_file_name, predictor):
    model_list = []
    for model in glob.glob(os.path.join(model_data_dir, "*.model")):
        model_list.append(model + " \n")
    if not model_list:
        return -1
    try:
        model_file = open(model_file_name, "w")
    except:
        return -2
    if predictor == "sspro":
        model_file.write(str(len(model_list)) + " 3\n")
    elif predictor == "accpro":
        model_file.write(str(len(model_list)) + " 20\n")
    elif predictor == "dispro":
        model_file.write(str(len(model_list)) + " 2\n")
    elif predictor == "dompro":
        model_file.write(str(len(model_list)) + " 2\n")
    elif predictor == "dipro":
        model_file.write(str(len(model_list)) + " 0.5\n")
    elif predictor == "betapro":
        model_file.write(str(len(model_list)) + "\n")
    model_list.sort()
    try:
        model_file.writelines(model_list)
        model_file.close()
    except:
        return -3
    return 0 
[docs]def generate_flat_alignment(sequence,
                            alignment_file_name,
                            progress_dialog=None,
                            remote_query_dialog=None):
    """
    Runs a BLAST job on nr database in order to generate a flat alignment file
    required by the predictors. In case of failure it just writes the input
    sequence to the alignment file.
    """
    group = SequenceGroup()
    group.sequences.append(sequence)
    group.reference = sequence
    settings = blast_get_default_settings()
    # We can use PDB, but it should be a user-controllable setting
    # Will implement for 2012
    settings["database"] = "nr"
    result = blast_run(group,
                       settings,
                       progress_dialog=progress_dialog,
                       quiet=True,
                       remote_query_dialog=remote_query_dialog)
    if result != "ok":
        settings["remotely"] = True
        result = blast_run(group,
                           settings,
                           progress_dialog=progress_dialog,
                           remote_query_dialog=remote_query_dialog)
    if result == "ok":
        group.minimizeAlignment(query_only=True)
        sequences = []
        for sequence in group.sequences:
            if sequence.isValidProtein():
                sequences.append(sequence.text())
        if not sequences:
            return False
        try:
            alignment_file = open(alignment_file_name, "w")
        except:
            return False
        alignment_file.write(str(len(sequences)) + "\n")
        for seq in sequences:
            seq = seq.replace('~', '.')
            alignment_file.write(seq + "\n")
        alignment_file.close()
        return True
    else:
        return False 
[docs]def has_predictor(predictor):
    predictor_exe = predictor
    # Proper SSpro executable is "sspro4", actually.
    if predictor == "sspro":
        predictor_exe += '4'
    # Get predictor executable and data paths.
    exe_path = get_predictor_path(predictor_exe)
    data_path = get_predictor_data_path(predictor)
    if exe_path and data_path:
        return True
    return False 
[docs]def run(sequence,
        predictor_list,
        progress_dialog=None,
        remote_query_dialog=None):
    if not sequence:
        return "failed"
    init_predictors()
    sorted_predictor_list = []
    sequence_string = sequence.gaplessText()
    if not sequence_string:
        return "failed"
    c_count = sequence_string.count('C')
    if len(predictor_list) == 1 and \
            
predictor_list[0] == "dipro" and c_count < 2:
        return "cancelled"
    # We have to run these two predictors anyway, these are
    # required by the others, unless we are calling sspro
    # or accpro, which can be run on theirs own.
    if len(predictor_list) > 1 or "accpro" not in predictor_list:
        sorted_predictor_list.append("sspro")
    if len(predictor_list) > 1 or "sspro" not in predictor_list:
        sorted_predictor_list.append("accpro")
    if "dompro" in predictor_list:
        sorted_predictor_list.append("dompro")
    if "dipro" in predictor_list:
        sorted_predictor_list.append("dipro")
    if "dispro" in predictor_list:
        sorted_predictor_list.append("dispro")
    if "betapro" in predictor_list:
        sorted_predictor_list.append("betapro")
    results = {}
    alignment_file_name = "predictor_alignment"
    pred_ss = pred_acc = None
    has_alignment = False
    # List of temporary job files
    tmp_files = []
    for predictor in sorted_predictor_list:
        # Ignore S-S prediction if not enough cysteines present
        if predictor == "dipro" and c_count < 2:
            continue
        predictor_exe = predictor
        # Proper SSpro executable is "sspro4", actually.
        if predictor == "sspro":
            predictor_exe += '4'
        # Get predictor executable and data paths.
        exe_path = get_predictor_path(predictor_exe)
        data_path = get_predictor_data_path(predictor)
        if not exe_path or not data_path:
            continue
        # Prepare a model file.
        model_file_name = predictor + "_model.def"
        result = make_model_def_file(data_path, model_file_name, predictor)
        if result != 0:
            continue
        tmp_files.append(model_file_name)
        # Write input file
        input_file_name = write_input_file(predictor,
                                           alignment_file_name,
                                           model_file_name,
                                           sequence_string,
                                           pred_ss=pred_ss,
                                           pred_acc=pred_acc)
        if not input_file_name:
            continue
        tmp_files.append(input_file_name)
        cmd = []
        if predictor == "betapro":
            cmd += [
                exe_path, model_file_name, input_file_name, alignment_file_name
            ]
        elif predictor == "dipro":
            cmd += [
                exe_path, model_file_name, input_file_name, alignment_file_name,
                '2'
            ]
        else:
            cmd += [exe_path, model_file_name, input_file_name, "." + os.sep]
            if predictor == "sspro":
                cmd.append("0")
            elif predictor == "accpro":
                cmd += ["2", "5"]
            elif predictor in ("dispro", "dompro"):
                pass
        if not has_alignment:
            sequence_copy = copy.deepcopy(sequence)
            has_alignment = generate_flat_alignment(
                sequence_copy,
                alignment_file_name,
                progress_dialog=progress_dialog,
                remote_query_dialog=remote_query_dialog)
            if has_alignment:
                tmp_files.append(alignment_file_name)
        if not has_alignment:
            return "cancelled"
        try:
            with open(predictor + ".out", 'w') as f:
                subprocess.call(cmd, stdout=f)
        except:
            continue
        try:
            prediction_file = open(predictor + ".out")
        except:
            continue
        result = prediction_file.readlines()
        prediction_file.close()
        tmp_files.append(predictor + ".out")
        if result:
            results[predictor] = result
        else:
            continue
        if predictor == "sspro":
            pred_ss = results[predictor][1].strip()
        elif predictor == "accpro":
            pred_acc = results[predictor][0].strip()
    final_sequences = []
    if "sspro" in predictor_list and "sspro" in results:
        pred = results["sspro"][1].strip()
        seq = Sequence()
        seq.type = constants.SEQ_ANNOTATION
        seq.annotation_type = constants.ANNOTATION_SSP
        seq.appendResidues(pred)
        seq.name = "Secondary Structure Prediction"
        seq.short_name = "SSP"
        for res in seq.residues:
            if res.code == 'H':
                res.color = (240, 96, 64)
                res.inverted = True
            elif res.code == 'E':
                res.color = (128, 240, 240)
                res.inverted = True
            else:
                res.code = '-'
                res.color = (255, 255, 255)
                res.inverted = True
        final_sequences.append(seq)
    if "accpro" in predictor_list and "accpro" in results:
        pred = results["accpro"][0].strip()
        seq = Sequence()
        seq.type = constants.SEQ_ANNOTATION
        seq.annotation_type = constants.ANNOTATION_ACC
        seq.appendResidues(pred)
        seq.name = "Solvent Accessibility Prediction"
        seq.short_name = "ACC"
        for res in seq.residues:
            if res.code == 'B':
                res.code = 'b'
                res.color = (64, 127, 191)
                res.inverted = True
            elif res.code == 'E':
                res.code = 'e'
                res.color = (255, 255, 64)
                res.inverted = True
        final_sequences.append(seq)
    if "dompro" in results:
        pred = results["dompro"][1].strip()
        seq = Sequence()
        seq.type = constants.SEQ_ANNOTATION
        seq.annotation_type = constants.ANNOTATION_DOM
        seq.appendResidues(pred)
        seq.name = "Domain Arrangement Prediction"
        seq.short_name = "DOM"
        final_sequences.append(seq)
        for res in seq.residues:
            if res.code == 'N':
                res.color = (127, 127, 127)
                res.inverted = True
            else:
                res.color = (255, 0, 0)
                res.inverted = True
            res.code = ' '
    if "dispro" in results:
        pred = results["dispro"][1].strip()
        values = results["dispro"][2].strip().split(' ')
        seq = Sequence()
        seq.type = constants.SEQ_ANNOTATION
        seq.annotation_type = constants.ANNOTATION_DIS
        seq.appendResidues(pred)
        seq.name = "Disordered Regions Prediction"
        seq.short_name = "DIS"
        for index, res in enumerate(seq.residues):
            val = float(values[index])
            res.value = val
            if res.code == 'N':
                res.color = (191, 191, 191)
                res.inverted = True
            else:
                if val > 0.9:
                    res.color = (255, 0, 0)
                else:
                    res.color = (255, 127, 0)
                res.inverted = True
            res.code = ' '
        final_sequences.append(seq)
    if "dipro" in results:
        read_bonds = False
        bond_list = []
        for line in results["dipro"]:
            if line.startswith("Bond_Index"):
                read_bonds = True
                continue
            if read_bonds:
                bonds = line.split()
                if len(bonds) >= 3:
                    bond_list.append((int(bonds[1]) - 1, int(bonds[2]) - 1))
        if bond_list:
            ss_str = "".join('A' for c in pred_ss)
            seq = Sequence()
            seq.type = constants.SEQ_ANNOTATION
            seq.annotation_type = constants.ANNOTATION_CCB
            seq.appendResidues(ss_str)
            seq.name = "Disulfide Bridges Prediction"
            seq.short_name = "SSBPRED"
            seq.bond_list = bond_list
            seq.height = int(0.5 * len(bond_list) + 0.5)
            for index, res in enumerate(seq.residues):
                res.code = ' '
            final_sequences.append(seq)
    if "betapro" in results and pred_ss:
        beta_p = []
        n_beta = 0
        length = sequence.gaplessLength()
        for index in range(length):
            if pred_ss[index] == 'E':
                beta_p.append(n_beta)
                n_beta += 1
            else:
                beta_p.append(-1)
        if n_beta > 0:
            tmp_map = []
            lines = results["betapro"]
            for line in lines:
                values = line.split()
                for val in values:
                    tmp_map.append(float(val))
            beta_map = []
            length = sequence.gaplessLength()
            for y in range(length):
                for x in range(length):
                    if beta_p[y] >= 0 and beta_p[x] >= 0:
                        beta_map.append(1.0 -
                                        tmp_map[beta_p[y] * n_beta + beta_p[x]])
                    else:
                        beta_map.append(1.0)
            sequence.beta_map = beta_map
    try:
        for file in tmp_files:
            os.remove(file)
    except:
        pass
    return final_sequences