# python module
"""
Classes for reading and writing Glide XP Descriptor blocks from pose files.
Copyright Schrodinger LLC. All rights reserved.
"""
import enum
from past.utils import old_div
from schrodinger import structure
from schrodinger.infra import mm
from schrodinger.utils import log
# Set up the global terms, they store all the term-related data
_GLIDE_TERM_NAMES = [
    mm.GLIDE_XP_GSCORE,
    mm.GLIDE_XP_LIPOPHILIC_EVDW,
    mm.GLIDE_XP_PHOBEN,
    mm.GLIDE_XP_PHOBEN_HB,
    mm.GLIDE_XP_PHOBEN_PAIRHB,
    mm.GLIDE_XP_HBOND,
    mm.GLIDE_XP_ELECTRO,
    mm.GLIDE_XP_SITEMAP,
    #                    mm.GLIDE_XP_PISTACK,
    mm.GLIDE_XP_PICAT,
    mm.GLIDE_XP_CL_BR_PACK,
    mm.GLIDE_XP_LOWMW,
    mm.GLIDE_XP_PENALS,
    mm.GLIDE_XP_PENAL_HB,
    mm.GLIDE_XP_PENAL_PHOBIC,
    mm.GLIDE_XP_PENAL_ROT
]
# Dict of functions to call for each property:
_ALL_BLOCKS = {}
_ALL_BLOCKS['m_glide_XPviz_hbonds'] = 'm_glide_XPviz_hbonds'
_ALL_BLOCKS['m_glide_XPviz_phobcon_hb'] = 'm_glide_XPviz_phobcon_hb'
_ALL_BLOCKS['m_glide_XPviz_hexsp'] = 'm_glide_XPviz_hexsp'
_ALL_BLOCKS['m_glide_XPviz_phobpack'] = 'm_glide_XPviz_phobpack'
_ALL_BLOCKS['m_glide_XPviz_stacking'] = 'm_glide_XPviz_stacking'
_ALL_BLOCKS['m_glide_XPviz_rotbonds'] = 'm_glide_XPviz_rotbonds'
_ALL_BLOCKS['m_glide_XPviz_watmol'] = 'm_glide_XPviz_watmol'
_ALL_BLOCKS['m_glide_XPviz_hexpairs'] = 'm_glide_XPviz_hexpairs'
_ALL_BLOCKS['m_glide_XPviz_exposure'] = 'm_glide_XPviz_exposure'
_ALL_BLOCKS['m_glide_XPviz_penalties'] = 'm_glide_XPviz_penalties'
_ALL_BLOCKS['m_glide_XPviz_picat'] = 'm_glide_XPviz_picat'
logger = log.get_output_logger("xpdes.py")
[docs]def extract_Phob_pack_ligand_indices(data_row):
    """
    Returns the ligand indices of the Phob_pack indices data row.
    The Phob_pack data block contains a row of indices containing
    [pack value, number of ligand indices, number of protein indices,
    lig_0, lig_1, lig_2, ..., prot_0, prot_1, prot_2, ...]
    Note that indices in this data block are zero based, and are incremented
    here for use in ct atom index access.
    :param data_row: second row in the Phob_pack data block
    :type data_row: list of int
    :return: ligand indices in the Phob_pack data block
    :rtype: list of int
    """
    num_ligand_indices = int(data_row[1])
    return [i + 1 for i in data_row[3:num_ligand_indices + 3]] 
[docs]def get_prop(cthandle, dn):
    """Helper function for m_glide_XPviz_* functions"""
    if dn.startswith("r_"):
        return mm.m2io_get_real(cthandle, [dn])
    elif dn.startswith("i_"):
        return mm.m2io_get_int(cthandle, [dn])
    elif dn.startswith("b_"):
        return mm.m2io_get_boolean(cthandle, [dn])
    elif dn.startswith("x_"):
        return mm.m2io_get_string(cthandle, [dn])
    return None 
[docs]def get_prop_idx(cthandle, dn, dim):
    """Helper function for m_glide_XPviz_* functions"""
    if dn.startswith("r_"):
        return mm.m2io_get_real_indexed(cthandle, dim, [dn])
    elif dn.startswith("i_"):
        return mm.m2io_get_int_indexed(cthandle, dim, [dn])
    elif dn.startswith("b_"):
        return mm.m2io_get_boolean_indexed(cthandle, dim, [dn])
    elif dn.startswith("x_"):
        return mm.m2io_get_string_indexed(cthandle, dim, [dn])
    return None 
#
# m_glide_XPviz_* functions from XP Visualizer (Blaine Bell)
#
# FIXME these functions have a log of duplication that can be factored out
[docs]def m_glide_XPviz_hbonds(cthandle):
    data_names = mm.m2io_get_data_names(cthandle, -1)
    dim = mm.m2io_get_index_dimension(cthandle)
    ret_data = []
    ret_data.append(['H_bonds', dim])
    for i in range(1, dim + 1):
        rowvals = []
        for dn in data_names:
            val = get_prop_idx(cthandle, dn, i)
            rowvals.extend(val)
        ret_data.append(rowvals)
    return ['H_bonds', ret_data] 
[docs]def m_glide_XPviz_phobcon_hb(cthandle):
    data_names = mm.m2io_get_data_names(cthandle, -1)
    dim = mm.m2io_get_index_dimension(cthandle)
    ret_data = []
    ret_data.append(['Phobcon_HB', dim])
    for i in range(1, dim + 1):
        rowvals = []
        for dn in data_names:
            val = get_prop_idx(cthandle, dn, i)
            rowvals.extend(val)
        ret_data.append(rowvals)
    return ['Phobcon_HB', ret_data] 
[docs]def m_glide_XPviz_hexsp(cthandle):
    data_names = mm.m2io_get_data_names(cthandle, -1)
    dim = mm.m2io_get_index_dimension(cthandle)
    ret_data = []
    ret_data.append(['Hex_sp', dim])
    rowvals = []
    for i in range(1, dim + 1):
        for dn in data_names:
            val = get_prop_idx(cthandle, dn, i)
            rowvals.extend(val)
    ret_data.append(rowvals)
    return ['Hex_sp', ret_data] 
[docs]def m_glide_XPviz_phobpack(cthandle):
    data_names = mm.m2io_get_data_names(cthandle, -1)
    ret_data = []
    fr = ['Phob_pack']
    for dn in data_names:
        val = get_prop(cthandle, dn)
        fr.extend(val)
    ret_data.append(fr)
    # NEED TO READ IN DATA
    mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_phobatsc')
    num_a = mm.m2io_get_index_dimension(cthandle)
    ats = []
    atv = []
    for i in range(1, num_a + 1):
        ats.extend(get_prop_idx(cthandle, "i_glide_XPviz_phobatscat", i))
        atv.extend(get_prop_idx(cthandle, "r_glide_XPviz_phobatscpar", i))
    mm.m2io_leave_block(cthandle)
    mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_pack_res')
    num_p = mm.m2io_get_index_dimension(cthandle)
    for i in range(1, num_p + 1):
        ats.extend(get_prop_idx(cthandle, "i_glide_XPviz_pack_resnum", i))
    mm.m2io_leave_block(cthandle)
    rl = [fr[2], num_a, num_p]
    rl.extend(ats)
    ret_data.append(rl)
    ret_data.append(atv)
    return ['Phob_pack', ret_data] 
[docs]def m_glide_XPviz_stacking(cthandle):
    return None 
[docs]def m_glide_XPviz_rotbonds(cthandle):
    dim = mm.m2io_get_index_dimension(cthandle)
    ret_data = [['Rot_bonds', dim]]
    data_names = mm.m2io_get_data_names(cthandle, -1)
    for i in range(1, dim + 1):
        for dn in data_names:
            val = get_prop_idx(cthandle, dn, i)
            ret_data[0].extend(val)
    return ['Rot_bonds', ret_data] 
[docs]def m_glide_XPviz_watmol(cthandle):
    dim = mm.m2io_get_index_dimension(cthandle)
    ret_data = [['Watmol', dim]]
    data_names = mm.m2io_get_data_names(cthandle, -1)
    for i in range(1, dim + 1):
        for dn in data_names:
            val = get_prop_idx(cthandle, dn, i)
            ret_data[0].extend(val)
    return ['Watmol', ret_data] 
[docs]def m_glide_XPviz_hexpairs(cthandle):
    dim = mm.m2io_get_index_dimension(cthandle)
    ret_data = [['Hex_pairs', dim]]
    for i in range(1, dim + 1):
        r = []
        for dn in [
                'i_glide_XPviz_ispec1a', 'i_glide_XPviz_ipspec',
                'i_glide_XPviz_ispec2a', 'i_glide_XPviz_jpspec',
                'r_glide_XPviz_hexspec'
        ]:
            val = get_prop_idx(cthandle, dn, i)
            r.extend(val)
        ret_data.append(r)
    return ['Hex_pairs', ret_data] 
[docs]def m_glide_XPviz_exposure(cthandle):
    ret_data = []
    fl = ['Exposure']
    data_names = mm.m2io_get_data_names(cthandle, -1)
    for dn in data_names:
        val = get_prop(cthandle, dn)
        fl.extend(val)
    ret_data.append(fl)
    mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_expo_group_penal')
    dim = mm.m2io_get_index_dimension(cthandle)
    exp_penal_data = []
    data_names = ['i_glide_XPviz_expo_nat', 'r_glide_XPviz_expo_penal']
    for i in range(1, dim + 1):
        nl = []
        for dn in data_names:
            val = get_prop_idx(cthandle, dn, i)
            nl.extend(val)
        exp_penal_data.append(nl)
    mm.m2io_leave_block(cthandle)
    mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_expo_groupat')
    dim = mm.m2io_get_index_dimension(cthandle)
    for i in range(1, dim + 1):
        nl = []
        idx = get_prop_idx(cthandle, 'i_glide_XPviz_expo_liggroup', i)
        at = get_prop_idx(cthandle, 'i_glide_XPviz_expo_ligat', i)
        exp_penal_data[idx[0] - 1].extend(at)
    ret_data.extend(exp_penal_data)
    mm.m2io_leave_block(cthandle)
    return ['Exposure', ret_data] 
[docs]def m_glide_XPviz_penal_blkpchg(cthandle):
    blkppar = get_prop(cthandle, "r_glide_XPviz_penal_blkppar")
    mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_penal_desolvlist')
    dim = mm.m2io_get_index_dimension(cthandle)
    ret_list = [['Blkpchg', dim, blkppar[0]]]
    for i in range(1, dim + 1):
        val = get_prop_idx(cthandle, 'i_glide_XPviz_penal_desolvat', i)
        ret_list[0].extend(val)
    mm.m2io_leave_block(cthandle)
    return ret_list 
[docs]def m_glide_XPviz_water_ligand(cthandle):
    ligpar = get_prop(cthandle, "r_glide_XPviz_water_ligpar")
    mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_water_liglist')
    dim = mm.m2io_get_index_dimension(cthandle)
    ret = [['Water-ligand', dim, ligpar[0]]]
    for i in range(1, dim + 1):
        r = []
        val = get_prop_idx(cthandle, 'i_glide_XPviz_water_ligat', i)
        ret[0].extend(val)
    mm.m2io_leave_block(cthandle)
    return ret 
[docs]def m_glide_XPviz_penal_polar(cthandle):
    polpar = get_prop(cthandle, "r_glide_XPviz_penal_polpar")
    mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_penal_pollist')
    dim = mm.m2io_get_index_dimension(cthandle)
    ret = [['Polar', dim, polpar[0]]]
    for i in range(1, dim + 1):
        r = []
        val = get_prop_idx(cthandle, 'i_glide_XPviz_penal_polarat', i)
        ret[0].extend(val)
    mm.m2io_leave_block(cthandle)
    return ret 
[docs]def m_glide_XPviz_water_protein(cthandle):
    protpar = get_prop(cthandle, "r_glide_XPviz_water_protpar")
    mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_water_protlist')
    dim = mm.m2io_get_index_dimension(cthandle)
    ret = [['Water-protein', dim, protpar[0]]]
    for i in range(1, dim + 1):
        r = []
        val = get_prop_idx(cthandle, 'i_glide_XPviz_water_protat', i)
        ret[0].extend(val)
    mm.m2io_leave_block(cthandle)
    return ret 
[docs]def m_glide_XPviz_penal_ddpen(cthandle):
    ddpenpar = get_prop(cthandle, "r_glide_XPviz_penal_ddpenpar")
    mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_penal_ddpenlist')
    dim = mm.m2io_get_index_dimension(cthandle)
    data_names = mm.m2io_get_data_names(cthandle, -1)
    ret = [['DDpen', dim, ddpenpar[0]]]
    for i in range(1, dim + 1):
        r = []
        for dn in data_names:
            val = get_prop_idx(cthandle, dn, i)
            r.extend(val)
        ret[0].extend(r)
    mm.m2io_leave_block(cthandle)
    return ret 
[docs]def m_glide_XPviz_penal_twistam(cthandle):
    ampar = get_prop(cthandle, "r_glide_XPviz_penal_ampar")
    mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_penal_amlist')
    dim = mm.m2io_get_index_dimension(cthandle)
    data_names = mm.m2io_get_data_names(cthandle, -1)
    ret = [['Twisted_Amide', dim, ampar[0]]]
    for i in range(1, dim + 1):
        r = []
        for dn in data_names:
            val = get_prop_idx(cthandle, dn, i)
            r.extend(val)
        ret.append(r)
    mm.m2io_leave_block(cthandle)
    return ret 
[docs]def m_glide_XPviz_penal_charge(cthandle):
    chgpar = get_prop(cthandle, "r_glide_XPviz_penal_chgpar")
    mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_penal_chglist')
    dim = mm.m2io_get_index_dimension(cthandle)
    data_names = mm.m2io_get_data_names(cthandle, -1)
    ret = [['Charge', dim, chgpar[0]]]
    for i in range(1, dim + 1):
        val = get_prop_idx(cthandle, 'i_glide_XPviz_penal_chargeat', i)
        ret[0].extend(val)
    mm.m2io_leave_block(cthandle)
    return ret 
[docs]def m_glide_XPviz_penalties(cthandle):
    dim = get_prop(cthandle, 'i_glide_XPviz_npenal')
    ret_data = [['Penalties', dim[0]]]
    block_names = mm.m2io_get_block_names(cthandle, 1)
    for bl in block_names:
        mm.m2io_goto_next_block(cthandle, bl)
        try:
            r = eval("%s(cthandle)" % (_ALL_BLOCKS[bl]))
            if r:
                ret_data.extend(r)
        except:
            # FIXME passing on all exceptions is dangerous
            pass
        mm.m2io_leave_block(cthandle)
    return ['Penalties', ret_data] 
[docs]def m_glide_XPviz_picat(cthandle):
    ret_data = []
    fl = ['Picat']
    dns = ['i_glide_XPviz_npicatat', 'i_glide_XPviz_npiatres']
    for dn in dns:
        val = get_prop(cthandle, dn)
        fl.extend(val)
    dim_a = 0
    atom_nums = []
    try:
        mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_picatat')
        data_names_a = mm.m2io_get_data_names(cthandle, -1)
        dim_a = mm.m2io_get_index_dimension(cthandle)
        for i in range(1, dim_a + 1):
            for dn in data_names_a:
                val = get_prop_idx(cthandle, dn, i)
                atom_nums.extend(val)
        mm.m2io_leave_block(cthandle)
    except:
        # FIXME passing on all exceptions is dangerous
        pass
    dim_p = 0
    res_nums = []
    try:
        mm.m2io_goto_next_block(cthandle, 'm_glide_XPviz_picatres')
        data_names_r = mm.m2io_get_data_names(cthandle, -1)
        dim_p = mm.m2io_get_index_dimension(cthandle)
        for i in range(1, dim_p + 1):
            for dn in data_names_r:
                val = get_prop_idx(cthandle, dn, i)
                res_nums.extend(val)
        mm.m2io_leave_block(cthandle)
    except:
        # FIXME passing on all exceptions is dangerous
        pass
    line = []
    line.extend(fl)
    line.extend(atom_nums)
    line.extend(res_nums)
    ret_data.append(line)
    return ['Picat', ret_data] 
[docs]class TermKeywords(str, enum.Enum):
    """
    XP-descriptor term keywords.  Note that members of this enum may be compared
    to bare strings.
    """
    Molecule = "Molecule"
    H_bonds = "H_bonds"
    Phob_pack = "Phob_pack"
    Hex_pairs = "Hex_pairs"
    Penalties = "Penalties"
    Phobcon_HB = "Phobcon_HB"
    Picat = "Picat"
    Exposure = "Exposure"
    Rot_bonds = "Rot_bonds"
    Hex_sp = "Hex_sp" 
[docs]class XpPVParser:
    """
    Class for parsing a PV file with XP-descriptor information
    """
[docs]    def __init__(self, filename):
        """
        Read PV file and create a table of XP Descriptor terms
        """
        self.xpdes_table = {}
        self.ctHasXPDes = {}
        self.CT_level_terms_from_XPBlock = {}
        for ctnum, st in enumerate(structure.StructureReader(filename)):
            if ctnum == 0:
                continue
            else:
                ligand_st = st
                cthandle = mm.mmct_ct_m2io_get_unrequested_handle(ligand_st)
                n = 0
                try:
                    n = mm.m2io_get_number_blocks(cthandle,
                                                  "m_glide_XPvisualizer")
                    mm.m2io_goto_next_block(cthandle, "m_glide_XPvisualizer")
                except:
                    # FIXME need to catch a specific exception here
                    logger.error(
                        "Input file '%s' has no m_glide_XPvisualizer block" %
                        filename)
                    return None
                self.CT_level_terms_from_XPBlock[
                    ctnum] = mm.m2io_get_real_indexed(cthandle, 1,
                                                      _GLIDE_TERM_NAMES)
                block_names = mm.m2io_get_block_names(cthandle, 1)
                for bl in block_names:
                    mm.m2io_goto_next_block(cthandle, bl)
                    try:
                        r = eval("%s(cthandle)" % (_ALL_BLOCKS[bl]))
                        if r:
                            if r[0] != 'Watmol':
                                self.xpdes_table[(ctnum, r[0])] = r[1]
                    except:
                        # FIXME passing on all exceptions is dangerous
                        pass
                    mm.m2io_leave_block(cthandle)
            self.ctHasXPDes[ctnum] = True 
[docs]    def parseBlock_SL(self, ctnum):
        """
        XP parsing for single-ligand scoring.
        Given ctnum, get and parse XP block
        return an array of tuples. tuple[0] is an atom number and
        tuple[1] is the corresponding XP score for that atom
        Added: tuple[2] the energy type e.g. r_xpdes_HBond which will
        be later set as an atom-level property.
        """
        XP_atoms_array = []
        try:
            ct_level_terms = self.CT_level_terms_from_XPBlock[ctnum]
            elec_score = ct_level_terms[6]
        except KeyError:
            raise KeyError("Failed to locate block for ligand %i" % (ctnum))
            return XP_atoms_array
        for term in TermKeywords:
            try:
                block = self.xpdes_table[(ctnum, term)]
                if term == TermKeywords.H_bonds:
                    for idx, row in enumerate(block):
                        if idx > 0:
                            ligand = row[0]
                            hbond_xp = float(row[2])
                            XP_atoms_array.append(
                                (int(ligand), hbond_xp, "r_xpdes_HBond"))
                if term == TermKeywords.Phobcon_HB:
                    for idx, row in enumerate(block):
                        if idx > 0:
                            ligand = row[0]
                            hbond_xp = float(row[2])
                            XP_atoms_array.append(
                                (int(ligand), hbond_xp, "r_xpdes_PhobEnHB"))
                if term == TermKeywords.Hex_pairs:
                    for idx, row in enumerate(block):
                        if idx > 0:
                            half_phobEbPair = old_div(float(row[4]), 2)
                            XP_atoms_array.append((int(row[0]), half_phobEbPair,
                                                   "r_xpdes_PhobEnPairHB"))
                            XP_atoms_array.append((int(row[2]), half_phobEbPair,
                                                   "r_xpdes_PhobEnPairHB"))
                if term == TermKeywords.Hex_sp:
                    for idx, row in enumerate(block):
                        if idx > 0:
                            per_atom_score = old_div(elec_score, len(row))
                            for lig_atom in row:
                                XP_atoms_array.append(
                                    (int(lig_atom), per_atom_score,
                                     "r_xpdes_Electro"))
                if term == TermKeywords.Phob_pack:
                    for idx, row in enumerate(block):
                        if idx == 0:
                            num_phob_pack_sites = int(row[1])
                            atoms_row = True
                        else:
                            if num_phob_pack_sites >= 1 and atoms_row:
                                phob_lig_atoms = extract_Phob_pack_ligand_indices(
                                    row)
                                atoms_row = False
                            else:
                                new_atom_xp = [float(a) for a in row]
                                for idx, lig_atom in enumerate(phob_lig_atoms):
                                    XP_atoms_array.append(
                                        (int(lig_atom), new_atom_xp[idx],
                                         "r_xpdes_PhobEn"))
                                num_phob_pack_sites -= 1
            except:
                # no value for this term
                # FIXME need to catch a specific exception here
                pass
        return XP_atoms_array 
[docs]    def parseBlock_F(self, ctnum, st):
        """
        XP parsing for fragment scoring. Use st argument to determine the
        number of rings associated with a set of atoms. Given ctnum, get and
        parse XP block return an array of tuples containing energy type
        e.g. r_xpdes_HBond which will be later set as an atom-level property.
        Each entry = (atom number, XP score, energy type property)
        """
        # Extract the Electro score
        try:
            ct_level_terms = self.CT_level_terms_from_XPBlock[ctnum]
        except KeyError:
            raise KeyError("Failed to locate block for ligand %i" % (ctnum))
        elec_score = ct_level_terms[6]
        atom_data = []
        for (ct, term), block in self.xpdes_table.items():
            # Only consider blocks corresponding to the current ct
            if ct is not ctnum:
                continue
            rows = block[1:]
            if term == TermKeywords.H_bonds:
                for row in rows:
                    atom_index, _, value = row
                    atom_data.append((atom_index, value, "HBond"))
            elif term == TermKeywords.Phobcon_HB:
                for row in rows:
                    atom_index, _, value = row
                    atom_data.append((atom_index, value, "PhobEnHB"))
            elif term == TermKeywords.Hex_pairs:
                for row in rows:
                    index1, _, index2, _, energy = row
                    half_energy = old_div(float(energy), 2.0)
                    label = "PhobEnPairHB"
                    atom_data.append((index1, half_energy, label))
                    atom_data.append((index2, half_energy, label))
            elif term == TermKeywords.Hex_sp:
                for row in rows:
                    score = old_div(elec_score, len(row))
                    for atom_index in row:  # Row contains only atom indices
                        atom_data.append((atom_index, score, "Electro"))
            elif term == TermKeywords.Phob_pack:
                row1, row2 = rows
                atom_indices = extract_Phob_pack_ligand_indices(row1)
                for atom_index, score in zip(atom_indices, row2):
                    atom_data.append((atom_index, score, "PhobEn"))
        # Format as (atom index, property value, property name)
        atom_data = [(int(index), float(value), "r_xpdes_" + label)
                     for index, value, label in atom_data]
        return atom_data