"""
Collection of modules to repair structures to be used for desmond simulation
Copyright Schrodinger, LLC. All rights reserved.
"""
# Contributors: Byungchan Kim
import os
import numpy
import schrodinger.application.desmond.smarts as smarts
import schrodinger.infra.mm as mm
import schrodinger.structure as structure
import schrodinger.structutils.analyze as analyze
from schrodinger.application.desmond import constants
# Define water SMARTS patterns
water_smarts = ["[H]O[H]"]
zob_water_smarts = ["[H]O([H])_[*]", "[H]O[H]_[*]"]
[docs]def create_cms_from_mae(input_fname, output_fname, membrane_asl, solvent_asl):
    streader = structure.StructureReader(input_fname)
    full_st = None
    st_list = []
    for st in streader:
        if (constants.CT_TYPE in st.property and st.property[constants.CT_TYPE]
                == constants.CT_TYPE.VAL.FULL_SYSTEM):
            full_st = st
            break
        else:
            st_list.append(st)
    if full_st:
        if "i_ffio_num_component" in full_st.property:
            del full_st.property["i_ffio_num_component"]
    if not full_st:
        for st in st_list:
            if full_st:
                full_st.extend(st)
            else:
                full_st = st
    st_list = decompose_to_comp_ct(full_st, membrane_asl, solvent_asl)
    if os.path.exists(output_fname):
        os.remove(output_fname)
    for st in st_list:
        st.append(output_fname) 
[docs]def create_full_system_from_comp(input_fname, output_fname):
    streader = structure.StructureReader(input_fname)
    st_list = []
    for st in streader:
        if (constants.CT_TYPE in st.property and st.property[constants.CT_TYPE]
                == constants.CT_TYPE.VAL.FULL_SYSTEM):
            continue
        else:
            st_list.append(st)
    full_st = None
    for st in st_list:
        if full_st:
            full_st.extend(st)
        else:
            full_st = st.copy()
    full_st.property[constants.CT_TYPE] = constants.CT_TYPE.VAL.FULL_SYSTEM
    full_st.write(output_fname)
    for st in st_list:
        st.append(output_fname) 
[docs]def decompose_to_comp_ct(in_st, membrane_asl="", solvent_asl=""):
    """
    decompose in_st to component cts, solute, membrane, solvent, if exists any.
    return list of component cts
    return full_system ct also if return_full_system is true
    """
    st = in_st.copy()
    repair_box_vector(st)
    membrane_st = _extract_membrane(st, membrane_asl)
    ion_st = _extract_ion(st)
    solvent_st = _extract_solvent(st, solvent_asl)
    # add atom group for biasing potential
    if membrane_st:
        for a in solvent_st.atom:
            a.property[constants.FEP_ABSOLUTE_ENERGY] = 1
    solute_st = st
    solute_st.property[constants.CT_TYPE] = constants.CT_TYPE.VAL.SOLUTE
    if solute_st:
        full_st = solute_st.copy()
    else:
        full_st = structure.Structure()
    if membrane_st:
        full_st.extend(membrane_st)
    if ion_st:
        full_st.extend(ion_st)
    if solvent_st:
        full_st.extend(solvent_st)
    full_st.property[constants.CT_TYPE] = constants.CT_TYPE.VAL.FULL_SYSTEM
    st_list = [full_st]
    if solute_st:
        st_list.append(solute_st)
    if membrane_st:
        st_list.append(membrane_st)
    if ion_st:
        st_list.append(ion_st)
    if solvent_st:
        st_list.append(solvent_st)
    return st_list 
[docs]def repair_box_vector(st):
    """
    repair box vectors when they are corrupted.
    create box vectors based on system size
    """
    for name in constants.SIM_BOX:
        if name not in st.property:
            print(
                "WARNING: Property '%s' is missing. Generating box vectors based on system dimension."
                % name)
            break
    else:
        return
    for name in constants.SIM_BOX:
        st.property[name] = 0.0
    atom_xyz_array = st.getXYZ(copy=False)
    xmin, ymin, zmin = numpy.min(atom_xyz_array, 0)
    xmax, ymax, zmax = numpy.max(atom_xyz_array, 0)
    vals = (xmax - xmin, ymax - ymin, zmax - zmin)
    for prop, val in zip(constants.SIM_BOX_DIAGONAL, vals):
        st.property[prop] = val 
def _extract_solvent(st, asl=""):
    """
    extract water molecules
    """
    if not asl:
        atoms = smarts.evaluate_net_smarts(st, water_smarts, zob_water_smarts)
    else:
        atoms = analyze.evaluate_asl(st, asl)
    if len(atoms) == 0:
        return None
    tmp_st = st.extract(atoms, True)
    st.deleteAtoms(atoms)
    # reorder atoms such that each all atoms in the same molecule
    # are together
    atom_order = [0] + [a.index for m in tmp_st.molecule for a in m.atom]
    new_st = tmp_st.copy()
    mm.mmct_ct_reorder(new_st, tmp_st, atom_order)
    res_name = new_st.atom[1].pdbres
    new_st.title = res_name + " water box"
    new_st.property[constants.CT_TYPE] = constants.CT_TYPE.VAL.SOLVENT
    new_st.property["i_ffio_num_component"] = 1
    # overwrite residue # with molecule #
    # viparr fails when two or more molecules have same residue #
    for a in new_st.atom:
        a.resnum = a.molecule_number
    return new_st
def _extract_membrane(st, asl=""):
    """
    extract membrane
    assign reference custom charges for lipid molecules found at system_builder/data/lipid_charge.mae
    """
    if not asl:
        residue_name = [
            "DPPC", "DPPE", "DPPS", "POPC", "POPE", "POPS", "DOPC", "DOPE",
            "DOPS", "DMPC", "PIP"
        ]
        asl = 'res.ptype '
        for r in residue_name:
            asl += '\"' + r + '\" '
    atoms = analyze.evaluate_asl(st, asl)
    if len(atoms) == 0:
        return None
    new_st = st.extract(atoms, True)
    st.deleteAtoms(atoms)
    resname = new_st.atom[1].pdbres
    new_st.title = resname + " bilayer"
    new_st.property[constants.CT_TYPE] = constants.CT_TYPE.VAL.MEMBRANE
    homogeneous = set()
    for a in new_st.atom:
        homogeneous.add(a.pdbres)
    if len(homogeneous) == 1:
        new_st.property["i_ffio_num_component"] = 1
    if resname not in ["DPPC", "POPC", "POPE"]:
        return new_st
    return new_st
def _extract_ion(st):
    """
    extract ions
    """
    asl = 'm.atoms 1'
    atoms = analyze.evaluate_asl(st, asl)
    if len(atoms) == 0:
        return None
    new_st = st.extract(atoms, True)
    st.deleteAtoms(atoms)
    res_name = new_st.atom[1].pdbres
    new_st.title = res_name
    new_st.property[constants.CT_TYPE] = constants.CT_TYPE.VAL.ION
    # overwrite residue # with molecule #
    # viparr fails when two or more molecules have same residue #
    for a in new_st.atom:
        a.resnum = a.molecule_number
    return new_st