import schrodinger.utils.log as log
from schrodinger import structure
from schrodinger.infra import phase
logger = log.get_output_logger(name="fragment_snapcore")
[docs]def get_aligned_coordinates_from_fragments(ct1, ct2, match1, match2,
                                           broken_bond_ct1, broken_bond_ct2):
    """
    Send each connected fragment pairs for snapcore coordinates.
    If the fragments have less than three atoms and contain no
    dummy atoms, generate snapcore coordinates.  If case of stereo failure
    or phase exeception, return None for coordinate array.
    :param ct1: Structure from first molecule
    :param ct2: Structure from second molecule
    :param matched1: list of matched atom indices in ct1
    :param matched2: list of matched atom indices in ct2
    :param broken_bond_ct1: list of bond tuples
    :param broken_bond_ct2: list of bond tuples
    :return: tuple of snapped core coordinates for ct1 and ct2 if no failure,
        otherwise None for failed snapcore
    """
    ORIG_ATOM_INDEX = "i_fep_original_atom_index"
    atom_map = {a: b for (a, b) in zip(match1, match2)}
    st1 = ct1.copy()
    st2 = ct2.copy()
    for ct in [st1, st2]:
        for a in ct.atom:
            a.property[ORIG_ATOM_INDEX] = a.index
    for ct, bonds in zip((st1, st2), (broken_bond_ct1, broken_bond_ct2)):
        for ai, aj in bonds:
            ct.deleteBond(ai, aj)
    frag1 = [mol.extractStructure(copy_props=True) for mol in st1.molecule]
    all_wt_orig2frag = {}
    for f_idx, frag in enumerate(frag1):
        all_wt_orig2frag[f_idx] = {
            a.property[ORIG_ATOM_INDEX]: a.index for a in frag.atom
        }
    frag_used = set()
    coord1 = st1.getXYZ()
    coord2 = st2.getXYZ()
    snap_success1 = False
    snap_success2 = False
    for fmut in [mol.extractStructure(copy_props=True) for mol in st2.molecule]:
        mut_orig2frag = {
            a.property[ORIG_ATOM_INDEX]: a.index for a in fmut.atom
        }
        matched1 = []
        matched2 = []
        fwt = None
        wt_orig2frag = None
        for fwt_idx in set(range(len(frag1))) - frag_used:
            wt_map = all_wt_orig2frag[fwt_idx]
            for a in wt_map:
                if a in atom_map and atom_map[a] in mut_orig2frag:
                    matched2.append(mut_orig2frag[atom_map[a]])
                    matched1.append(wt_map[a])
            if len(matched1) > 0:
                frag_used.add(fwt_idx)
                fwt = frag1[fwt_idx]
                wt_orig2frag = wt_map
                break
        if len(matched1) == 0:
            continue
        if len(matched1) > 2:
            snap_success2 = get_aligned_fragment_coords(fwt, fmut, matched1,
                                                        matched2, mut_orig2frag,
                                                        coord2)
            snap_success1 = get_aligned_fragment_coords(fmut, fwt, matched2,
                                                        matched1, wt_orig2frag,
                                                        coord1)
        elif (len(matched1) == fwt.atom_total) and (len(matched2)
                                                    == fmut.atom_total):
            # just copy all coordinates
            for o_idx in wt_orig2frag:
                coord1[o_idx - 1] = ct2.atom[atom_map[o_idx]].xyz
                coord2[atom_map[o_idx] - 1] = ct1.atom[o_idx].xyz
        else:
            logger.info(
                "skipping some fragments for snapcore because there are less than 3 atoms"
            )
    ret_val1 = None
    ret_val2 = None
    if snap_success1:
        ret_val1 = coord1
    if snap_success2:
        ret_val2 = coord2
    return (ret_val1, ret_val2) 
[docs]def get_aligned_fragment_coords(fwt, fmut, matched1, matched2, map_orig2frag,
                                coord):
    """
    :param fwt: fragment Structure from first molecule
    
    :param fmut: fragment Structure from second molecule
    :param matched1: list of matched atom indices in fwt
    :param matched2: list of matched atom indices in fmut
    :param map_orig2frag: dictionary mapping original atom indices to fragment
        indices
    :param coord: original coordinates returned by mut_ct.getXYZ(), if snapcore
        is possible aligned coordinates of fmut atoms are copied
    return: False if alignment failed otherwise True
    """
    try:
        aligned_coord = get_aligned_coordinates(fwt, fmut, matched1, matched2)
    except:
        #In case of PhpException and RuntimeError for stereo failure
        #don't update coordinates
        return False
    inv_dict = {v: k for (k, v) in map_orig2frag.items()}
    for a in fmut.atom:
        idx = inv_dict[a.index] - 1
        for ii in range(3):
            coord[idx][ii] = aligned_coord[(a.index - 1)][ii]
    return True 
[docs]def get_aligned_coordinates(ct1,
                            ct2,
                            matched_atom1,
                            matched_atom2,
                            broken_bond_ct2=[]):  # noqa: M511
    """
    Wrapper for PhpAlignCore
    :param ct1: first Structure
    :param ct2: second Structure
    :param matched_atom1: list of mateched atom in ct1
    :param matched_atom2: list of mateched atom in ct2
    :param broken_bond_ct2: list of broken bonds in ct2
    :return: numpy array of aligned coordinates for ct2 atoms if successful,
        otherwise None
    """
    align_options = phase.PhpAlignCoreOptions()
    align_options.stereo_change_action = phase.PhpStereoChangeAction_SAVE_MAPPINGS
    align_options.stereo_change = phase.PhpStereoChange_INVERSION_ONLY
    aligned = phase.PhpAlignCore(ct1, ct2, matched_atom1, matched_atom2,
                                 broken_bond_ct2, align_options)
    aligned_ct_coord = None
    if not aligned.stereoFailure():
        aligned_ct = structure.Structure(aligned.getAlignedB()[0])
        return aligned_ct.getXYZ()
    else:
        raise RuntimeError(
            'WARNING: Call to phase_align_core failed due to stereoFailure: template ct: %s, other ct:%s'
            % (ct1.title, ct2.title))