"""
Module for using APBS to find electrostic potential on the protein surface to
analyze protein-protein (or protein-ligand) interactions. There are two types
of analyses that can be performed here:
(1) Electrostatic complementarity. The reference: J. Mol. Biol. (1997) 268, 570-584.
(2) Residual potential. The reference: Prot. Sci. (2001) 10, 362-377 and also the website:
    http://web.mit.edu/tidor/www/residual/description.html
Electrostatic complementarity (EC) defined in (1) provides a single quantity to
describe the interface complementarity, and it is extended here to assign a
quantity for each residue or each atom. Residual potential (RP) defined in (2)
focuses on the ligand design, providing a map of residual (non-ideal)
electrostatic potential on the ligand surface. It would be better
used as a visualization tool.
Example usage to get EC::
    ct = structure.Structure.read('1brs.maegz')
    # make sure force field is assigned to initialize the partial charge
    assign_ff(ct)
    lig_atoms = analyze.evaluate_asl(ct, 'chain.name D')
    ec = calc_total_complementarity(ct, lig_atoms) # the overal EC
    print(f"Overall EC: {ec}")
    pots_by_atoms = calc_complementarity_by_atom(ct, lig_atoms)
    # now get EC by residue
    for res in ct.residue:
        pots_by_res = {}
        for atom in res.atom:
            if atom in pots_by_atoms:
                pots_by_res[atom.index] = pots_by_atoms[atom.index]
        if pots_by_res:
            print("Residue EC: {res} {-1.0 * pearson_by_set(pots_by_res)}")
To get RP::
    jobname = 'test'
    rp = ResidualPotential(ct, lig_atoms, jobname = jobname)
    residual_potential = rp.getResidualPotential()
    # write out the surface and color it with residual potential
    rp.ligct.write(jobname+'.maegz')
    color_potential_surface(rp.ligsurf, residual_potential)
    rp.ligsurf.write(jobname+'_residual.vis')
    # also possible to visualize two components of residual potential
    inter_potential = rp.getInteractionPotential()
    color_potential_surface(rp.ligsurf, inter_potential)
    rp.ligsurf.write(jobname+'_inter.vis')
    desolv_potential = rp.getDesolvationPotential()
    color_potential_surface(rp.ligsurf, desolv_potential)
    rp.ligsurf.write(jobname+'_desolv.vis')
Or simply get electrostatic potential by APBS::
    pg = get_APBS_potential_grid(ct) # potential on the grid
    surf = surface.Surface.newMolecularSurface(ct, 'Surface')
    # potential on the surface points
    pots = pg.getSurfacePotential(surf.vertex_coords)
Copyright Schrodinger, LLC. All rights reserved.
"""
import os
import subprocess
from collections import defaultdict
from past.utils import old_div
import numpy as np
import psutil
from scipy.interpolate import RegularGridInterpolator
import schrodinger.utils.log as log
from schrodinger import surface
from schrodinger.forcefield import common as ffcommon
from schrodinger.infra import mm
from schrodinger.job import queue
OPLS_VERSION = 14
DIEL_PRO = '2.0'
DIEL_WATER = '78.0'
COARSE_BUFFER = 40.0
FINE_BUFFER = 20.0
POT_CUTOFF_POSITIVE = 5.0  # kT/e
POT_CUTOFF_NEGATIVE = -5.0  # kT/e
NON_BURIED_DIST = 1.5
logger = log.get_output_logger(name="escomp")
logger.setLevel(log.DEBUG)
[docs]def color_potential_surface(surf,
                            vertex_pots,
                            negative_cutoff=POT_CUTOFF_NEGATIVE,
                            positive_cutoff=POT_CUTOFF_POSITIVE):
    """
    Color the surface according to the potential at surface points. Red for negative
    potential, blue for positive potential.
    red (255, 0, 0) for negative, blue (0, 0, 255) for positive, white (255, 255, 255) for neutral
    :type surf: `Surface<schrodinger.surface.Surface>`
    :param surf: the input surface object
    :type vertex_pots: List of floats
    :param vertex_pots: the potential values on surface points
    :type negative_cutoff: float
    :param negative_cutoff: the cutoff value for negative potential coloring. Below
            this cutoff, surface will be colored pure red.
    :type positive_cutoff: float
    :param positive_cutoff: the cutoff value for positive potential coloring. Above
            this cutoff, surface will be colored pure blue.
    """
    vertex_colors = []
    for pot in vertex_pots:
        if pot < negative_cutoff:
            rgb = (1.0, 0.0, 0.0)
        elif pot < 0.0:
            ratio = pot / negative_cutoff
            gb = (1.0 - ratio)
            rgb = (1.0, gb, gb)
        elif pot < positive_cutoff:
            ratio = pot / positive_cutoff
            rg = (1.0 - ratio)
            rgb = (rg, rg, 1.0)
        else:
            rgb = (0, 0, 1.0)
        vertex_colors.append(rgb)
    surf.setColoring(vertex_colors) 
[docs]def assign_ff(ct, ff_version=OPLS_VERSION):
    """
    Assign force field to get the atom property "partial_charge"
    """
    ffcommon.generate_partial_charges(ct, ff_version) 
[docs]def get_center_gridlen(ct):
    """
    Compute the center and grid size for the input structure.
    :type ct: `Structure<schrodinger.structure.Structure>`
    :param ct: the input structure
    :rtype: two lists of floats: (x, y, z), (xlen, ylen, zlen)
    :return: center position, and size in three dimensions
    """
    np_xyz = ct.getXYZ()
    center = list(np.mean(np_xyz, axis=0))
    gridlen = []
    for column in [0, 1, 2]:
        m1 = np.min(np_xyz[:, column])
        m2 = np.max(np_xyz[:, column])
        gridlen.append(np.fabs(m2 - m1))
    return center, gridlen 
[docs]def get_APBS_potential_grid(ct,
                            center=None,
                            gridlen=None,
                            jobname='apbs_potgrid'):
    """
    Compute the APBS electrostatic potential on a 3D grid.
    The partial charge in the ct will be used in APBS calculation. So care should be taken
    before passing in CT if for example the ligand charge should be disabled. The vdW Radii
    are used to construct the molecular surface.
    :type ct: `Structure<schrodinger.structure.Structure>`
    :param ct: the input structure
    :type center: List of three floats
    :param center: the center of the grid
    :type gridlen: List of three floats
    :param gridlen: the grid size in three dimensions
    :type jobname: String
    :param jobname: the basename for temporary APBS files
    :rtype: Object of PotGrid class
    :return: the potential grid
    """
    pqr_file = jobname + '.pqr'
    write_pqr(ct, pqr_file)
    in_file = jobname + '.in'
    write_input(ct, in_file, pqr_file, jobname, center=center, gridlen=gridlen)
    out_file = jobname + '.out'
    cmd = [
        os.path.join(os.environ['SCHRODINGER'], "utilities", "apbs"), in_file,
        '--output-file=%s' % out_file
    ]
    log_file = in_file.replace(".in", ".log")
    with open(log_file, 'w') as log_fh:
        apbs_proc = subprocess.Popen(cmd,
                                     stderr=subprocess.STDOUT,
                                     stdout=log_fh,
                                     universal_newlines=True)
        null, err = apbs_proc.communicate()
        if apbs_proc.returncode:
            raise RuntimeError("APBS job did not end normally: %s", err)
    dx_file = in_file.replace('.in', '.dx')
    if not os.path.exists(dx_file):
        raise RuntimeError("Potential map file missing: %s", dx_file)
    potgrid = PotGrid(dx_file=dx_file)
    # clean up the job files
    for file in [pqr_file, in_file, log_file, out_file, dx_file]:
        os.remove(file)
    return potgrid 
[docs]def write_pqr(ct, filename):
    """
    Write the .PQR file for APBS job.
    :type ct: `Structure<schrodinger.structure.Structure>`
    :param ct: the input structure
    :type filename: String
    :param filename: PQR file name
    """
    atom_offset = 1
    res_name = "NA"
    atom_count = len(ct.atom)
    with open(filename, 'w') as f:
        for (i, atom) in enumerate(ct.atom):
            iatom = i + atom_offset
            line = "ATOM".ljust(6)[:6]
            line += "%5d" % iatom
            line += " "
            tstr = atom.property["s_m_pdb_atom_name"].replace(" ", "")
            if not tstr:
                tstr = atom.element + "%d" % atom.index
            tstr = tstr[:3]
            line += tstr.ljust(4)[:4]
            tstr = atom.property["s_m_pdb_residue_name"].replace(" ", "")
            if tstr:
                res_name = tstr
            if not tstr:
                tstr = res_name
            if len(tstr) == 4:
                line += tstr.ljust(4)[:4]  #
            else:
                line += " " + tstr.ljust(3)[:3]
            line += " "
            tstr = atom.property["s_m_chain_name"]
            line += tstr.ljust(1)[:1]
            tstr = "%d" % atom.property["i_m_residue_number"]
            line += tstr.rjust(4)[:4]
            if atom.inscode != "":
                line += "%s   " % atom.inscode
            else:
                line += "    "
            xyz_qr = "  %-9.3f %-9.3f %-9.3f %7.4f %7.4f\n" % (
                atom.x, atom.y, atom.z, atom.partial_charge, atom.radius)
            line += xyz_qr
            f.write(line) 
[docs]class PotGrid():
    """
    The container that holds the potential grid from APBS calculation.
    The potential has the unit of kT/e.
    """
[docs]    def __init__(self, **kwargs):
        """
        There are two ways to initialize the object: read from a DX file, or
        copy from an existing object with the option of using another 3D potential map.
        """
        if 'dx_file' in kwargs:
            self._read(kwargs['dx_file'])
        elif 'copy_from' in kwargs:
            pots = kwargs.get('pots', None)
            self._new(kwargs['copy_from'], pots=pots)
        else:
            raise RuntimeError('Wrong initialization method.') 
    def _read(self, dx_file):
        """
        Read the potential from a DX file and store the data.
        :type dx_file: String
        :param dx_file: the DX file from APBS calculation
        """
        with open(dx_file) as f:
            lines = f.readlines()
        templist = lines[4].strip().split()
        nz = int(templist[-1])
        ny = int(templist[-2])
        nx = int(templist[-3])
        templist = lines[5].strip().split()
        self.xmin = float(templist[1])
        self.ymin = float(templist[2])
        self.zmin = float(templist[3])
        self.hx = float(lines[6].strip().split()[1])
        self.hy = float(lines[7].strip().split()[2])
        self.hz = float(lines[8].strip().split()[3])
        potlist = []
        for i in range(11, len(lines)):
            if lines[i].startswith('attribute'):
                break
            templist = lines[i].strip().split()
            potlist.extend(templist)
        # sanity check
        if len(potlist) != nx * ny * nz:
            raise RuntimeError("Potential map incomplete: %d != %d * %d * %d" %
                               (len(potlist), nx, ny, nz))
        self.pots = np.zeros((nx, ny, nz), dtype=np.float)  # double precision
        i = 0
        for ix in range(nx):
            for iy in range(ny):
                for iz in range(nz):
                    self.pots[ix, iy, iz] = float(potlist[i])
                    i += 1
    def _new(self, pg, pots=None):
        """
        Create a new object from an existing object
        :type pg: PotGrid object
        :param pg: the existing PotGrid object
        :type pots: 3D Numpy array
        :param pots: the potential map on the grid
        """
        self.xmin = pg.xmin
        self.ymin = pg.ymin
        self.zmin = pg.zmin
        self.hx = pg.hx
        self.hy = pg.hy
        self.hz = pg.hz
        if pots is not None:
            self.pots = pots
        else:
            self.pots = np.copy(pg.pots)
[docs]    def getSurfacePotential(self, vertex_coords):
        """
        Interpolate the potential from the grid to the surface.
        :type vertex_coords: 2D Numpy array (N x 3)
        :param vertex_coords: list of coordinates of surface vertex points
        :rtype: 1D Numpy array (N)
        :return: list of potential values on surface points
        """
        nx, ny, nz = self.pots.shape
        xs = [self.xmin + ix * self.hx for ix in range(nx)]
        ys = [self.ymin + iy * self.hy for iy in range(ny)]
        zs = [self.zmin + iz * self.hz for iz in range(nz)]
        interpolator = RegularGridInterpolator(points=[xs, ys, zs],
                                               values=self.pots)
        return interpolator(vertex_coords)  
[docs]class ResidualPotential():
    """
    Calculator of the residual potential on the ligand surface. The two
    components of the residual potential, interaction potential and desolvation
    potential, can also be reported and visualized on the surface.
    """
[docs]    def __init__(self, ct, lig_atoms, jobname="residual"):
        """
        :type ct: `Structure<schrodinger.structure.Structure>`
        :param ct: the input complex structure
        :type lig_atoms: list or set
        :param atoms1: atom numbers that define the ligand
        :type jobname: String
        :param jobname: the basename for temporary APBS files
        """
        self.ligct = ct.extract(lig_atoms)
        # make sure the next three ABPS calculations use the same grid
        center, gridlen = get_center_gridlen(ct)
        # interaction potential - complex bound state potential due to receptor charge
        inct = ct.copy()
        for ia in lig_atoms:
            inct.atom[ia].partial_charge = 0.0
        self.pg1 = get_APBS_potential_grid(inct,
                                           center=center,
                                           gridlen=gridlen,
                                           jobname=jobname + '_recb')
        # desolvation potential part 1 - complex bound state potential due to ligand charge
        inct = ct.copy()
        rec_atoms = set(range(1, inct.atom_total + 1)) - set(lig_atoms)
        for ia in rec_atoms:
            inct.atom[ia].partial_charge = 0.0
        self.pg2 = get_APBS_potential_grid(inct,
                                           center=center,
                                           gridlen=gridlen,
                                           jobname=jobname + '_ligb')
        # desolvation potential part 2 - ligand only, unbound state potential
        self.pg3 = get_APBS_potential_grid(self.ligct,
                                           center=center,
                                           gridlen=gridlen,
                                           jobname=jobname + '_ligub')
        self.ligsurf = surface.Surface.newMolecularSurface(
            self.ligct,
            jobname + '_Ligand_Surface',
            mol_surf_type=surface.MolSurfType.molecular)
        self.ligsurf.setTransparency(30) 
[docs]    def getResidualPotential(self):
        """
        return the "residual potential" on the ligand surface.
        :rtype: 1D Numpy array
        :return: a list of potential values on ligand surface points
        """
        pg_residual = PotGrid(copy_from=self.pg1,
                              pots=self.pg1.pots + self.pg2.pots -
                              self.pg3.pots)
        return pg_residual.getSurfacePotential(self.ligsurf.vertex_coords) 
[docs]    def getInteractionPotential(self):
        """
        return the "interaction potential" on the ligand surface.
        :rtype: 1D Numpy array
        :return: a list of potential values on ligand surface points
        """
        return self.pg1.getSurfacePotential(self.ligsurf.vertex_coords) 
[docs]    def getDesolvationPotential(self):
        """
        return the "desolvation potential" on the ligand surface.
        :rtype: 1D Numpy array
        :return: a list of potential values on ligand surface points
        """
        pg_desolv = PotGrid(copy_from=self.pg2,
                            pots=self.pg2.pots - self.pg3.pots)
        return pg_desolv.getSurfacePotential(self.ligsurf.vertex_coords) 
[docs]    def getResidualPotentialByAtom(self):
        """
        return the residual potential grouped by atom.
        :rtype: Dict of lists
        :return: dict key is the ligand atom index in the ligand ct, dict value is a list of
                potential values on the surface points that belong to this atom.
        """
        rp_by_atom = defaultdict(list)
        residual_potential = self.getResidualPotential()
        for i in range(self.ligsurf.vertex_count):
            ia = self.ligsurf.nearest_atom_indices[i]
            pot = residual_potential[i]
            rp_by_atom[ia].append(pot)
        return rp_by_atom 
[docs]    def getResidualPotentialByResidue(self):
        """
        return the residual potential grouped by residue.
        :rtype: Dict of lists
        :return: dict key is the ligand residue string, dict value is a list of
                potential values on the surface points that belong to this residue.
        """
        rp_by_atom = self.getResidualPotentialByAtom()
        rp_by_residue = defaultdict(list)
        for res in self.ligct.residue:
            res_str = str(res)
            for atom in res.atom:
                rp_by_residue[res_str].extend(rp_by_atom[atom.index])
        return rp_by_residue  
[docs]def calc_total_complementarity(ct, atoms1, atoms2=None):
    """
    Return the total electrostatic complementarity between the specified surfaces.
    :type ct: structure._Structure object
    :param ct: Structure to which <atoms1> and <atoms2> are indices in.
    :type atoms1: Iterable of atom indices
    :param atoms1: Atom numbers from the surface for which to calculate the
            complementrairity.
    :type atoms2: Iterable of atom indices
    :param atoms2: Atom numbers for the other surface. if not specified, use
        all other atoms from the CT.
    :rtype: float
    :return: the electrostatic complementarity between the 2 surfaces.
    """
    pots_by_set1, pots_by_set2 = calc_complementarity(ct, atoms1, atoms2)
    corr1 = pearson_by_set(pots_by_set1)
    corr2 = pearson_by_set(pots_by_set2)
    corr = -0.5 * (corr1 + corr2)
    return corr 
[docs]def calc_complementarity_by_atom(ct, atoms1, atoms2=None):
    """
    Return the pairs of potential values used for calculating electrostatic complementarity
    between the specified surfaces, grouped by atom, in one dict.
    :type ct: structure._Structure object
    :param ct: Structure to which <atoms1> and <atoms2> are indices in.
    :type atoms1: Iterable of atom indices
    :param atoms1: Atom numbers from the surface for which to calculate the
        complementrairity.
    :type atoms2: Iterable of atom indices
    :param atoms2: Atom numbers for the other surface. if not specified, use
        all other atoms from the CT.
    :rtype: dict of lists.
    :return: dict key is the index of the atom from the given list, dict value
        is a list of potential pairs on the surface points that belong to the
        buried surface of this atom. The correlation between the pair of
        potentials on one atom will give the complementarity measurement of
        that atom. Similarly, the correlation between the pair of potentials on
        one residue will give the complementarity of that residue, etc.
    """
    pots_by_set1, pots_by_set2 = calc_complementarity(ct, atoms1, atoms2)
    pots_by_set1.update(pots_by_set2)
    return pots_by_set1 
[docs]def calc_complementarity(ct, atoms1, atoms2=None):
    """
    Return the pairs of potential values used for calculating electrostatic complementarity
    between the specified surfaces, grouped by atom, in two dicts.
    :type ct: structure._Structure object
    :param ct: Structure to which <atoms1> and <atoms2> are indices in.
    :type atoms1: Iterable of atom indices
    :param atoms1: Atom numbers from the surface for which to calculate the
        complementarity.
    :type atoms2: Iterable of atom indices
    :param atoms2: Atom numbers for the other surface. if not specified, use
        all other atoms from the CT.
    :rtype: two dicts of lists.
    :return: Each dict corresponds to one of atom sets <atoms1> and <atoms2>.
        For each dict, dict key is the index of the atom from the given list,
        dict value is a list of potential pairs on the surface points that
        belong to the buried surface of this atom. The correlation between the
        pair of potentials on one atom will give the complementarity
        measurement of that atom. Similarly, the correlation between the pair
        of potentials on one residue will give the complementarity of that
        residue, etc.
    """
    if atoms2 is None:
        atoms2 = set(range(1, ct.atom_total + 1)) - set(atoms1)
        if not atoms2:
            raise ValueError("The <atoms1> list matched all atoms in CT")
    inct = ct.copy()
    for ia in atoms2:
        inct.atom[ia].partial_charge = 0.0
    pg1 = get_APBS_potential_grid(inct)
    inct = ct.copy()
    for ia in atoms1:
        inct.atom[ia].partial_charge = 0.0
    pg2 = get_APBS_potential_grid(inct)
    prot1_ct = ct.copy()
    del_atoms = [
        a for a in range(1, prot1_ct.atom_total + 1) if a not in atoms1
    ]
    prot1_renumber_dict = prot1_ct.deleteAtoms(del_atoms, renumber_map=True)
    reverse_prot1_renumber_dict = {
        value: key for key, value in prot1_renumber_dict.items()
    }
    prot2_ct = ct.copy()
    del_atoms = [
        a for a in range(1, prot2_ct.atom_total + 1) if a not in atoms2
    ]
    prot2_renumber_dict = prot2_ct.deleteAtoms(del_atoms, renumber_map=True)
    reverse_prot2_renumber_dict = {
        value: key for key, value in prot2_renumber_dict.items()
    }
    # create the surfaces
    total_surf = surface.Surface.newMolecularSurface(
        ct, 'Total_Surface', mol_surf_type=surface.MolSurfType.molecular)
    prot1_surf = surface.Surface.newMolecularSurface(
        prot1_ct, 'Prot1_Surface', mol_surf_type=surface.MolSurfType.molecular)
    prot2_surf = surface.Surface.newMolecularSurface(
        prot2_ct, 'Prot2_Surface', mol_surf_type=surface.MolSurfType.molecular)
    # prepare to get buried surface points
    nb_coords = total_surf.vertex_coords.astype(
        np.float32)  # single precision required
    nb_cell_handle = mm.mmct_create_distance_cell_xyz2(nb_coords,
                                                       NON_BURIED_DIST)
    # get the potential on buried surface points, grouped by atoms
    pots_by_atoms1 = _buried_surface_pots(prot1_surf, nb_cell_handle, pg1, pg2)
    pots_by_atoms2 = _buried_surface_pots(prot2_surf, nb_cell_handle, pg1, pg2)
    # convert to original atom indices
    pots_by_atoms1 = {
        reverse_prot1_renumber_dict[a]: v
        for a, v in pots_by_atoms1.items()
        if any(v)
    }
    pots_by_atoms2 = {
        reverse_prot2_renumber_dict[a]: v
        for a, v in pots_by_atoms2.items()
        if any(v)
    }
    return pots_by_atoms1, pots_by_atoms2 
def _buried_surface_pots(prot_surface, nb_cell_handle, pg1, pg2):
    """
    Return the pair of potentials on the buried surface points, grouped by atoms.
    :type prot_surface: `Surface<schrodinger.surface.Surface>`
    :param prot_surface: the surface from <atoms1> or <atoms2>
    :type nb_cell_handle: Object of distance cell
    :param nb_cell_handle: the NB distance cell created from the total surface of the complex. This
            will be used to check if a surface point from a binding parter surface is buried in the complex surface.
    :type pg1: Object of PotGrid
    :param pg1: the potential grid resulting only from the charges from <atoms1>
    :type pg2: Object of PotGrid
    :param pg2: the potential grid resulting only from the charges from <atoms2>
    :rtype: dict of lists.
    :return: dict key is the atom index, dict value is a list of potential pairs on the
            surface points that belong to the buried surface of this atom. The correlation between
            the pair of potentials measures the complementarity.
    """
    buried_indices = []
    buried_coords = []
    for i, coord in enumerate(prot_surface.vertex_coords):
        if mm.mmct_query_distance_cell_count(nb_cell_handle, *coord) == 0:
            buried_indices.append(i)
            buried_coords.append(coord)
    buried_pots_1 = pg1.getSurfacePotential(buried_coords)
    buried_pots_2 = pg2.getSurfacePotential(buried_coords)
    pots_by_atoms = defaultdict(list)
    for i in range(len(buried_indices)):
        index = buried_indices[i]
        pot1 = buried_pots_1[i]
        pot2 = buried_pots_2[i]
        nearest_atom = prot_surface.nearest_atom_indices[index]
        pots_by_atoms[nearest_atom].append((pot1, pot2))
    return pots_by_atoms
[docs]def pearson_by_set(pots_by_set):
    """
    Compute Pearson Correlation Coefficient for the pair of surface potentials for a set of atoms.
    :type pots_by_set: Dict of lists
    :param pots_by_set: the pair of surface potentials for a set of atoms. Dict key is atom index,
            dict value is a list of potential pairs on the buried surface points of that atom.
    :rtype: float
    :return: Pearson Correlation Coefficient
    """
    x = []
    y = []
    for pots_by_atom in pots_by_set.values():
        for (p1, p2) in pots_by_atom:
            x.append(p1)
            y.append(p2)
    return np.corrcoef(x, y)[0, 1]