"""
The script truncates protein beyond truncate_distance from ligand atoms, restrain remaining protein heavy atoms beyond restrain_distance from ligand atoms, and solvate it by solvent_buffer distance. The truncation is done by residue based ASL.
Copyright Schrodinger, LLC. All rights reserved.
"""
# Contributors: Byungchan Kim
import math
import os
import sys
import numpy
import schrodinger.application.desmond.system_builder_inp as system_builder_inp
import schrodinger.structure as structure
import schrodinger.structutils.transform as transform
import schrodinger.utils.cmdline as cmdline
from schrodinger.application.desmond import ffiostructure
from schrodinger.application.desmond.constants import CT_TYPE
from schrodinger.structutils import analyze
from schrodinger.structutils import build
[docs]class DesmondBoxSize:
[docs]    def __init__(self, **kwargs):
        self.use_buffer = True
        self.all_equal = True
        self.a = 5.0
        self.b = 5.0
        self.c = 5.0
        self.alpha = 90.0
        self.beta = 90.0
        self.gamma = 90.0
        self.rotational_matrix = transform.get_rotation_matrix_from_eulers(
            0, 0, 0)
        self.update(**kwargs) 
[docs]    def update(self, **kwargs):
        if 'use_buffer' in kwargs:
            self.use_buffer = kwargs['use_buffer']
        if 'all_equal' in kwargs:
            self.all_equal = kwargs['all_equal']
        if 'a' in kwargs:
            self.a = kwargs['a']
        if 'b' in kwargs:
            self.b = kwargs['b']
        if 'c' in kwargs:
            self.c = kwargs['c']
        if 'alpha' in kwargs:
            self.alpha = kwargs['alpha']
        if 'beta' in kwargs:
            self.beta = kwargs['beta']
        if 'gamma' in kwargs:
            self.gamma = kwargs['gamma'] 
[docs]    def getStructureSize(self, st):
        """
        Returns the a, b, & c absolute coordinates, even when
        the user specified buffer distances
        Returns None if there are no atoms in the Workspace
        """
        atom_xyz_array = st.getXYZ(copy=False)
        xmin, ymin, zmin = numpy.min(atom_xyz_array, 0)
        xmax, ymax, zmax = numpy.max(atom_xyz_array, 0)
        a = (xmax - xmin)
        b = (ymax - ymin)
        c = (zmax - zmin)
        return (a, b, c) 
[docs]    def getBoxVectors(self, st):
        """
        Returns the vectors representing the box.
        Origin is the back face bottom left.
        Returns None on error (after displaying a dialog box)
        """
        # Here a, b, and c represent absolute box size
        system_dimension = self.getStructureSize(st)
        if self.use_buffer and self.all_equal:
            max_length = max(system_dimension)
            a = self.a * 2.0 + max_length + 1.0
            c = b = a
        elif self.use_buffer:
            a = self.a * 2.0 + system_dimension[0] + 1.0
            b = self.b * 2.0 + system_dimension[1] + 1.0
            c = self.c * 2.0 + system_dimension[2] + 1.0
        elif self.all_equal:
            c = b = a = self.a
        else:
            a = self.a
            b = self.b
            c = self.c
        alpha = math.radians(self.alpha)
        beta = math.radians(self.beta)
        gamma = math.radians(self.gamma)
        cosa = math.cos(alpha)
        cosb = math.cos(beta)
        cosg = math.cos(gamma)
        sing = math.sin(gamma)
        tmp = 1 - cosa * cosa - cosb * cosb - cosg * cosg + 2 * cosa * cosb * cosg
        if tmp <= 0:
            msg = "ERROR: alpha=%f, beta=%f, gamma=%f are inconsistent." \
                
% (self.alpha, self.beta, self.gamma)
            raise FloatingPointError(msg)
        box_vectors = numpy.array(
            [[a, b * cosg, c * cosb],
             [0, b * sing, c * (cosa - cosb * cosg) / sing],
             [0, 0, c * math.sqrt(tmp) / sing]])
        return box_vectors 
[docs]    def calculateVolume(self, st):
        """
        Will return 0 on error (after showing dialog)
        """
        box_vectors = self.getBoxVectors(st)
        if box_vectors is None:
            # Error dialog was already shown
            return 0
        col1 = box_vectors[:, 0]
        col2 = box_vectors[:, 1]
        col3 = box_vectors[:, 2]
        return math.fabs(numpy.dot(col1, (numpy.cross(col2, col3)))) 
[docs]    def findMinVolume(self, st):
        transform.translate_centroid_to_origin(st)
        bin = 60.0 / 360.0 * 2 * math.pi
        # Calculate the original volume:
        minvol = self.calculateVolume(st)
        if minvol == 0:
            return
        min_st = st
        for i in range(2, 7, 2):
            for j in range(2, 7, 2):
                for k in range(2, 7, 2):
                    phi = bin * i
                    theta = bin * j / 2.0
                    psi = bin * k
                    rot_matrix = transform.get_rotation_matrix_from_eulers(
                        phi, theta, psi)
                    temp_st = st.copy()
                    transform.transform_structure(temp_st, rot_matrix)
                    volume = self.calculateVolume(temp_st)
                    if volume < minvol:
                        minvol = volume
                        min_st = temp_st
                        self.rotational_matrix = rot_matrix
        st = min_st
        return minvol 
[docs]    @staticmethod
    def translateCentroidToOrigin(strucs, skip_solvent=True, solvent_asl=None):
        """
        Re-zeros `strucs`.
        :param skip_solvent: Include solvent positions into centroid.
        :type skip_solvent: bool
        :param solvent_asl: Solvent ASL (relevant for truthy `skip_solvent`).
        :type solvent_asl: str or NoneType
        """
        temp_st = structure.create_new_structure()
        for st in strucs:
            temp_st.extend(st)
        if skip_solvent:
            solvent_atoms = \
                
system_builder_inp.SystemBuilderInput.identifySolventAtoms(
                    temp_st, solvent_asl)
            temp_st.deleteAtoms(solvent_atoms)
        center = transform.get_centroid(temp_st)
        for st in strucs:
            transform.translate_structure(st, -center[0], -center[1],
                                          -center[2]) 
[docs]    def minimizeVolume(self, strucs):
        temp_st = structure.create_new_structure()
        for st in strucs:
            temp_st.extend(st)
        print("  Initial volume: ", self.calculateVolume(temp_st))
        min_vol = self.findMinVolume(temp_st)
        print("  Final volume:   ", min_vol)
        for st in strucs:
            transform.transform_structure(st, self.rotational_matrix)
        self.translateCentroidToOrigin(strucs)
        return min_vol  
[docs]def truncateProtein(protein_st,
                    ligand_st,
                    retain_ligand=False,
                    truncate_distance=0,
                    restrain_distance=-1):
    """
    :param restrain_distance: -1 means no restrain, 0 means restrain all atoms
    """
    prot_num = protein_st.atom_total
    truncated_st = protein_st.copy()
    truncated_st.extend(ligand_st)
    # Treat boundary atoms
    for a in truncated_st.atom:
        a.property["i_old_bond_total"] = a.bond_total
    if restrain_distance == 0.0:
        asl = 'all'
    elif restrain_distance > 0.0:
        asl = '(not (fillres within %d (atom.num > %d)))' % (restrain_distance,
                                                             prot_num)
    else:
        asl = 'not (all)'
    restrained_atoms = analyze.evaluate_asl(truncated_st, asl)
    for i in restrained_atoms:
        truncated_st.atom[i].property['i_ffio_restraint'] = 1
    # Truncate protein with residue based ASL
    truncate_asl = ''
    if truncate_distance > 0.0:
        truncate_asl = 'not (fillres within %d (atom.num > %d))' % (
            truncate_distance, prot_num)
    if truncate_asl:
        truncate_asl += '  or '
    truncate_asl += 'atom.num > %d' % prot_num
    if truncate_asl:
        deleted_atoms = analyze.evaluate_asl(truncated_st, truncate_asl)
        truncated_st.deleteAtoms(deleted_atoms)
    # Add hydrogens only if interface residues
    atom_list = []
    for a in truncated_st.atom:
        if a.property["i_old_bond_total"] != a.bond_total:
            atom_list.append(a.index)
        del a.property["i_old_bond_total"]
    build.add_hydrogens(truncated_st, "All-atom with No-Lp", atom_list)
    # keep ligand atoms to the end of solute CT
    if retain_ligand:
        truncated_st.extend(ligand_st)
    return truncated_st 
[docs]def truncate_solvate_protein(opt):
    if not os.path.exists(opt.protein):
        print('%s does not exists.' % opt.protein)
        print('Use -protein to specify a protein file name.')
        sys.exit(1)
    if not os.path.exists(opt.ligand):
        print('Ligand, %s,  does not exists.' % opt.ligand)
        print('Use -ligand to specify a ligand file name.')
        sys.exit(1)
    protein_st = None
    for st in structure.StructureReader(opt.protein):
        if protein_st:
            protein_st.extend(st)
        else:
            protein_st = st
    ligand_st = None
    for st in structure.StructureReader(opt.ligand):
        if ligand_st:
            ligand_st.extend(st)
        else:
            ligand_st = st
    # Truncate protein
    truncated_st = truncateProtein(protein_st, ligand_st, opt.retain_ligand,
                                   opt.truncate_distance, opt.restrain_distance)
    desmondbox = DesmondBoxSize(use_buffer=True, all_equal=True, a=5.0)
    print(desmondbox.calculateVolume(truncated_st))
    print(desmondbox.getStructureSize(truncated_st))
    print(desmondbox.findMinVolume(truncated_st))
    print(desmondbox.calculateVolume(truncated_st))
    print(desmondbox.getStructureSize(truncated_st))
    jobname = opt.jobname
    if not jobname:
        jobname = 'desmond_setup'
    solute_fname = jobname + '-in.mae'
    # Solvate using system builder
    truncated_st.write(solute_fname)
    sys_build_inp = system_builder_inp.SystemBuilderInput()
    sys_build_inp.setSolute(solute_fname)
    sys_build_inp.setBoundaryCondition(
        boundary_condition=opt.boundary_condition, a=5.0)
    sys_build_inp.run(jobname, ['-WAIT'])
    cms_fname = jobname + '-out.cms'
    if not os.path.exists(cms_fname):
        print('%s does not exists.' % cms_fname)
        print('Please check "%s.log" file.' % jobname)
        sys.exit(1)
    # Add restraints
    temp_fname = jobname + '-temp.cms'
    if os.path.exists(temp_fname):
        os.remove(temp_fname)
    for st in ffiostructure.CMSReader(cms_fname):
        if st.property[CT_TYPE] == CT_TYPE.VAL.SOLUTE:
            restrained_atoms = []
            for a in st.atom:
                if 'i_ffio_restraint' in a.property and a.atomic_number > 1:
                    restrained_atoms.append(a.index)
            st.ffio.addRestraints(len(restrained_atoms))
            for i in range(len(restrained_atoms)):
                ai = restrained_atoms[i]
                r = st.ffio.restraint[i + 1]
                r.ai = ai
                r.funct = 'harm'
                r.c1 = 5.0
                r.c2 = 5.0
                r.c3 = 5.0
                r.t1 = st.atom[ai].x
                r.t2 = st.atom[ai].y
                r.t3 = st.atom[ai].z
        for a in st.atom:
            if 'i_ffio_restraint' in a.property:
                del a.property['i_ffio_restraint']
        st.append(temp_fname, format='CMS')
    os.rename(temp_fname, cms_fname) 
[docs]def find_equivalent_st(my_st, st_list, pname):
    """
    Find a equivalent st of my_st from iterable st_list
    Copy the pname values to 'r_ffio_custom_charge' of my_st
    """
    import pymmlibs
    from schrodinger.infra import mmlist
    map_list = []
    for st in st_list:
        (status, equiv_ct,
         atom_map) = pymmlibs.mmstereo_find_ct_eq_and_map(my_st, st, 1)
        if equiv_ct:
            map_list = mmlist._mmlist_to_pylist(atom_map)
            for i, a in enumerate(my_st.atom):
                iatom = map_list[i]
                try:
                    a.property['r_ffio_custom_charge'] = st.atom[
                        iatom].property[pname]
                except Exception:
                    raise
            return
    print('No matching CT found in a reference file')
    sys.exit(1) 
if (__name__ == "__main__"):
    usage = '''
  %prog [options]
Description:
  %prog is a tool to solvate binding site.
Example:
  %prog -protein protein.mae -ligand ligand.mae
'''
    parser = cmdline.SingleDashOptionParser(usage)
    parser.add_option('-JOBNAME', dest='jobname', default='', help='JOBNAME')
    parser.add_option('-HOST', dest='host', default='', help='host name')
    parser.add_option('-protein',
                      dest='protein',
                      default='',
                      help='protein file name')
    parser.add_option('-ligand',
                      dest='ligand',
                      default='',
                      help='ligand file name')
    parser.add_option('-retain_ligand',
                      action='store_true',
                      dest='retain_ligand',
                      default=False,
                      help='retain ligand')
    parser.add_option('-truncate_distance',
                      dest='truncate_distance',
                      default=12.0,
                      help='truncate distance')
    parser.add_option('-restrain_distance',
                      dest='restrain_distance',
                      default=8.0,
                      help='restrain distance')
    parser.add_option('-boundary_condition',
                      dest='boundary_condition',
                      default='truncated_octahedron',
                      help='boundary condition')
    parser.add_option('-solvent_buffer',
                      dest='solvent_buffer',
                      default=5.0,
                      help='solvent buffer')
    options, args = parser.parse_args()
    truncate_solvate_protein(options)