from typing import TYPE_CHECKING
from typing import List
from typing import Optional
from typing import Tuple
from schrodinger.application.desmond import constants
from schrodinger.application.desmond import struc
from schrodinger.application.desmond.struc import struc_iter
from schrodinger.structure import Structure
from schrodinger.structutils import transform
from schrodinger.structutils.analyze import evaluate_asl
from schrodinger.structutils.rmsd import get_super_transformation_matrix
if TYPE_CHECKING:
    from schrodinger.application.scisol.packages.fep import graph
[docs]def align_and_update_graph(
        cms_fname: str,
        input_graph: "graph.Graph",
        align_asl: str,
        additional_structures: Optional[List[Structure]] = None):
    """
    Update and align the input graph in addition to optionally aligning additional_structures
    with the output from the membrane building and relaxation workflow.
    """
    if input_graph.fep_type == constants.FEP_TYPES.COVALENT_LIGAND:
        align_asl = f'({align_asl}) and atom.{constants.FEP_COVALENT_PROTEIN_BB_ATOM} 1'
    protein_ct, membrane_ct, solvent_ct, ion_cts = _extract_from_cms(
        cms_fname, align_asl)
    _update_graph(input_graph,
                  protein_ct,
                  membrane_ct,
                  solvent_ct,
                  ion_cts,
                  align_asl,
                  additional_structures=additional_structures) 
def _extract_from_cms(
    cms_fname: str,
    align_asl: str,
    membrane_asl: str = 'membrane'
) -> Tuple[Structure, Structure, Structure, List[Structure]]:
    """
    Given a CMS with a membrane-bound protein, we align it such that the
    membrane is in the center of the X-Y plane and the membrane is at the
    center of the Z-axis. We then extract and return three CTs from it:
    receptor, membrane and solvent.
    """
    from schrodinger.application.desmond.packages import topo
    msys_model, cms_model = topo.read_cms(cms_fname)
    # center cms such that the center of the bilayer is at Z=0, and the
    # receptor is centered in X-Y plane.
    aids = cms_model.select_atom(align_asl)
    gids = topo.aids2gids(cms_model, aids, include_pseudoatoms=True)
    membrane_aids = cms_model.select_atom(membrane_asl)
    membrane_gids = topo.aids2gids(cms_model, membrane_aids)
    topo.center_cms(msys_model, membrane_gids, cms_model, [2])
    topo.center_cms(msys_model, gids, cms_model, [0, 1])
    comp_cms = struc.component_structures(cms_model)
    solvent_ct, membrane_ct = comp_cms.solvent, comp_cms.membrane
    # pick a protein structure for alignment
    ct = next(struc_iter(comp_cms.receptor))
    protein_ct = struc.delete_ffio_ff(ct.copy())
    # Need to convert FFIOStructure -> Structure to serialize
    membrane_ct = Structure(membrane_ct.handle)
    solvent_ct = Structure(solvent_ct.handle)
    ion_cts = [*struc_iter(comp_cms.ion)]
    ion_cts += [*struc_iter(comp_cms.positive_salt)]
    ion_cts += [*struc_iter(comp_cms.negative_salt)]
    # When extracting from FFIOStructure, the atoms are not visible
    # in Maestro. Set property to make them visible here.
    for atom in protein_ct.atom:
        atom.visible = 1
    # Remove property set during "group_waters" stage
    for atom in solvent_ct.atom:
        if atom.property.get(constants.FEP_ABSOLUTE_ENERGY):
            del atom.property[constants.FEP_ABSOLUTE_ENERGY]
    return protein_ct, membrane_ct, solvent_ct, ion_cts
def _update_graph(g: "graph.Graph",
                  protein_ct: Structure,
                  membrane_ct: Structure,
                  solvent_ct: Structure,
                  ion_cts: List[Structure],
                  align_asl: str,
                  additional_structures: Optional[List[Structure]] = None):
    """
    Add receptor, membrane and solvent CTs to the graph and align node strucs
    (and optionally additional_structures) with these updated CTs.
    """
    from schrodinger.application.scisol.packages.fep import hot_atom
    # Remove ct properties that are not needed.
    struc.delete_structure_properties([membrane_ct, solvent_ct, protein_ct], [
        constants.TRJ_POINTER, constants.M_SUBGROUP_TITLE,
        constants.M_SUBGROUPID, constants.M_SUBGROUP_COLLAPSED
    ])
    ref_struc, _ = get_membrane_launcher_input(g)
    g.membrane_struc = membrane_ct
    g.solvent_struc = solvent_ct
    relax_aids = evaluate_asl(protein_ct, align_asl)
    # Find rotation/translation matrix of input receptor conformation and
    # the equilibrated receptor conformation.
    input_aids = evaluate_asl(ref_struc, align_asl)
    matrix = get_super_transformation_matrix(protein_ct, relax_aids, ref_struc,
                                             input_aids)
    for n in g.nodes:
        transform.transform_structure(n.struc, matrix)
        if n.protein_struc:
            transform.transform_structure(n.protein_struc, matrix)
    if g.fep_type in [
            constants.FEP_TYPES.SMALL_MOLECULE,
            constants.FEP_TYPES.METALLOPROTEIN,
    ]:
        # For these types, just use the relaxed structure directly
        g.receptor_struc = protein_ct
        # Merge ions into receptor struc. Other FEP types will be handled in
        # future release
        for ct in ion_cts:
            g.receptor_struc = g.receptor_struc.merge(ct)
    elif g.receptor_struc:
        transform.transform_structure(g.receptor_struc, matrix)
    if additional_structures:
        for ct in additional_structures:
            transform.transform_structure(ct, matrix)
    # Since we changed the number of environment CTs by adding membrane
    # and solvent structures, we need to update HotAtomManager on each edge,
    # but we need to reset it.
    if g.hotatoms_rule_ligand is None:
        leg_solvent, leg_complex = constants.REST_REGION_RULE.DEFAULT, constants.REST_REGION_RULE.DEFAULT
    else:
        leg_solvent, leg_complex = g.hotatoms_rule_ligand
    for e in g.edges:
        e.hot_atoms.set_default()
    hot_atom.overwrite_hotatoms(g.edges(),
                                protein_asl=g.hotatoms_rule_protein,
                                leg_solvent=leg_solvent,
                                leg_complex=leg_complex)