"""
Core routines for wrapping SKA backend.
This module provides direct access to the functionality used by the
top-level ska program and structalign utility.  The usual keyword-value
options in the top-level input file are supplied as a python dict.
Wrappers are supplied to provide PDB filenames as input as well as
maestro CTs which are transformed in place.
For higher level API, use the schrodinger.structutils.structalign module.
Copyright Schrodinger, LLC. All rights reserved.
"""
import operator
import os
import re
import shutil
import tempfile
import numpy
from schrodinger import structure
from schrodinger.infra import _ska
from schrodinger.infra import mm
from schrodinger.job.util import hunt
from schrodinger.structutils import analyze
from schrodinger.ui.sequencealignment import constants as seqdata
from schrodinger.utils.fileutils import force_remove
_true_val = re.compile(r'^(1|TRUE|YES)$', re.I)
_atom_rec = re.compile(
    r'^((ATOM  |HETATM).{24})([ -.\d]{8})([ -.\d]{8})([ -.\d]{8})(.*)$')
_winlen_vals = [2, 6, 9, 14, 17]
_trolltop = None
SKA_TRANSFORMATION_MATRIX = "s_psp_ska_transformation_matrix"
class _SkaData(object):
    """
    Data structure for named backend return values
        rotmat: list of (U,x) rotation matrices and translations
        align: aligned sequences as a python dict
        psd: backend PSD score
        rmsd: superposition RMSD
        stdout: raw output from the backend
        sse: dict of secondary structures
    """
    def __init__(self, rotmat, align, index, psd, rmsd, stdout, sse=None):
        self.rotmat = rotmat
        self.align = align
        self.index = index
        self.psd = psd
        self.rmsd = rmsd
        self.stdout = stdout
        self.sse = sse
    def setTransformationMatrixProperty(self, st, index=0):
        """
        Sets ct-level properties that correspond to the 3x3 rotation and 3x1
        translation in the transformation of this data structure. For SkaData
        with more than one alignment structure, an index can be provided to
        obtain a specific structure's transformation.
        :param st: structure to add properties to
        :type st: structure.Structure
        :param index: index of the transformation to extract
        :type index: int
        """
        if not self.rotmat or self.rotmat == [None]:
            return
        matrix = numpy.identity(4)
        rotation_matrix, translation_vector = self.rotmat[index]
        matrix[0:3, 0:3] = rotation_matrix
        matrix[0:3, 3] = translation_vector
        st.property[SKA_TRANSFORMATION_MATRIX] = ';'.join(map(str, matrix.flat))
def _standardize(old_ct, rename, reorder):
    """
    Return the standardized version of the input structure, without modifying
    the input CT.
    :type rename: bool
    :param rename: Whether to convert the non-standard residue names to
            standard forms where possible (e.g. HID -> HIS).
    :type reorder: bool
    :param reorder: Whether to re-order the residues by sequence.
    If both options are False, will return the original input structure, without
    making a copy.
    """
    if not rename and not reorder:
        return old_ct
    if reorder:
        new_ct = structure.create_new_structure(0)
        mm.mmpdb_initialize(mm.error_handler)
        index_map = mm.mmpdb_get_sequence_order(old_ct.handle, 0)
        mm.mmpdb_terminate()
        mm.mmct_ct_reorder(new_ct, old_ct, index_map)
    else:
        # Re-naming but not re-ordering; make a copy:
        new_ct = old_ct.copy()
    if rename:
        for res in new_ct.residue:
            rname = res.pdbres[0:3]
            rcode = seqdata.AMINO_ACIDS_3_TO_1.get(rname, 'X')
            alias = seqdata.AMINO_ACIDS[rcode][0]
            if rname != alias:
                for at in res.atom:
                    at.pdbres = alias
    return new_ct
[docs]def pairwise_align_ct(query,
                      templist,
                      keywords=None,
                      log=None,
                      debug=False,
                      save_props=False,
                      std_res=False,
                      reorder=False,
                      asl=None,
                      asl_mobile=None):
    """
    Wrapper for a series of pairwise ska jobs with CTs as input.
    NOTE: For higher-level API, use schrodinger.structutils.structalign module.
    :param query: fixed reference structure
    :type query: (seqname, ct) tuple
    :param templist: one or more structures to be aligned
    :type templist: list of (seqname, ct) tuples
    :param keywords: top-level keyword-value input options (other than QUERY_FILE and TEMPLATE)
    :type keywords: dict
    :param log: active logger for diagnostic messages
    :type log: logging.logger
    :param debug: debug flag passed to backend
    :type debug: bool
    :param save_props: True if output data included as CT properties
    :type save_props: bool
    :param std_res: True if residue names are translated to standard forms
    :type std_res: bool
    :param reorder: True if residues need to be reordered by connectivity
    :type reorder: bool
    :param asl: ASL for reference substructure to align
    :type asl: str
    :param asl_mobile: ASL for mobile substructures to align
    :type asl_mobile: str
    :return: list of alignments corresponding to templist
    :rtype: list of `_SkaData`
    """
    # FIXME factor out common code with multiple_align_ct()
    if (keywords is None):
        keywords = {}
    qtag, qct = query
    q_backup_title = qct.title
    qct.title = qtag
    std_qct = _standardize(qct, std_res, reorder)
    if asl:
        align_atoms = analyze.evaluate_asl(std_qct, asl)
        if not align_atoms:
            st_str = "query structure: %s" % qtag
            raise ValueError("No atoms matched the ASL in the %s" % st_str)
        std_qct = std_qct.extract(align_atoms, copy_props=True)
    results = []
    for tag, ct in templist:
        t_backup_title = ct.title
        ct.title = tag
        std_tct = _standardize(ct, std_res, reorder)
        if asl_mobile:
            align_atoms = analyze.evaluate_asl(std_tct, asl_mobile)
            if not align_atoms:
                st_str = "template structure: %s" % tag
                raise ValueError("No atoms matched the ASL in the %s" % st_str)
            std_tct = std_tct.extract(align_atoms, copy_props=True)
        try:
            output = run_align([std_qct, std_tct], keywords, log, debug)
            (u, x) = output.rotmat[0]
            # NOTE when -asl is used, "ct" still has the original full structure
            transform_structure(ct, u, x)  # perform in-place transformation
        except Exception as err:
            raw_output = '\n\nWARNING: Skipping failed alignment: %s, %s\n' % (
                qct.title, ct.title)
            raw_output += "  Exception: %s\n\n" % err
            output = _SkaData([None], {}, {}, None, None, raw_output)
        if save_props:
            if output.psd:
                ct.property['r_psp_StructAlign_Score'] = output.psd
            if output.rmsd:
                ct.property['r_psp_StructAlign_RMSD'] = output.rmsd
            ct.property['s_psp_StructAlign_Reference'] = qtag
            output.setTransformationMatrixProperty(ct)
        results.append(output)
        ct.title = t_backup_title
    qct.title = q_backup_title
    return results 
[docs]def multiple_align_ct(query,
                      templist,
                      keywords=None,
                      log=None,
                      debug=False,
                      save_props=False,
                      std_res=False,
                      reorder=False):
    """
    Wrapper for a single multiple-alignment ska job with CTs as input.
    NOTE: For higher-level API, use schrodinger.structutils.structalign module.
    Arguments are the same as in `pairwise_align_ct` above.
    :return: a single multiple alignment
    :rtype: `_SkaData`
    """
    # TODO Add asl option
    # FIXME factor out common code with pairwise_align_ct()
    if (keywords is None):
        keywords = {}
    qtag, qct = query
    q_backup_title = qct.title
    qct.title = qtag
    std_qct = _standardize(qct, std_res, reorder)
    strucs = [std_qct]
    for tag, ct in templist:
        std_tct = _standardize(ct, std_res, reorder)
        std_tct.title = tag
        strucs.append(std_tct)
    output = run_align(strucs, keywords, log, debug)
    # carry out corresponding rotation for each template
    for i, ((u, x), (tag, ct)) in enumerate(zip(output.rotmat, templist)):
        transform_structure(ct, u, x)  # perform in-place transformation
        if save_props:
            if output.psd:
                ct.property['r_psp_StructAlign_Score'] = output.psd
            if output.rmsd:
                ct.property['r_psp_StructAlign_RMSD'] = output.rmsd
            ct.property['s_psp_StructAlign_Reference'] = qtag
            output.setTransformationMatrixProperty(ct, i)
    qct.title = q_backup_title
    return output 
[docs]def run_align(structs, keywords, log=None, debug=False):
    """
    Core driver routine, equivalent to `$SCHRODINGER/ska` with all top-level
    input options specified in `keywords`. Returns alignment, scores, and
    rotation matrices as well as backend stdout.
    NOTE: For higher-level API, use schrodinger.structutils.structalign module.
    :param structs: structures to align. First in list is query; the rest of
            CTs will be aligned to it (templates).
    :type structs: list(schrodinger.structure.Structure)
    :param keywords: top-level keyword-value input options
    :type keywords: dict
    :param log: active logger for diagnostic messages
    :type log: logging.logger
    :param debug: debug flag passed to backend
    :type debug: bool
    :return: a single multiple alignment
    :rtype: `_SkaData`
    """
    global _trolltop
    # Swap the titles of CTs with reasonable substitutes, if there are problems
    need_swap = False
    used_titles = set()
    for ct in structs:
        if (len(ct.title) >= 40 or " " in ct.title or ct.title in used_titles):
            need_swap = True
        used_titles.add(ct.title)
    if (need_swap):
        original_titles = dict()
        for i, ct in enumerate(structs):
            ct.property['s_psp_ska_orig_title'] = ct.title
            ct.title = "ct" + str(i)
            original_titles[ct.title] = ct.property['s_psp_ska_orig_title']
    # Clean up input structures: eliminate chains which contain no amino acids
    # and would be invalid inputs to SKA
    aminoacid = [
        "ALA ", "ARG ", "ASN ", "ASP ", "CYS ", "GLN ", "GLU ", "GLY ", "HIS ",
        "ILE ", "LEU ", "LYS ", "MET ", "PHE ", "PRO ", "SER ", "THR ", "TRP ",
        "TYR ", "VAL "
    ]
    # Make a list of CT indices and titles of strutures with no valid residues:
    poplist = []
    poptitles = []
    for i, CT in enumerate(structs):
        valid = False
        for res in CT.residue:
            if (aminoacid.__contains__(res.pdbres)):
                valid = True
                break
        if (not valid):
            poplist.append(i)
            poptitles.append(CT.title)
            if (i == 0):
                raise IOError("Query structure has no valid amino acid " +
                              "residues.")
    # Remove structures that have no valid residues:
    npop = len(poplist)
    for i in range(npop - 1, -1, -1):
        structs.pop(poplist[i])
    # for backwards compatibility even though no other formats
    # are implemented
    if keywords.get('ALIFORMAT', 'native') != 'native':
        raise RuntimeError('Unknown value specified for alignment format')
    if len(structs) < 2:
        raise RuntimeError('Query and template files not specified')
    # limitation imposed by use of unique chains in backend output
    if len(structs) > 52:
        raise RuntimeError('Maximum of 52 structures exceeded')
    strucfile = keywords.get('ALISTRUCS_OUTFILE', None)
    # provide a default output stream if none provided
    if log is None:
        from schrodinger.utils.log import get_logger
        log = get_logger()
    # extract template names (e.g. PDB codes) from CT titles
    tags = [ct.title for ct in structs]
    log.info('Using input structures: ' + ', '.join(tags))
    auto = False
    if _true_val.match(keywords.get('USE_AUTOMATIC_SETTINGS', 'no')):
        if len(structs) > 2:
            log.warn('Automatic settings ignored for multiple alignment')
        else:
            log.info('Using automatic settings')
            auto = True
    if _trolltop is None:
        try:
            _trolltop = os.environ['TROLLTOP']
        except KeyError:
            mmshare_data = hunt('mmshare', 'data')
            if not mmshare_data:
                raise RuntimeError('Unable to locate data files')
            _trolltop = os.path.join(mmshare_data, 'ska', 'allh.top')
        if os.path.exists(_trolltop):
            log.debug('Using topology file: ' + _trolltop)
        else:
            raise RuntimeError('Topology file %s appears to be missing' %
                               _trolltop)
    options = {'top': _trolltop}
    if debug:
        options['verr'] = True
    # convert top-level keywords to corresponding backend command-line flags
    try:
        if 'PSD_THRESHOLD' in keywords:
            options['p'] = float(keywords['PSD_THRESHOLD'])
        if 'SWITCH' in keywords and _true_val.match(keywords['SWITCH']):
            options['switch'] = True
        if 'ORDER' in keywords and keywords['ORDER'] == 'seed':
            options['seed'] = True
        if 'CLEAN' in keywords and _true_val.match(keywords['CLEAN']):
            options['clean'] = True
        if 'RECKLESS' in keywords and _true_val.match(keywords['RECKLESS']):
            options['reckless'] = True
        runopts = []
        if auto:
            # use defaults for everything other than '-w'
            options['i'] = 2.0
            options['g'] = 1.0
            options['s'] = 1.0
            options['minlength'] = 2
            for v in _winlen_vals:
                options['w'] = v
                runopts.append(dict(options))  # force a deep copy
        else:
            options['i'] = float(keywords.get('GAP_OPEN', 2.0))
            options['g'] = float(keywords.get('GAP_DEL', 1.0))
            options['s'] = float(keywords.get('SSE_MINSIM', 1.0))
            options['minlength'] = int(keywords.get('SSE_MINLEN', 2))
            # default is to not include '-w' flag on command-line
            if 'SSE_WINLEN' in keywords:
                options['w'] = int(keywords['SSE_WINLEN'])
            runopts.append(options)
    except ValueError as e:
        raise RuntimeError('Illegal value found in input file: %s' %
                           e.args[0].split()[-1])
    results = []
    for options in runopts:
        # run backend executable for each run
        try:
            if strucfile:
                tmppdb = os.path.basename(tempfile.mktemp('.pdb', 'ska_', '.'))
                options['o'] = tmppdb
            else:
                tmppdb = None
            text = 'Running ska with arguments:'
            for key in sorted(list(options)):
                if key == 'top':  # this is output in debug mode
                    continue
                if options[key] is True:
                    text += ' -%s' % key
                else:
                    text += ' -%s %s' % (key, str(options[key]))
            log.info(text)
            # backend takes unordered list of CTs as input to multiple alignment
            # options correspond to command-line flags for stand-alone
            # executable
            # The backend appears to have problems if the numebr of characters
            # in the title is more than 39 characters.  This appears to be
            # related to the fact that there code gets enacted if margin > 40
            # in SequenceList::BuildOutput
            #
            # Make copies of the cts with standard titles
            structs_to_run = [ct.copy() for ct in structs]
            map_new_to_orig_title = {}
            for ict, ct in enumerate(structs_to_run):
                new_title = "Structure%08d" % ict
                map_new_to_orig_title[new_title] = ct.title
                ct.title = new_title
            ska_raw_output, ska_transf, ska_index = _ska.run_ska(
                [ct.handle for ct in structs_to_run], options)
            # Convert the outputs to look like what they would with original
            # titles (minus the crashes)
            transf = {}
            index = {}
            raw_output = ska_raw_output
            # Get the maximum length of the original titles to aid
            # in formatting the output
            max_title_len = max([
                len(map_new_to_orig_title[k])
                for k in list(map_new_to_orig_title)
            ])
            max_orig_title_len = max(
                [len(k) for k in list(map_new_to_orig_title)])
            for new_title, orig_title in map_new_to_orig_title.items():
                if new_title in ska_transf:
                    transf[orig_title] = ska_transf[new_title]
                if new_title in ska_index:
                    index[orig_title] = ska_index[new_title]
            new_output = []
            for line in raw_output.split("\n"):
                # Adjust the spacing of the ruler to match
                # the change in the sequence
                if line.startswith(" " * max_orig_title_len):
                    line = line.replace(" " * max_orig_title_len,
                                        " " * max_title_len, 1)
                # Adjust the names of the sequences
                for new_title, orig_title in map_new_to_orig_title.items():
                    line = line.replace(new_title,
                                        ("%-" + str(max_title_len) + "s") %
                                        orig_title)
                new_output.append(line)
            raw_output = "\n".join(new_output)
            # Done with ct title manipulation code
            psd = None
            rmsd = None
            align = dict(list(zip(tags, [''] * len(tags))))
            sse_map = dict(list(zip(tags, [''] * len(tags))))
            sse_tags = {f'tc_sse:{tag}': tag for tag in tags}
            ready = False
            for line in raw_output.splitlines():
                # in case of verbose output, scroll down to actual alignment
                if '..........+' in line:
                    ready = True
                if not ready:
                    continue
                field = line.split()
                if not field:
                    continue
                first_word = field[0]
                if first_word in tags:
                    align[first_word] += field[-1]
                elif first_word in sse_tags:
                    sse_map[sse_tags[first_word]] += field[-1]
                    continue
                try:
                    if first_word == 'PSD:':
                        psd = float(field[1])
                    elif first_word == 'RMSD:':
                        rmsd = float(field[1])
                except ValueError:
                    raise RuntimeError(
                        'Error parsing alignment scores in output')
            # sanity check for aligned sequences, since arrays are
            # initialized for all template tags
            ali_len = 0
            for s in align.values():
                if ali_len:
                    if len(s) != ali_len:
                        raise RuntimeError(
                            'Error parsing sequence alignment in output')
                else:
                    ali_len = len(s)
            for s in sse_map.values():
                if len(s) != ali_len:
                    raise RuntimeError(
                        'Error parsing secondary structure in output')
            try:
                log.info('\tPSD: %f, RMSD: %f' % (psd, rmsd))
            except TypeError:
                log.info('\tNo alignment score available')
            results.append(
                (raw_output, transf, index, tmppdb, align, psd, rmsd, sse_map))
        except SystemError:
            raise RuntimeError('Fatal error in ska backend')
    # pick best score if running multiple jobs with automatic settings
    if len(results) > 1:
        (raw_output, transf, index, tmppdb, align, psd, rmsd,
         sse_map) = sorted(results, key=operator.itemgetter(5))[0]
        log.debug('Selecting best result by PSD: %f' % psd)
    else:
        (raw_output, transf, index, tmppdb, align, psd, rmsd,
         sse_map) = results[0]
    # sanity check for rotation matrix data
    rotmat = []
    try:
        refmat = numpy.array(transf[tags[0]])
    except KeyError:
        refmat = None
    for tag in tags[1:]:
        try:
            mat = numpy.array(transf[tag])
            if mat.shape != (4, 3):
                raise RuntimeError('Invalid rotation matrix for template: ' +
                                   tag)
            rotmat.append((mat[:3], mat[3]))
        except KeyError:
            if refmat is None:
                raise RuntimeError('No rotation matrix found for template: ' +
                                   tag)
            else:
                # this template was used as the reference for alignment
                rotmat.append((numpy.eye(3), numpy.zeros((3,))))
    # transform to query frame if alignment was reordered
    if refmat is not None:
        refu = refmat[:3].transpose()
        refx = refmat[3]
        for i in range(len(rotmat)):
            u, x = rotmat[i]
            u = numpy.dot(refu, u)
            x = numpy.dot(refu, x - refx)
            rotmat[i] = (u, x)
    # only need to carry out rotations here if output structures requested
    # (for backwards compatibility, otherwise calling script can use returned
    # matrices directly)
    if strucfile:
        shutil.copy(tmppdb, strucfile)
        log.info('Raw structures written to file: ' + strucfile)
        temfiles = keywords.get('TEMPLATE', [])
        # Remove structures with invalid residues:
        # FIXME removing by titles won't work correctly in cases where same
        # title is used by multiple CTs - or even when the beginning of a title
        # is shared.
        poplist = []
        i = 0
        for tf in temfiles:
            for ttl in poptitles:
                if (tf.startswith(ttl)):
                    poplist.append(i)
            i += 1
        for i in range(len(poplist) - 1, -1, -1):
            temfiles.pop(poplist[i])
        # filenames for modified structures
        rotfiles = ['rot-' + f for f in temfiles]
        # enumerate over templates using matching-length lists
        # (rotmat is a list of tuples)
        for r, f, (u, x) in zip(rotfiles, temfiles, rotmat):
            try:
                tfile = open(f, 'r')
            except:
                raise RuntimeError('Unable to read template file: ' + f)
            try:
                rfile = open(r, 'w')
            except:
                raise RuntimeError('Unable to create rotated file: ' + r)
            try:
                # exact copy of entire file (including ligands, etc) only
                # changing coordinates
                for line in tfile:
                    data = _atom_rec.match(line)
                    if data:
                        try:
                            # coordinates
                            crds = numpy.array(
                                list(map(float, data.group(3, 4, 5))))
                            # everything else in line
                            pre, post = data.group(1, 6)
                            crds = numpy.inner(crds, u) + x
                            rfile.write(pre + '%8.3f%8.3f%8.3f' %
                                        tuple(crds.tolist()) + post + '\n')
                        except TypeError:
                            raise RuntimeError(
                                'Invalid data found in template file: ' + f)
                    else:
                        rfile.write(line)
            finally:
                rfile.close()
                tfile.close()
            log.debug('Creating rotated template file: ' + r)
    else:
        log.info('Rotated PDB files not generated')
    if strucfile and not debug:
        for r in results:
            force_remove(r[3])
    # Replace the CT titles and delete the CT-level properties
    # used to store them, if there were any problems
    if (need_swap):
        for ct in structs:
            ct.title = ct.property['s_psp_ska_orig_title']
            del ct.property['s_psp_ska_orig_title']
        # Perform some surgery on _SkaData arguments to ensure that
        # the results can be parsed later.
        ska_out_disclaimer = "# Notice: the SKA output below utilizes a " + \
                            
"different set of structure titles\n#          " + \
                            
"than were originally specified.  These " + \
                            
"changes were made because\n#          the " + \
                            
"titles contained either spaces or unusually " + \
                            
"long names.  To\n#          ensure that SKA " + \
                            
"could properly process the structures, the\n" + \
                            
"#          following name key was implemented:\n"
        for i, ct in enumerate(structs):
            align[ct.title] = align["ct" + str(i)]
            index[ct.title] = index["ct" + str(i)]
            del align["ct" + str(i)]
            del index["ct" + str(i)]
            ska_out_disclaimer += "#  ct" + str(i) + ":  " + ct.title + "\n"
        raw_output = ska_out_disclaimer + "\n" + raw_output
        # Get the original titles back to sse_map if the titles were swapped due to length limit or whitespace.
        new_sse_map = dict()
        for ct_sub_title in sse_map:
            new_sse_map[original_titles[ct_sub_title]] = sse_map[ct_sub_title]
        sse_map = new_sse_map
    # NOTE: rotated structures are not returned; but are instead written to
    # the CWD with the rot-* prefix if ALISTRUCS_OUTFILE was set.
    log.info('SKA calculation successfully completed')
    return _SkaData(rotmat, align, index, psd, rmsd, raw_output, sse_map)