""" transform a Structure into a standard nuclear orientation"""
import numpy as np
from schrodinger.structutils import analyze
from schrodinger.structutils import transform
ZERO_THRESH = 1.0e-4
SECOND_MASS = 0
ATOM_INDEX = 1
[docs]def standard_nuclear_orientation(st):
    """
    Transform the structure to a standard nuclear orientation (SNO).
    SNO means the center of mass is at the origin and the principal moment
    of inertia axes are aligned with the coordinate (x,y,z) axes.
    :param st: structure to orient
    :type st: Structure instance
    """
    # move COM to origin
    Rcom = analyze.center_of_mass(st)
    transform.translate_structure(st, x=-Rcom[0], y=-Rcom[1], z=-Rcom[2])
    # get moments
    vals, vecs = analyze.calculate_principal_moments(struct=st)
    #determine how many degenerate principal spaces there are
    degen_space_dim = _determine_degen_spaces(vals, vecs)
    # make sure this is a proper rotation
    _fix_handedness(vecs)
    # build the 4x4 transformation matrix used by structutils.transform
    transform_mat = np.zeros([4, 4])
    for i in range(3):
        transform_mat[0, i] = vecs[0][i]
        transform_mat[1, i] = vecs[1][i]
        transform_mat[2, i] = vecs[2][i]
    transform_mat[3, 3] = 1.0
    transform.transform_structure(st, transform_mat)
    # deal with the degenerate spaces
    if degen_space_dim > 1:
        _resolve_degen_spaces(st, degen_space_dim)
    # deal with negative signs in principal axis system
    _resolve_180_rotations(st) 
def _determine_degen_spaces(vals, vecs):
    """
    Determine if any of the principal axes are in a degenerate space,
    that is, if they correspond to the same (non-trivial) eigenvalue.
    (The trivial eigenvalues are also degenerate, but there is also no use
    in trying to fix them, because being zero means there is nothing to do.)
    Something must be done to resolve the ambiguity and consistently pick
    specific axes (as any orthogonal set in the degenerate space are valid)
    For now, we just find out how big any degenerate spaces are, and ensure
    that any "special" axis, is given higher priority than degenerate axes.
    :type vals: list of floats
    :param vals: the principal moments
    :type vecs: list of 3 numpy arrays of length 3
    :param vecs: the principal axes
    """
    degen_space_dim = 1
    special_axis = None
    if abs(vals[0] - vals[1]) < ZERO_THRESH and abs(vals[0]) > ZERO_THRESH:
        degen_space_dim += 1
        special_axis = 2
    if abs(vals[1] - vals[2]) < ZERO_THRESH and abs(vals[1]) > ZERO_THRESH:
        degen_space_dim += 1
        special_axis = 0
    if special_axis == 2:
        vecs.reverse()
    return degen_space_dim
def _fix_handedness(moments):
    """
    Check the handedness of three vectors
    and try to invert the third vector to make a right handed coordinate system.
    :param moments:  Three vectors which should be orthogonal.
    :type moments: list of 3 numpy arrays of length 3
    """
    det = np.linalg.det(moments)
    # improper rotation
    if det < 0.0:
        x = moments[0]
        y = moments[1]
        z = moments[2]
        v = np.cross(x, y)
        dot = np.dot(v, z)
        #reflect the z axis
        if dot < 0.0:
            for i in range(3):
                moments[2][i] = -moments[2][i]
def _resolve_degen_spaces(st, dg_dim):
    """
    Rotate within 3 (or 2)-dim degenerate subspace to put lowest index, mass-weighted
    atoms among those closest to the origin to lie along x (or y) axis. If 3-dim,
    we further rotate in the 2-dim degenerate subspace to lie on the y axis.
    We assume in this function that the molecule has COM at origin. If a 2D subspace
    is present, the x axis will already be the unique axis.
    :type st: Structure instance
    :param st: the structure to be rotated to unambiguous principal axes
    :type dg_dim: int
    :param dg_dim: the dimensionality of the degenerate space
    """
    for i in range(dg_dim, 1, -1):
        atom_xyz = st.getXYZ()
        atom_dist = []
        for idx, at in enumerate(atom_xyz):
            atom_dist.append(
                np.sqrt(sum(at**2)) / st.atom[idx + 1].atomic_weight)
        cyl_dist = []
        if i == 2:
            for idx, at in enumerate(atom_xyz):
                cyl_dist.append(
                    np.sqrt(sum(at[1:]**2)) / st.atom[idx + 1].atomic_weight)
        high_prior_atom = None
        ## The following loop finds the closest atom to the origin (that is
        ## not at the origin) with lowest atom index.
        close_dist = None
        for idx, at in enumerate(atom_dist):
            if at > ZERO_THRESH:
                if close_dist is None or (
                        at < close_dist and
                        not abs(at - close_dist) < ZERO_THRESH):
                    if i == 3 or cyl_dist[idx] > ZERO_THRESH:
                        close_dist = at
                        high_prior_atom = idx
        if high_prior_atom is None:
            break
        sp_x, sp_y, sp_z = st.atom[high_prior_atom + 1].xyz
        theta = np.pi / 2
        ## We will want a slightly tighter ZERO_THRESH to check for
        ## zero division
        if abs(sp_y) > ZERO_THRESH * 1.0e-2:
            theta = np.arctan(-sp_z / sp_y)
        rot_decomp1 = transform.get_rotation_matrix(transform.X_AXIS, theta)
        rot_decomp2 = np.eye(4)
        if i == 3:
            alpha = np.pi / 2
            if abs(sp_x) > ZERO_THRESH * 1.0e-2:
                alpha = np.arctan(
                    -(np.cos(theta) * sp_y - np.sin(theta) * sp_z) / sp_x)
            rot_decomp2 = transform.get_rotation_matrix(transform.Z_AXIS, alpha)
        transform_mat = np.dot(rot_decomp2, rot_decomp1)
        transform.transform_structure(st, transform_mat)
def _resolve_180_rotations(st):
    """
    We wish to disambiguate amongst 180 degree rotations of
    the coordinate system (since principal moments are only
    defined up to a negative sign and two negative signs is a
    proper 180 degree rotation). We first try to break this
    symmetry by second atom mass moments, that is, if there is
    a heavy side and a light side along the first two principal
    axes, opt for the coordinate system that points these heavier
    sides toward the positive x and y axes. If there is still a
    symmetry here, opt for the axes that put larger index atoms
    along the positive axes. This will make it so that, for
    example, two H2 molecules (with the same bond length) will
    have the same SNO all the way to atom numbering.
    """
    mx, my, mz = _compute_weighted_moments(st, SECOND_MASS)
    if abs(mx) < ZERO_THRESH or abs(my) < ZERO_THRESH:
        nx, ny, nz = _compute_weighted_moments(st, ATOM_INDEX)
        if abs(mx) < ZERO_THRESH:
            mx = nx
        if abs(my) < ZERO_THRESH:
            my = ny
    transform_mat2 = np.eye(4)
    if mx < -ZERO_THRESH:
        if my < -ZERO_THRESH:
            transform_mat2[0, 0] = -1.0
            transform_mat2[1, 1] = -1.0
        else:
            transform_mat2[0, 0] = -1.0
            transform_mat2[2, 2] = -1.0
    else:
        if my < -ZERO_THRESH:
            transform_mat2[1, 1] = -1.0
            transform_mat2[2, 2] = -1.0
    if np.sum(transform_mat2) == 0:
        transform.transform_structure(st, transform_mat2)
def _compute_weighted_moments(st, weighting):
    """
    Computes the weighted distance along x, y, and z.
    We assume that the molecule's center of mass is already at the origin (and the
    x, y, and z axes are the principal moments, though this is not strictly required
    by the code), and so the sum of the mass-weighted x, y, and z components are 0.
    If we wish to break the symmetry of our principal axis system, we can look at
    second mass moments. That is, the sum of the x, y and z components of all atoms
    times the square of the atoms' masses. For all but truly symmetric systems, this
    breaks principal axis symm. weighting = SECOND_MASS
    If this is insufficient to break the symmetry, a last ditch effort to break
    molecular symmetry so that standard nuclear orientation of a highly symmetric
    molecule (e.g. H2) is the same even up to atom is to use the atom indices as the
    weights. weighting = ATOM_INDEX
    :param st: the structure to analyze, with COM at origin
    :type st: Structure instance
    :param weighting: the function to weight the coordinates by
    :type weighting: SECOND_MASS or ATOM_INDEX
    :return: (weight*x, weight*y, weight*z)
    :rtype: 3-tuple of floats
    """
    weight = lambda idx, at: 0.0
    if weighting == SECOND_MASS:
        weight = lambda idx, at: at.atomic_weight**2
    elif weighting == ATOM_INDEX:
        weight = lambda idx, at: idx + 1
    else:
        raise ValueError(f"{weighting} is not a valid weighting function")
    atom_xyz = np.zeros((len(st.atom), 3))
    for idx, at in enumerate(st.atom):
        atom_xyz[idx] = np.array(at.xyz) * weight(idx, at)
    mx, my, mz = np.sum(atom_xyz, axis=0)
    return mx, my, mz