"""
Functions and classes for Prime molecule preparation.
It is deprecated as starting with Suite2008, Prime no longer needs structure
"pre-fixing".
The wrapper for this module is in: `$SCHRODINGER/mmshare-v*/python/scripts/primefix.py`
Copyright Schrodinger, LLC. All rights reserved.
"""
# Contributors: Matvey Adzhigirey, Woody Sherman
import sys
import schrodinger.protein.findhets as findhets
import schrodinger.structure as structure
import schrodinger.utils.log as log
logger = log.get_output_logger("schrodinger.application.prime.primefix")
logger.setLevel(log.WARNING)
#logger.setLevel(log.DEBUG)
# Het residue names that are to be excluded from X01, X02... renaming
excluded_hets = ["NAD", "NAG", "SO4"]
alkali_metals = [3, 11, 19, 37, 55, 87]
alkaline_earth_metals = [4, 12, 20, 38, 56, 88]
lanthanide_series = [
    21, 39, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71
]
actinide_series = [
    89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103
]
transition_metals = [
    22, 23, 24, 25, 26, 27, 28, 29, 30, 40, 41, 42, 43, 44, 45, 46, 47, 48, 72,
    73, 74, 75, 76, 77, 78, 79, 80, 104, 105, 106, 107
]
metalloids = [32, 33, 51, 52, 84, 85]
other_metals = [13, 31, 49, 50, 81, 82, 83]
all_metals = alkali_metals + alkaline_earth_metals + lanthanide_series + actinide_series + transition_metals + metalloids + other_metals
# Dictionary of residue names to give to the metals after breaking bonds:
metal_pdbnames = {
    11: ' NA ',  # NA Sodium  +1
    12: ' MG ',  # MG Magnesium +2
    19: '  K ',  # K  Potassium +1
    20: ' CA ',  # CA Calcium +2
    25: ' MN ',  # MN Manganese +2, +3, +4, +5, +6, +7
    26: ' FE ',  # FE Iron +2 +3
    27: ' CO ',  # CO Cobalt +2 +3
    28: ' NI ',  # NI Nickel +2 +3
    29: ' CU ',  # CU Copper +1 +2 +3
    30: ' ZN ',  # ZN Zinc +2
    48: ' CD ',  # CD Cadmium +2 No MacroModel parameters
    80: ' HG ',  # HG Mercury +1 +2  No MacroModel parameters
}
# Dictionary that defines possible ionization states for metals:
# Key is the atomic number.
metal_charges = {
    11: [1],  # NA Sodium  +1
    12: [2],  # MG Magnesium +2
    19: [1],  # K  Potassium +1
    20: [2],  # CA Calcium +2
    25: [2, 3, 4, 5, 6, 7],  # MN Manganese +2, +3, +4, +5, +6, +7
    26: [2, 3],  # FE Iron +2 +3
    27: [2, 3],  # CO Cobalt +2 +3
    28: [2, 3],  # NI Nickel +2 +3
    29: [1, 2, 3],  # CU Copper +1 +2 +3
    30: [2]  # ZN Zinc +2
    #   48:[2]      CD Cadmium +2 No MacroModel parameters
    #   80:[1, 2]   HG Mercury +1 +2  No MacroModel parameters
}
metal_default_charges = {
    11: +1,  # NA Sodium
    12: +2,  # MG Magnesium
    19: +1,  # K
    20: +2,  # CA
    25: +2,  # MN Manganese
    26: +2,  # FE Iron except in FeS clusters +3
    27: +2,  # CO Cobalt
    28: +2,  # NI Nickel
    29: +3,  # CU Copper
    30: +2,  # ZN Zinc
    48: +2,  # CD Cadmium
    80: +2,  # HG Mercury
}
###############################################################################
[docs]class FixForPrime:
    """
    Wrapper class to house the combination of methods for fixing a
    structure for Prime.
    Raises RuntimeError if it can't determine a chain name, or general
    Exception if can't name a hydrogen because of PDB atom name field
    size restrictions.
    Call instance.fixAll() after creating the instance to drive all
    fixing methods.
    """
[docs]    def __init__(self, st):
        self.st = st
        self.reassign_resnum = {
        }  # atoms to reassign resnums of (key: atom number, value: any)
        self.chains_used = [] 
[docs]    def fixSpecialResidues(self):
        """
        Assigns PDB atom names, and charges for residue names:
        HOH, DOD, TIP, SPS
        ASP, GLU
        NMA, NME
        ACE
        HEM, HEC
        Changes residue name from DOD, TIP, or SPC to HOH
        """
        changed_waters = []
        for res in self.st.residue:
            pdbres = res.pdbres.strip()
            # According to specification PDB residue names should always contain 3 characters
            # PDB atom names, on the other hand, are 4 characters long
            # Fix waters - resname to 'HOH' and fix pdb atoms names
            if pdbres in ['HOH', 'DOD', 'TIP', 'SPC']:
                changed_water = False
                for a in res.atom:
                    if pdbres != 'HOH':
                        a.pdbres = 'HOH'
                        changed_water = True
                    if a.chain == " ":
                        a.chain = "W"
                        changed_water = True
                    if a.element == "O" and a.pdbname != ' O  ':
                        a.pdbname = ' O  '
                        changed_water = True
                    elif a.pdbname == " H1 ":
                        a.pdbname = "1H  "
                        changed_water = True
                    elif a.pdbname == " H2 ":
                        a.pdbname = "2H  "
                        changed_water = True
                if changed_water:
                    changed_waters.append(res.resnum)
                # Remove Asp/Glu formal charge from the double bonded oxygen:
            elif pdbres == "ASP" or pdbres == "GLU":
                for a in res.atom:
                    if a.pdbname in [' OE1', ' OE2']:  # check atom name
                        if len(a.bond
                              ) == 1:  # Only one neighbor - no hydrogen bound
                            if a.bond[1].order == 2:  # double bonded
                                a.formal_charge = 0
                                msg = "Atom %i(O) of %s: set charge to %i" % (
                                    int(a), pdbres, 0)
                                logger.debug(msg)
                            else:  # single bonded
                                a.formal_charge = -1
                                msg = "Atom %i(O) of %s: set charge to %i" % (
                                    int(a), pdbres, 0)
                                logger.debug(msg)
                # Fix terminal caps
            elif pdbres == 'NMA' or pdbres == 'NME':
                for a in res.atom:
                    a.pdbres = 'NMA'
                    if a.pdbname == ' C  ':  # check atom name
                        a.pdbname = ' CA '
                        msg = "Atom %i(C) of NMA: set pdbname to ' CA '" % int(
                            a)
                        logger.debug(msg)
                        break
            elif pdbres == 'ACE':
                for a in res.atom:
                    if a.pdbname == ' CA ':  # check atom name
                        a.pdbname = ' CH3'
                        msg = "Atom %i(C) of ACE: set pdbname to ' CH3'" % int(
                            a)
                        logger.debug(msg)
                        break
            # Fix charges on HEM Nitrogens that were bonded to Fe.
            # Also change the PDB residue name of Iron to ' FE '.
            # Make Fe atoms of Hemes orange (so that prepwizard recognizes them as a hets)
            # At this point the bonds to FE are already broken
            elif pdbres in ['HEM', 'HEC']:
                for a in res.atom:
                    if a.pdbname in [' FE ', 'FE  ']:  # Iron
                        a.pdbres = ' FE '
                        self.reassign_resnum[int(
                            a)] = None  # Mark for resnum to be reassigned
                        a.property['i_m_pdb_convert_problem'] = 4  # Mark orange
                        msg = "Atom %i(Fe) of %s: set pdbres to ' FE '" % (
                            int(a), pdbres)
                        logger.debug(msg)
                    elif a.pdbname in [' N B',
                                       ' N D']:  # Nitrogens needing 0 charge
                        a.formal_charge = 0
                        # a.retype() # Reset the MMod atom type
                        msg = "Atom %i(N) of %s: set charge to 0" % (int(a),
                                                                     pdbres)
                        logger.debug(msg)
                    elif a.pdbname in [' N A',
                                       ' N C']:  # Nitrogens needing -1 charge
                        a.formal_charge = -1
                        # a.retype() # Reset the MMod atom type
                        msg = "Atom %i(N) of %s: set charge to -1" % (int(a),
                                                                      pdbres)
                        logger.debug(msg)
        if changed_waters:
            msg = "Water residues %s: Changed chain:pdbres to W:HOH" % changed_waters
            logger.debug(msg) 
[docs]    def fixDNACaps(self):
        """
        Prime atom parameters need to be different for terminal DNA bases
        than the atom parameters for the DNA bases that are in the center
        of the string. This is accomplished by giving terminal bases a new
        residue name. Ev:60484
        XHL is the 5' cap (-OH bound to the phosphate of the base)
        POT is the 3' cap (a phosphate cap)
        Protocol: Find all XHL&POT caps and rename the cap and the residue bound to it
        """
        cap_base_dict = {}  # key: cap residue; value: base residue
        for cap_res in self.st.residue:
            cap_pdbres = cap_res.pdbres.strip()
            if cap_pdbres in ['HXL', 'POT']:
                cap_done = False
                for atom in cap_res.atom:
                    if cap_done:
                        break
                    for neighbor in atom.bonded_atoms:
                        base_pdbres = neighbor.pdbres.strip()
                        if base_pdbres not in ['HXL', 'POT']:
                            # Atom is part of the neighboring base:
                            if len(base_pdbres) != 1:
                                print(
                                    'WARNING: Invalid residue (%s) bound to %s'
                                    % (base_pdbres, cap_pdbres))
                            else:
                                base_res = neighbor.getResidue()
                                cap_base_dict[cap_res] = base_res
                            cap_done = True
                            break
        # Do the actual retyping in a separate loop:
        for cap_res, base_res in cap_base_dict.items():
            base_pdbres = base_res.pdbres.strip()
            if base_pdbres == 'HXL':
                new_pdbres = base_pdbres + '5T'
            else:  # POT
                new_pdbres = base_pdbres + '3T'
            msg = "Residues %s[%s] and %s[%s] " % (
                base_res, base_pdbres, cap_res, cap_res.pdbres.strip())
            base_res.pdbres = new_pdbres
            cap_res.pdbres = new_pdbres
            cap_res.resnum = base_res.resnum
            cap_res.chain = base_res.chain  # Just in case
            msg += "merged into a new residue %s[%s]" % (cap_res, new_pdbres)
            logger.debug(msg) 
[docs]    def reassignHets(self, het_atom_groups):
        """
        Assign chain, pdbres, and resnum to het groups.  Raises
        RuntimeError if it can't determine a chain name, or general
        Exception if can't name a hydrogen because of PDB atom name
        field size restrictions.
        """
        # Determine an available chain:
        hetchain = None  # or if taken, then 'Y', 'Z', 'W', 'V' back...
        preferred_het_chain = ['X', 'Y', 'Z', 'W', 'V', 'U', 'T', 'S', 'R', 'Q']
        for c in preferred_het_chain:
            if c not in self.chains_used:
                hetchain = c
                break
        if not hetchain:
            raise RuntimeError(
                'Could not find next available chain name for het groups')
        # Iterate over het groups:
        hetnum = 0
        for het_atoms in het_atom_groups:
            hetnum += 1
            #print "HET %i:" % hetnum, het_atoms
            unique_counter = {}
            # Iterate over atoms in a het group:
            for ai in het_atoms:
                a = self.st.atom[ai]
                # Set resname to X01, X02, etc:
                if hetnum < 10:
                    pdbres = 'X0%i' % hetnum
                else:
                    pdbres = 'X%i' % hetnum
                # Back up the original residue debugrmation:
                if not a.property.get('s_ppw_pdbres_bu'
                                     ):  # first time running fix_for_prime()
                    a.property['s_ppw_pdbres_bu'] = a.pdbres
                    a.property['s_ppw_chain_bu'] = a.chain
                    a.property['i_ppw_resnum_bu'] = a.resnum
                # Set the new residue information:
                a.chain = hetchain
                a.pdbres = pdbres
                a.resnum = hetnum
                #### Set unique PDB name for each atom in this residue ###
                # set unique_counter to 0 if it does not yet exist for a given element
                e = a.element
                if e in unique_counter:
                    unique_counter[e] += 1
                else:
                    unique_counter[e] = 1
                num = unique_counter[e]
                # assign new pdbname to each atom of a residue
                if num >= 100:
                    if e == 'H' and num < 1000:
                        new_pdbatom = "%iH%i" % (num[0], num[1:2])
                    else:  # Non-hydrogen with >= 100 OR H with >= 1000 atoms:
                        raise Exception(
                            "Problem setting unique PDB atom name; # of element [%s] exeeds limits"
                            % e)
                else:
                    new_pdbatom = "%2s%-2s" % (e.upper(), str(num))
                a.pdbname = new_pdbatom
            msg = "Atoms %s: assigned a new residue of %s:%i[%s]" % (
                het_atoms, hetchain, hetnum, pdbres)
            logger.debug(msg) 
[docs]    def reassignResNums(self):
        """
        Assigns new residue numbers as needed, making sure a residue
        number only appears in a chain once.
        self.reassign_resnum currently contains only a list of metals to which bonds were broken.
        """
        # Find all residues that are in the structure that do NOT need to be reassigned:
        resnums_used = {
        }  # key: chain; value: dict of [key resnum, value True if used]
        for a in self.st.atom:
            if not int(a) in self.reassign_resnum:
                chain = a.chain
                if chain not in resnums_used:  # First time we see this chain
                    resnums_used[chain] = {}
                resnums_used[chain][a.resnum] = True
        # Assign new residue number to atoms that were put into reassign_resnum
        # dict earlier. A residue number can occur in each chain only once.
        for ai in self.reassign_resnum:
            a = self.st.atom[ai]
            chain = a.chain
            # Find the first unused resnum in the chain:
            if chain in resnums_used:
                resnum = 0
                while True:
                    resnum += 1
                    if resnum not in resnums_used[chain]:
                        a.resnum = resnum  # Reassign the resnum for this atom
                        msg = "Atom %i: assigned a new resnum of %i" % (ai,
                                                                        resnum)
                        logger.debug(msg)
                        resnums_used[chain][
                            resnum] = True  # Update the list - this resnum is now used
                        break
            else:  # This is the only residue in the chain
                resnum = 1
                a.resnum = resnum  # Reassign the resnum for this atom
                resnums_used[chain] = {
                    resnum: True
                }  # Udate the list - this resnum is now used 
[docs]    def renumberByMolecule(self):
        """
        Break up the structure into multiple structures - one molecule
        each Recombined the structures into one, adding a molecule at
        a time.  The purpose of this is so that all atoms from each
        molecule appear together in the connection table.
        """
        mol_sts = []
        for mol in self.st.molecule:
            mol_sts.append(mol.extractStructure())
        combined_st = mol_sts[0]
        for st in mol_sts[1:]:
            combined_st.extend(st)
        self.st = combined_st 
[docs]    def fixPdbresAndPdbnames(self):
        """
        Justifies PDB residue name.  Changes all PDB atom names to
        upper case.
        """
        for res in self.st.residue:
            oldr = res.pdbres
            # Strip 4th char (if any), right justify, and append a space:
            pdbres = oldr.rjust(3)[:3] + ' '
            if pdbres != oldr:
                # Setting pdbres will justify properly in r2006-1 release
                res.pdbres = pdbres
                msg = "%s: changed pdbres from '%s' to '%s'" % (res, oldr,
                                                                pdbres)
                logger.debug(msg)
            # Capitalize the PDB atom names:
            changed_pdbnames = False
            for a in res.atom:
                pdbname = a.pdbname
                upper = pdbname.upper()
                if pdbname != upper:
                    changed_pdbnames = True
                    a.pdbname = upper
            if changed_pdbnames:
                msg = "%s: capitalized PDB atom names" % res
                logger.debug(msg)
        return 
[docs]    def fixHets(self):
        """ Calls reassignHets on identified het atoms.  """
        hets = findhets.find_hets(self.st,
                                  include_metals=False,
                                  include_hydrogens=True,
                                  excluded_hets=excluded_hets)
        # Assign resnum/pdbname/ to het atoms:
        self.reassignHets(hets)
        return 
[docs]    def fixAll(self):
        """
        Driver method that makes all the fixes.
        """
        self.fixPdbresAndPdbnames()
        self.fixMetals()
        self.fixSpecialResidues()
        self.fixDNACaps()
        self.fixHets()
        # Assign residue numbers to atoms in reassign_resnum dict:
        self.reassignResNums()
        self.renumberByMolecule()
        self.st.retype()  # for some reason needs to be called twice Ev:51529
        self.st.retype()  # for some reason needs to be called twice Ev:51529
        return  
[docs]def fix_for_prime(st):
    """
    Preferred way to use the FixForPrime class.
    Fixes structure.Structure for Prime calculations.  This is the
    preferred API for scripts implementing primefix.py.
        1. Assign a unique, single residue name, residue number, and chain
           name to all ligand residues. Make sure that new name doesn't
           conflict with amino acid resname. Back up the old residue name as
           s_ppw_resname_bu, s_ppw_resnum_bu
        2. Left-justify the 3-character residue name.
        3. Set pdbname to all upper case.
        4. Delete all bonds to metals and assign formal charges to the metal
           and previously attached atoms.
        5. Water molecules should have a residue name of "HOH " and pdb atom
           names " O ", "1H  "  and "2H  ". Currently, residues TIP, TIP3,
           TIP4, SPC, and HOH and DOD are treated.
        6. FeS clusters are fixed with the appropriate bonds anf formal
           charges.
        7. Terminal caps are fixed (rename "NME " to "NMA ", pdbname " C  "
           of NMA to " CA ", and pdbname " CA " of ACE to " CH3").
        8. Atoms from the same molecule are placed together in the
           connection table.
    NOTE: Prime can still fail if a residue is not in order of connectivity.
    If this is the case in your structure, it should be exported as a PDB
    file and then reimported as a Maestro file before running this script.
    """
    ffp = FixForPrime(st)
    ffp.fixAll()
    return 
if __name__ == '__main__':
    if len(sys.argv) < 3:
        print("Usage: primefix.py <input.mae> <output.mae>")
        sys.exit(0)
    infile = sys.argv[1]
    outfile = sys.argv[2]
    written = False
    for st in structure.StructureReader(infile):
        fix_for_prime(st)
        if not written:
            st.write(outfile)
            written = True
        else:
            st.append(outfile)
#EOF