import os
from collections import defaultdict
from typing import List
from schrodinger.application.msv import seqio
from schrodinger.job.util import hunt
from schrodinger.models import jsonable
from schrodinger.models import parameters
from schrodinger.protein import alignment
from schrodinger.protein import sequence
from schrodinger.tasks import jobtasks
from schrodinger.tasks import tasks
BlastAlgorithm = jsonable.JsonableEnum('BlastAlgorithm', 'BLAST PSIBLAST')
BlastDatabase = jsonable.JsonableEnum('BlastDatabase', 'PDB NR CUSTOM')
SimilarityMatrix = jsonable.JsonableEnum(
    'SimilarityMatrix', 'BLOSUM45 BLOSUM62 BLOSUM80 PAM30 PAM70')
LOCAL = 'local'
REMOTE = 'remote'
[docs]class NoBlastHitsError(RuntimeError):
    """
    Custom exception if no blast hits are returned
    """ 
def _create_inverse_dict(k_v_dict):
    """
    Given a dict with a set of values for each key, create an inverse dict
    where the values point to a set of keys
    """
    v_k_dict = defaultdict(set)
    for k, v_set in k_v_dict.items():
        for v in v_set:
            v_k_dict[v].add(k)
    return dict(v_k_dict)
def _find_closest(x, values):
    # This is O(N) but N is small
    return min(values, key=lambda v: abs(v - x))
[docs]class BlastSettings(parameters.CompoundParam):
    progname: BlastAlgorithm
    location: str = LOCAL
    database_name: BlastDatabase
    custom_database_path: str
    word_size: int = 3
    filter_query: bool = False
    gap_open_cost: int = 11
    gap_extend_cost: int = 1
    possible_gap_open: set
    allowed_gap_extend: set
    similarity_matrix: SimilarityMatrix = SimilarityMatrix.BLOSUM80
    num_iterations: int = 3
    evalue_threshold: int = 1
    inclusion_threshold: float = 0.005
    allow_multiple_chains: bool = True
    download_structures: bool = False
    align_after_download: bool = False
    DEFAULTS = {
        SimilarityMatrix.BLOSUM45: (15, 2),
        SimilarityMatrix.BLOSUM62: (9, 2),
        SimilarityMatrix.BLOSUM80: (10, 1),
        SimilarityMatrix.PAM30: (9, 1),
        SimilarityMatrix.PAM70: (10, 1),
    }
    EXT_OPEN_VALUES = {
        SimilarityMatrix.BLOSUM45: {
            3: {13, 12, 11, 10},
            2: {15, 14, 13, 12},
            1: {19, 18, 17, 16}
        },
        SimilarityMatrix.BLOSUM62: {
            2: {11, 10, 9, 8, 7, 6},
            1: {13, 12, 11, 10, 9}
        },
        SimilarityMatrix.BLOSUM80: {
            2: {8, 7, 6},
            1: {11, 10, 9}
        },
        SimilarityMatrix.PAM30: {
            2: {7, 6, 5},
            1: {10, 9, 8}
        },
        SimilarityMatrix.PAM70: {
            2: {8, 7, 6},
            1: {11, 10, 9}
        },
        # Remote blast allows more combinations than local blast but for
        # simplicity we only allow local-compatible settings
        #SimilarityMatrix.PAM30: {3: {15, 13}, 2: {14, 7, 6, 5}, 1: {14, 10, 9, 8}},
        #SimilarityMatrix.PAM70: {3: {12}, 2: {11, 8, 7, 6}, 1: {11, 10, 9}},
    }
    OPEN_EXT_VALUES = {
        matrix: _create_inverse_dict(ext_open)
        for matrix, ext_open in EXT_OPEN_VALUES.items()
    }
[docs]    def initConcrete(self):
        super().initConcrete()
        self.similarity_matrixChanged.connect(self._setMatrixDefaults)
        self.gap_open_costChanged.connect(self._updateAllowedExtendValues) 
[docs]    def initializeValue(self):
        super().initializeValue()
        self._setMatrixDefaults() 
[docs]    @classmethod
    def fromJsonImplementation(cls, json_obj):
        new_input = super().fromJsonImplementation(json_obj)
        # These params can affect each other so they need to be set again in
        # the correct order to reconstitute correctly
        order_dependent_params = ('gap_open_cost', 'gap_extend_cost')
        for param_name in order_dependent_params:
            setattr(new_input, param_name, json_obj[param_name])
        return new_input 
[docs]    def getBlastSettings(self):
        """
        Returns BLAST settings as a dictionary.
        :return: Dictionary of BLAST settings.
        :rtype: dict
        """
        custom_path_undefined = (self.database_name == BlastDatabase.CUSTOM and
                                 not self.custom_database_path.strip())
        custom_option_unselected = (self.database_name != BlastDatabase.CUSTOM
                                    and self.custom_database_path.strip())
        if custom_path_undefined or custom_option_unselected:
            raise ValueError("Choose the correct database option and"
                             " provide path to the custom database")
        database = (self.custom_database_path
                    if self.database_name == BlastDatabase.CUSTOM else
                    self.database_name.name.lower())
        # these settings will be passed as arguments to BlastPlus with only a
        # leading hyphen added
        settings = {
            'progname': 'blastp' if self.progname is BlastAlgorithm.BLAST else
                        'psiblast',
            'location': self.location,
            'database': database,
            'word_size': self.word_size,
            'filter': 'yes' if self.filter_query else 'no',
            'gap_open_cost': self.gap_open_cost,
            'gap_extend_cost': self.gap_extend_cost,
            'matrix': self.similarity_matrix.name,
            'num_iterations': self.num_iterations,
            'evalue': 10**self.evalue_threshold,
            'expand_hits': self.allow_multiple_chains,
        }
        if self.location == REMOTE or self.progname is BlastAlgorithm.PSIBLAST:
            settings['e_value_threshold'] = self.inclusion_threshold
        settings = {k: str(v) for k, v in settings.items()}
        return settings 
    def _setMatrixDefaults(self):
        self._updateOpenValues()
        gap_open, gap_extend = self.DEFAULTS[self.similarity_matrix]
        self._updateAllowedExtendValues(gap_open)
        self.gap_open_cost = gap_open
        self.gap_extend_cost = gap_extend
    def _updateOpenValues(self):
        values = set(self.OPEN_EXT_VALUES[self.similarity_matrix].keys())
        self.possible_gap_open = values
        if self.gap_open_cost not in values:
            self.gap_open_cost = _find_closest(self.gap_open_cost, values)
    def _updateAllowedExtendValues(self, gap_open):
        values = self.OPEN_EXT_VALUES[self.similarity_matrix].get(gap_open)
        if values is None:
            values = set(self.EXT_OPEN_VALUES[self.similarity_matrix].keys())
        if self.gap_extend_cost not in values:
            self.gap_extend_cost = _find_closest(self.gap_extend_cost, values)
        self.allowed_gap_extend = set(values) 
[docs]class BlastTask(jobtasks.ComboJobTask):
    """
    This is a thin wrapper over BlastPlus object that implements job running
    and incorporation.
    To enable DEBUG_MODE, set DEBUG_MODE to True at the bottom of this class.
    In DEBUG_MODE, no blast call will actually be made and the first top
    10 hits of a BLAST search with 1cmy:a will be returned as the output.
    """
    DEFAULT_TASKDIR_SETTING = tasks.TEMP_TASKDIR
    PROGRAM_NAME = "BLAST"
    output: List[dict]  # List of blast hits
[docs]    def getExpectedRuntime(self):
        """
        Return the expected runtime of the task based on the current settings
        :return: Expected runtime in seconds
        :rtype: int
        """
        settings = self.input.settings
        if settings.getBlastSettings()['location'] == REMOTE:
            minutes = 30
        elif settings.database_name is BlastDatabase.NR:
            minutes = 3 * 60  # Local NR is expected to take ~2.5 hours
        else:
            minutes = 5
        return minutes * 60 
[docs]    def getQueryName(self):
        return self.input.getQueryName() 
[docs]    def checkLocalDatabase(self):
        """
        Return True if the local database exists and is correctly configured.
        Return False if the local database is missing or truncated.
        """
        blast_plus = self._initBlastPlus()
        try:
            has_local_db = blast_plus.checkLocalBlastInstallation()
        except RuntimeError:
            return False
        else:
            return has_local_db 
    def _initBlastPlus(self):
        """
        Create and return a BlastPlus object based on the current settings
        """
        options = []
        for key, value in self.input.settings.getBlastSettings().items():
            options.append('-' + key)
            options.append(value)
        from schrodinger.application.prime.packages import blast_plus
        parsed_options = blast_plus.blast_parser().parse_args(options)
        return blast_plus.BlastPlus(parsed_options)
[docs]    def backendMain(self):
        self._pdb_header_info = {}
        blast_plus = self._initBlastPlus()
        hits = blast_plus.runBlast(seqio.to_biopython(
            self.input.query_sequence))
        if not hits:
            raise NoBlastHitsError("No relevant BLAST results were found.")
        self.output = self._parseHits(hits) 
[docs]    def getBlastAlignment(self):
        if self.status is self.FAILED:
            raise ValueError("Cannot get alignment. Blast task failed.")
        elif self.status is not self.DONE:
            raise ValueError("Can't get the blast alignment before the task "
                             "is run.")
        seqs = [self.input.query_sequence]
        seqs.extend(
            sequence.ProteinSequence(hit['sequence']) for hit in self.output)
        return alignment.ProteinAlignment(seqs) 
    def _parseHits(self, hits):
        """
        :return: A list of parsed hits.
        :rtype: list of dict
        """
        if not self._pdb_header_info:
            self._initializePDBHeaderInfo()
        hit_list = []
        from schrodinger.application.prime.packages import blast_plus
        for hit in hits:
            name, description = self._getNameAndDescription(hit)
            title = hit.alignment.title
            info = blast_plus.get_info_from_title(title)
            database = blast_plus.get_database_from_title(title)
            for hsp in hit.alignment.hsps:
                codes = str(hit.seq_io.seq)
                percent_id = 100 * hsp.identities / hsp.align_length
                percent_pos = 100 * hsp.positives / hsp.align_length
                percent_gaps = 100 * hsp.gaps / hsp.align_length
                info = self._pdb_header_info.get(name, {})
                hit_dict = {
                    'name': name,
                    'info': info,
                    'database': database,
                    'sequence': codes,
                    'score': hsp.score,
                    'evalue': hsp.expect,
                    'percent_id': percent_id,
                    'percent_pos': percent_pos,
                    'percent_gaps': percent_gaps,
                    'pdb_title': description,
                    'pdb_compound': info.get('COMPND:', ''),
                    'pdb_source': info.get('SOURCE:', ''),
                    'pdb_expdata': info.get('EXPDTA:', ''),
                    'pdb_resolution': info.get('RESOLUTION:', ''),
                    'pdb_hetname': info.get('HETNAM:', ''),
                    'pdb_pfam': info.get('PFAM:', ''),
                }
                hit_list.append(hit_dict)
        return hit_list
    @staticmethod
    def _getNameAndDescription(hit):
        """
        Extract the hit's name and description.
        :param hit: Object representing a single BLAST hit
        :type hit: blast_plus.BlastHit
        """
        from schrodinger.application.prime.packages import blast_plus
        # When expand_hits is True, each hit has a reference to an identical
        # alignment object; only the biopython sequence description is updated
        # with the single sequence's name
        full_description = hit.seq_io.description
        name = blast_plus.get_name_from_title(full_description)
        # Get description of protein from seq description
        # hsp: high scoring pair
        first_hsp = full_description.split('>', 1)[0]
        description = first_hsp.split(' ', 1)[1]
        description = description.strip()
        return name, description
    @staticmethod
    def _getPDBHeaderFileName():
        """
        Find a path to PDB header info file.
        :return: Path to the PDB header info file, or empty string if not found.
        :rtype: str
        """
        psp_data_dir = hunt('psp', 'data')
        if psp_data_dir:
            header_file_name = os.path.join(psp_data_dir, "headerinfo.dat")
            if os.path.isfile(header_file_name):
                return header_file_name
        return ""
    def _initializePDBHeaderInfo(self):
        """
        Initialize PDB header info. This fills up the pdb_header_info
        dictionary.
        """
        self._pdb_header_info = {}
        pdb_header_file_name = self._getPDBHeaderFileName()
        if not pdb_header_file_name:
            return
        lines = []
        with open(pdb_header_file_name, "r") as header_file:
            lines = header_file.readlines()
        pdb_id = ""
        for line in lines:
            names = line.split(' ', 1)
            if len(names) > 1:
                key = names[0]
                text = names[1].rstrip()
                if key == "ID:":
                    pdb_id = text
                    self._pdb_header_info[pdb_id] = {}
                elif pdb_id:
                    self._pdb_header_info[pdb_id][key] = text