"""
Perform a pairwise superposition of multiple structures using the
C-alpha atoms of selected residues.
"""
#- Imports -------------------------------------------------------------------
import csv
from itertools import chain
from itertools import combinations
from math import ceil
from schrodinger import structure
from schrodinger.infra.mm import MMCT_ATOM_BALLNSTICK
from schrodinger.protein import findhets
from schrodinger.structutils import analyze
from schrodinger.structutils import color
from schrodinger.structutils import measure
from schrodinger.structutils import rmsd
from schrodinger.utils import csv_unicode
#- Globals --------------------------------------------------------------------
BINDING_SITE_STYLE = MMCT_ATOM_BALLNSTICK
ALIGNED_STYLE = MMCT_ATOM_BALLNSTICK
ALIGNED_COLOR = 'green'
#- Functions -----------------------------------------------------------------
[docs]def get_ligand_asl(cutoff=5.0, molnum=None, fillres=True):
    """
    Returns an ASL expression defining binding site atoms based on the
    ligand molecule number and cutoff distance. This ASL specifically
    identifies the C-alpha atoms.
    :param cutoff: Cutoff from ligand used to asign binding site atoms
    :type  cutoff: float
    :param molnum: The molecule number fo the ligand. If None the ligand
            will be automatically detected.
    :type  molnum: int
    :param fillres: Use all atoms in a residue if any atom is found
    :type  fillres: bool
    :returns: ASL expression for binding site atoms
    :rtype: str
    """
    if molnum is None:
        lig_asl = '(mol.atoms 10-120)'
    else:
        lig_asl = "mol.num %i" % molnum
    asl = '(atom.ptype " CA ") AND NOT ( %s )' % lig_asl
    if fillres:
        asl += 'AND (fillres within %f ( %s )) ' % (cutoff, lig_asl)
    else:
        asl += 'AND (within %f ( %s )) ' % (cutoff, lig_asl)
    return asl 
[docs]def get_residue_asl(residues):
    """
    Returns an ASL expression defining binding site atoms based on the residue
    strings passed in. This ASL specifically identifies the C-alpha atoms.
    :param residues: Residues used to define binding site atoms
    :type  residues: list of residue strings (<chain>:<resnum>)
    :returns: ASL expression for binding site atoms
    :rtype: str
    """
    asls = []
    for i, residue in enumerate(residues):
        (chain, resnum) = residue.split(":")
        chain.replace('_', ' ')
        asls.append('(res.num %s AND chain.name "%s")' % (chain, resnum))
    asl = '(atom.ptype " CA ") AND (%s)' % ' OR '.join(asls)
    return asl 
#- Classes --------------------------------------------------------------------
[docs]class SiteMatchLookup(object):
    """
    Lookup for any reference and mobile structures that have matching
    binding site atoms. This stores a reference structure and a mobile
    structure, along with their mapped C-alpha atoms.
    The reason the terms reference and mobile are used is because when
    superimposing 2 structures for RMSD calculations the "reference" structure
    will always be the stationary structure by default. The "mobile" structure
    will have it's coordinates changed to get an iteratively lower RMSD.
    """
    SUPERIMPOSED_PROPERTY = 'r_sitealign_Superimposed_RMSD'
    INPLACE_PROPERTY = 'r_sitealign_In_Place_RMSD'
[docs]    def __init__(self, ref_st, mobile_st, atom_map):
        """
        :param ref_st: Reference structure
        :type  ref_st: `structure<schrodinger.structure.Structure>`
        :param mobile_st: Mobile structure
        :type  mobile_st: `structure<schrodinger.structure.Structure>`
        :param atom_map: Reference and mobile binding site atom map
        :type  atom_map: dict where keys are ref atom indices and values are
                         the mobile atom indices that overlap the ref indices.
        """
        self.ref_st = ref_st
        self.mobile_st = mobile_st
        self.atom_map = atom_map 
[docs]    def getMobileStRMSD(self):
        """
        Gets the inplace and superimposed RMSDs for the mobile st
        """
        i_rmsd = self.mobile_st.property.get(self.INPLACE_PROPERTY, None)
        s_rmsd = self.mobile_st.property.get(self.SUPERIMPOSED_PROPERTY, None)
        return (i_rmsd, s_rmsd) 
    def _setSuperimposedRMSD(self, new_rmsd):
        """ Sets the superimposed RMSD for the mobile st """
        self.mobile_st.property[self.SUPERIMPOSED_PROPERTY] = new_rmsd
    def _setInPlaceRMSD(self, new_rmsd):
        """ Sets the inplace RMSD for the mobile st """
        self.mobile_st.property[self.INPLACE_PROPERTY] = new_rmsd
    @property
    def ref_keys(self):
        """
        Getter for reference keys. The key is a list of residue strings defining the
        residues belonging to the reference structure.
        :returns: List of residue strings defining residues in reference structure
        :rtype: list
        """
        keys = []
        for ref_idx in list(self.atom_map):
            atom = self.ref_st.atom[ref_idx]
            residue = atom.getResidue()
            keys.append(str(residue))
        return keys
    @property
    def mob_keys(self):
        """
        Getter for mobile keys. The key is a list of residue strings defining
        the residues belonging to the mobile structure. If the matching mobile
        key is None, an empty string is appended.
        :returns: List of residue strings defining residues in mobile structure
        :rtype: list
        """
        keys = []
        for mob_idx in self.atom_map.values():
            if mob_idx is None:
                keys.append('')
                continue
            atom = self.mobile_st.atom[mob_idx]
            residue = atom.getResidue()
            keys.append(str(residue))
        return keys
[docs]    def getMatchingAtomMap(self):
        matching_atom_map = {}
        for ref_idx, mob_idx in self.atom_map.items():
            if mob_idx is None:
                continue
            matching_atom_map[ref_idx] = mob_idx
        return matching_atom_map 
[docs]    def getSuperimposedRMSD(self):
        """
        Superimpose mobile structure to reference structure. This will translate
        the mobile structure to the reference structure.
        :returns: The RMSD after superimposing
        :rtype: float
        """
        atom_map = self.getMatchingAtomMap()
        superimposed_rmsd = rmsd.superimpose(self.ref_st,
                                             list(atom_map),
                                             self.mobile_st,
                                             list(atom_map.values()),
                                             move_which=rmsd.CT)
        superimposed_rmsd = float(superimposed_rmsd)
        self._setSuperimposedRMSD(superimposed_rmsd)
        return superimposed_rmsd 
[docs]    def getInPlaceRMSD(self, **kwargs):
        """
        Get the in-place RMSD of the reference structure to the mobile
        structure. No translation happens.
        :returns: The in-place RMSD
        :rtype: float
        :see: `rmsd.calculate_in_place_rmsd` for available kwargs
        """
        atom_map = self.getMatchingAtomMap()
        inplace_rmsd = rmsd.calculate_in_place_rmsd(self.ref_st, list(atom_map),
                                                    self.mobile_st,
                                                    list(atom_map.values()),
                                                    **kwargs)
        inplace_rmsd = float(inplace_rmsd)
        self._setInPlaceRMSD(inplace_rmsd)
        return inplace_rmsd 
[docs]    def colorByRMSD(self):
        """
        Color the reference and mobile structure by per-residue, backbone
        atom RMSD.  This will change:
        - Style of all non-het atoms in mobile structure to "element"
        - Color of all non-het carbons in mobile structure to "user12"
        - Style of mobile backbone atoms to ball and stick
        - Color of mobile residues by red-blue scale depending on per-reside
          RMSD
            - blue for RMSD below 0.1
            - red for RMSD above 3.4
        """
        # Get all the atoms in the mobile structure that are not hets
        mob_hets = chain.from_iterable(findhets.find_hets(self.mobile_st))
        mob_prot_atoms = set(mob_hets) - set(
            a.index for a in self.mobile_st.atom)
        # Color all non-hets by element, then their C's as "user12"
        color.apply_color_scheme(self.mobile_st, 'element', mob_prot_atoms)
        for idx in mob_prot_atoms:
            atom = self.mobile_st.atom[idx]
            if atom.element == "C":
                atom.color = "user12"
        def get_res_bb_atoms(st, resnum):
            """ Gets the atoms in a residue for RMSD calculation """
            asl = '(res.num %d) AND ( backbone ) AND NOT (atom.ele H)' % resnum
            return analyze.evaluate_asl(st, asl)
        # Loop over all reference and mobile atoms in the atom map and
        # color the mobile structure by RMSD
        res_indices = []
        for ref_idx, mob_idx in self.getMatchingAtomMap().items():
            ref_res = self.ref_st.atom[ref_idx].getResidue()
            mob_res = self.ref_st.atom[ref_idx].getResidue()
            # Look for resnum in res_indices so we only work on a residue once
            if mob_res.resnum in res_indices:
                continue
            else:
                res_indices.append(mob_res.resnum)
            # Only calculate RMSD on the backbone atoms
            ref_bb_atoms = get_res_bb_atoms(self.ref_st, ref_res.resnum)
            mob_bb_atoms = get_res_bb_atoms(self.mobile_st, mob_res.resnum)
            rmsd_value = rmsd.calculate_in_place_rmsd(
                self.ref_st,
                ref_bb_atoms,
                self.mobile_st,
                mob_bb_atoms,
            )
            color_val = ceil(rmsd_value * 10)
            if color_val == 1:
                color_str = 'blue'
            elif color_val >= 34:
                color_str = 'red'
            elif color_val <= 17:
                color_str = 'blue%d' % ((color_val * 2) - 2)
            else:
                color_str = 'red%d' % (32 - ((color_val - 18) * 2))
            # Get the residue's atoms and apply the color to all atoms
            for idx in mob_res.getAtomIndices():
                atom = self.mobile_st.atom[idx]
                atom.color = color_str
                # Change the mobile backbone atom's style
                if idx in mob_bb_atoms:
                    atom.atom_style = BINDING_SITE_STYLE  
[docs]class BindingSiteAligner(object):
    """
    Align structures by matching binding site atoms.
    """
[docs]    def __init__(self, ref_st, mobile_sts, binding_site_asl, ignore_dist=5.0):
        """
        :param ref_st: The reference structure
        :type ref_st: `schrodinger.structure.Structure`
        :param mobile_sts: List of structures to align to the reference
        :type mobile_sts: list of `structures<schrodinger.structure.Structure>`
        :param binding_site_asl: The asl used to define the binding site atoms
        :type  binding_site_asl: str
        :param ignore_dist: Minimum dist needed for atoms to match each other
                            in reference struct and mobile structs
        :type ignore_dist: float
        """
        self.reference = ref_st
        self.mobile_sts = mobile_sts
        self.asl = binding_site_asl
        self.ignore_dist = float(ignore_dist)
        self.site_matches = []
        """
        List of each `match<SiteMatchLookup>` found between the reference
        structure and each mobile structure. These contain the reference
        structure and a mobile structure sharing at least 3 C-alpha binding
        site atoms.
        """
        self.matrix_matches = []
        """
        List of each `match<SiteMatchLookup>` found between mobile structures.
        These contain the pair of mobile structures that share at least 3
        C-alpha binding site atoms with each other and with `self.reference`.
        """ 
[docs]    def setBindingSiteAtoms(self, asl=None):
        """
        Sets the binding site atoms to the `asl` provided.
        :param asl: The asl used to define the binding site atoms. If None
                    `self.asl` will be used (from asl passed in in init).
        :type  asl: str
        :raise RuntimeError: If the asl does not return at least 3 atom indices
        """
        if asl is None:
            asl = self.asl
        self.binding_site_atoms = analyze.evaluate_asl(self.reference, asl)
        # We need at least 3 atoms to superimpose
        if len(self.binding_site_atoms) < 3:
            err_msg = "Less than 3 receptor atoms found using ASL:\n%s" % asl
            raise RuntimeError(err_msg)
        # Set atoms used in superposition to CPK
        for a in (self.reference.atom[i] for i in self.binding_site_atoms):
            a.style = BINDING_SITE_STYLE 
[docs]    def setSiteMatches(self):
        """
        Set the `self.site_matches` variable. This stores all the
        `lookups<SiteMatchLookup>` that will be used to calculate the RMSDs.
        There will be no `SiteMatchLookup` created for mobiles sites that do
        not have at least 3 atoms matching the reference binding site atoms .
        """
        for i, mob_st in enumerate(self.mobile_sts):
            # Get all the mobile structure's alpha carbon indices to compare
            # with the reference binding site alpha carbons
            mobile_indices = analyze.evaluate_asl(mob_st, '(atom.ptype " CA ")')
            match_map = {}
            for ref_idx in self.binding_site_atoms:
                ref_atom = self.reference.atom[ref_idx]
                # Set the minimum distance to the ignore cutoff
                dist = self.ignore_dist
                closest_atom = None
                # Find the closest mob atom to ref atom within dist criterion
                for mob_idx in mobile_indices:
                    mob_atom = mob_st.atom[mob_idx]
                    new_dist = measure.measure_distance(ref_atom, mob_atom)
                    if new_dist < dist:
                        dist = new_dist
                        closest_atom = mob_idx
                match_map[ref_idx] = closest_atom
            # Test to make sure we have at least 3 atoms unique atoms.
            # The len test is 4 b/c None will be in there
            unique_mob_matches = set(match_map.values())
            if len(unique_mob_matches) > 4:
                lookup = SiteMatchLookup(self.reference, mob_st, match_map)
                self.site_matches.append(lookup)
            # If there are no sites to align raise an error
        if len(self.site_matches) == 0:
            err = 'No mobile structures have at least 3 atoms that '
            err += 'match the reference binding site atoms within '
            err += '%.2f Angstroms' % self.ignore_dist
            raise RuntimeError(err) 
[docs]    def alignSites(self, inplace=False, color_by_rmsd=False):
        """
        Align all structures from `self.site_matches`. This will return the
        aligned mobile structures.
        All mobile structures have a `r_sitealign_inlpace_rmsd` property added
        to them indicating the RMSD in-place of the reference and the mobile
        structure.
        If `inplace=False` another property, `r_sitealign_superimposed_rmsd`
        is added to the mobile structure indicating the RMSD of the reference
        and mobile structure after they have been superimposed.
        :param inplace: Whether to run RMSD calculations in-place or not
        :type  inplace: bool
        :param color_by_rmsd: Color the mobile atoms by RMSD. Can only be used
                              if inplace is False
        :type color_by_rmsd: bool
        :returns: The mobile structures (aligned if `inplace=False`)
        :rtype: list(structures<schrodinger.structure.Structure)
        """
        # Get the binding site atoms and the match lookups
        self.setBindingSiteAtoms()
        self.setSiteMatches()
        if inplace:
            if color_by_rmsd:
                color_by_rmsd = False
        for match in self.site_matches:
            inplace_rmsd = match.getInPlaceRMSD()
            if not inplace:
                rmsd_after = match.getSuperimposedRMSD()
            if color_by_rmsd:
                match.colorByRMSD()
            else:
                # If we don't color the residues by RMSD we should change the
                # representation and color of the mobile CA's.
                # Here atom_map.values() is the list of mobile atoms mathced to
                # a reference atom
                for mob_idx in match.atom_map.values():
                    if mob_idx is None:
                        continue
                    atom = match.mobile_st.atom[mob_idx]
                    atom.style = ALIGNED_STYLE
                    atom.color = ALIGNED_COLOR
        return self.getAlignedMobileStructures() 
[docs]    def calculateMatrix(self):
        """
        Generates an in-place RMSD matrix between all mobile structures in
        `self.site_matches`.
        This method will go through all combinations of mobile structures in
        `self.site_matches`, in subsequences of two. It will then test for 3
        common C-alpha atoms between the two mobile structures that have matches
        with the reference structure. Example: `self.site_matches` has 3
        `matches<SiteMatchLookup>`, A, B and C. The matrix calculation will
        loop over:
        1. A and B, then
        2. A and C, then
        3. B and C
        Pairs of mobile structures that do not share 3 atoms with each other
        **and** with the reference will be skipped.
        The in-place RMSD can be accessed in two ways. One way is to write the
        data out to a CSV-formatted file. This file contains a header and then
        each row will summarize teh RMSD calculated. This file can be written
        using the `self.writeMatrixData()` method. The second way to access
        the RMSD information is to get it from the returned lookups::
            matrix_matches = aligner.calculateMatrix()
            for match in matrix_matches:
                mobile_st1   = match.ref_st
                mobile_st2   = match.mobile_st
                inplace_rmsd = match.inplace_rmsd
        :rtype: list(SiteMatchLookup)
        :returns: Lookups for each pair of mobile atoms sharing at least 3
                  atoms with each other and the reference
        :raise RuntimeError: If less than 2 mobile structures were aligned to
                             reference
        """
        # Raise error if less than two mobile structures were aligned to reference
        if len(self.site_matches) < 2:
            raise RuntimeError(
                'At least 2 mobile structures need to be aligned to reference')
        st_combos = combinations(self.site_matches, 2)
        for match1, match2 in st_combos:
            # Get a dictionary mapping reference key to mobile atom
            atom_pair1 = match1.getMatchingAtomMap()
            atom_pair2 = match2.getMatchingAtomMap()
            # Get the intersection of the reference atoms for each mobile st
            shared_atoms = set(list(atom_pair1)) & set(list(atom_pair2))
            if len(shared_atoms) < 3:
                continue
            # Get the C-alpha indices to use in RMSD
            mob1_atoms = [atom_pair1[ref_key] for ref_key in shared_atoms]
            mob2_atoms = [atom_pair2[ref_key] for ref_key in shared_atoms]
            lookup = SiteMatchLookup(match1.mobile_st, match2.mobile_st,
                                     dict(list(zip(mob1_atoms, mob2_atoms))))
            inplace_rmsd = lookup.getInPlaceRMSD()
            self.matrix_matches.append(lookup)
        return self.matrix_matches 
[docs]    def writeStructures(self, filename):
        """
        Write reference and mobile structures to a file.  This will write the
        structures to a file based on their current state.  If this is called
        right after initializing the class the structures will be unaltered.
        If this is called after sites are superimposed the structures written
        will be superimposed.
        :param filename: File name to write structures to
        :type  filename: str
        """
        writer = structure.StructureWriter(filename)
        writer.append(self.reference)
        if len(self.site_matches) == 0:
            for st in self.mobile_sts:
                writer.append(st)
        else:
            for match in self.site_matches:
                writer.append(match.mobile_st) 
[docs]    def getAlignedMobileStructures(self):
        """
        Returns a generator that iterates over the aligned mobile structures
        """
        mobile_sts = []
        for match in self.site_matches:
            mobile_sts.append(match.mobile_st)
        return mobile_sts 
    def _writeCSVData(self, filename, csv_data):
        """  Private method to write out csv data """
        with csv_unicode.writer_open(filename) as fh:
            writer = csv.writer(fh)
            writer.writerows(csv_data)
[docs]    def writeAlignData(self, filename):
        """
        Write out all RMSD data for mobile structures aligned to the
        reference structure.
        :param filename: File name to write csv data to
        :type  filename: str
        """
        data = [["Reference", "Mobile", "Inplace_RMSD", "Superimposed_RMSD"]]
        for match in self.site_matches:
            inplace_rmsd, superimposed_rmsd = match.getMobileStRMSD()
            if not inplace_rmsd:
                inplace_rmsd = ''
            if not superimposed_rmsd:
                superimposed_rmsd = ''
            ref_title = match.ref_st.title
            mob_title = match.mobile_st.title
            data.append([ref_title, mob_title, inplace_rmsd, superimposed_rmsd])
        if len(data) > 1:
            self._writeCSVData(filename, data) 
[docs]    def writeMatrixData(self, filename):
        """
        Write out all RMSD data for mobile structures aligned to other
        mobile structures.
        :param filename: File name to write csv data to
        :type  filename: str
        """
        data = [["First structure", "Second structure", "RMSD"]]
        for match in self.matrix_matches:
            inplace_rmsd, superimposed_rmsd = match.getMobileStRMSD()
            data.append(
                [match.ref_st.title, match.mobile_st.title, inplace_rmsd])
        if len(data) > 1:
            self._writeCSVData(filename, data)