"""
This module provides a base class for tensor-like objects and methods for
basic tensor manipulation.  It also provides a class, SquareTensor,
that provides basic methods for creating and manipulating rank 2 tensors
Copyright Schrodinger, LLC. All rights reserved.
"""
# Based on pymatgen/analysis/elasticity/tensors.py (LEGAL-413)
# last updated from upstream: 9759a6e on May 9, 2018
import collections.abc
import itertools
import string
import warnings
from fractions import Fraction
import numpy as np
import spglib
from scipy.linalg import polar
from schrodinger.application.matsci.nano import xtal
voigt_map = [(0, 0), (1, 1), (2, 2), (1, 2), (0, 2), (0, 1)]
reverse_voigt_map = np.array([[0, 5, 4], [5, 1, 3], [4, 3, 2]])
[docs]class Tensor(np.ndarray):
    """
    Base class for doing useful general operations on Nth order tensors,
    without restrictions on the type (stress, elastic, strain, piezo, etc.)
    """
    def __new__(cls, input_array, vscale=None, check_rank=None):
        """
        Create a Tensor object.  Note that the constructor uses __new__
        rather than __init__ according to the standard method of
        subclassing numpy ndarrays.
        Args:
            input_array: (array-like with shape 3^N): array-like representing
                a tensor quantity in standard (i. e. non-voigt) notation
            vscale: (N x M array-like): a matrix corresponding
                to the coefficients of the voigt-notation tensor
        """
        obj = np.asarray(input_array).view(cls)
        obj.rank = len(obj.shape)
        if check_rank and check_rank != obj.rank:
            raise ValueError("{} input must be rank {}".format(
                obj.__class__.__name__, check_rank))
        vshape = tuple([3] * (obj.rank % 2) + [6] * (obj.rank // 2))
        obj._vscale = np.ones(vshape)
        if vscale is not None:
            obj._vscale = vscale
        if obj._vscale.shape != vshape:
            raise ValueError("Voigt scaling matrix must be the shape of the "
                             "voigt notation matrix or vector.")
        if not all([i == 3 for i in obj.shape]):
            raise ValueError("Pymatgen only supports 3-dimensional tensors, "
                             "and default tensor constructor uses standard "
                             "notation.  To construct from voigt notation, use"
                             " {}.from_voigt".format(obj.__class__.__name__))
        return obj
    def __array_finalize__(self, obj):
        if obj is None:
            return
        self.rank = getattr(obj, 'rank', None)
        self._vscale = getattr(obj, '_vscale', None)
        self._vdict = getattr(obj, '_vdict', None)
    def __array_wrap__(self, obj):
        """
        Overrides __array_wrap__ methods in ndarray superclass to avoid errors
        associated with functions that return scalar values
        """
        if len(obj.shape) == 0:
            return obj[()]
        else:
            return np.ndarray.__array_wrap__(self, obj)
    def __hash__(self):
        """
        define a hash function, since numpy arrays
        have their own __eq__ method
        """
        return hash(self.tobytes())
    def __repr__(self):
        return "{}({})".format(self.__class__.__name__, self.__str__())
[docs]    def zeroed(self, tol=1e-3):
        """
        returns the matrix with all entries below a certain threshold
        (i.e. tol) set to zero
        """
        new_tensor = self.copy()
        new_tensor[abs(new_tensor) < tol] = 0
        return new_tensor 
[docs]    def rotate(self, matrix, tol=1e-3):
        """
        Applies a rotation directly, and tests input matrix to ensure a valid
        rotation.
        Args:
            matrix (3x3 array-like): rotation matrix to be applied to tensor
            tol (float): tolerance for testing rotation matrix validity
        """
        matrix = SquareTensor(matrix)
        if not matrix.is_rotation(tol):
            raise ValueError("Rotation matrix is not valid.")
        return self.transform((matrix), [0., 0., 0.]) 
[docs]    def einsum_sequence(self, other_arrays, einsum_string=None):
        """
        Calculates the result of an einstein summation expression
        """
        if not isinstance(other_arrays, list):
            raise ValueError("other tensors must be list of "
                             "tensors or tensor input")
        other_arrays = [np.array(a) for a in other_arrays]
        if not einsum_string:
            lc = string.ascii_lowercase
            einsum_string = lc[:self.rank]
            other_ranks = [len(a.shape) for a in other_arrays]
            idx = self.rank - sum(other_ranks)
            for length in other_ranks:
                einsum_string += ',' + lc[idx:idx + length]
                idx += length
        einsum_args = [self] + list(other_arrays)
        return np.einsum(einsum_string, *einsum_args) 
    @property
    def symmetrized(self):
        """
        Returns a generally symmetrized tensor, calculated by taking
        the sum of the tensor and its transpose with respect to all
        possible permutations of indices
        """
        perms = list(itertools.permutations(range(self.rank)))
        return sum([np.transpose(self, ind) for ind in perms]) / len(perms)
    @property
    def voigt_symmetrized(self):
        """
        Returns a "voigt"-symmetrized tensor, i. e. a voigt-notation
        tensor such that it is invariant wrt permutation of indices
        """
        if not (self.rank % 2 == 0 and self.rank > 2):
            raise ValueError("V-symmetrization requires rank even and > 2")
        v = self.voigt
        perms = list(itertools.permutations(range(len(v.shape))))
        new_v = sum([np.transpose(v, ind) for ind in perms]) / len(perms)
        return self.__class__.from_voigt(new_v)
[docs]    def is_symmetric(self, tol=1e-5):
        """
        Tests whether a tensor is symmetric or not based on the residual
        with its symmetric part, from self.symmetrized
        Args:
            tol (float): tolerance to test for symmetry
        """
        return (self - self.symmetrized < tol).all() 
[docs]    def is_fit_to_structure(self, structure, tol=1e-2):
        """
        Tests whether a tensor is invariant with respect to the
        symmetry operations of a particular structure by testing
        whether the residual of the symmetric portion is below a
        tolerance
        Args:
            structure (Structure): structure to be fit to
            tol (float): tolerance for symmetry testing
        """
        return (self - self.fit_to_structure(structure) < tol).all() 
    @property
    def voigt(self):
        """
        Returns the tensor in Voigt notation
        """
        v_matrix = np.zeros(self._vscale.shape, dtype=self.dtype)
        this_voigt_map = self.get_voigt_dict(self.rank)
        for ind in this_voigt_map:
            v_matrix[this_voigt_map[ind]] = self[ind]
        if not self.is_voigt_symmetric():
            warnings.warn("Tensor is not symmetric, information may "
                          "be lost in voigt conversion.")
        return v_matrix * self._vscale
[docs]    def is_voigt_symmetric(self, tol=1e-6):
        """
        Tests symmetry of tensor to that necessary for voigt-conversion
        by grouping indices into pairs and constructing a sequence of
        possible permutations to be used in a tensor transpose
        """
        transpose_pieces = [[[0 for i in range(self.rank % 2)]]]
        transpose_pieces += [
            [range(j, j + 2)] for j in range(self.rank % 2, self.rank, 2)
        ]
        for n in range(self.rank % 2, len(transpose_pieces)):
            if len(transpose_pieces[n][0]) == 2:
                transpose_pieces[n] += [transpose_pieces[n][0][::-1]]
        for trans_seq in itertools.product(*transpose_pieces):
            trans_seq = list(itertools.chain(*trans_seq))
            if (self - self.transpose(trans_seq) > tol).any():
                return False
        return True 
[docs]    @staticmethod
    def get_voigt_dict(rank):
        """
        Returns a dictionary that maps indices in the tensor to those
        in a voigt representation based on input rank
        Args:
            rank (int): Tensor rank to generate the voigt map
        """
        vdict = {}
        for ind in itertools.product(*[range(3)] * rank):
            v_ind = ind[:rank % 2]
            for j in range(rank // 2):
                pos = rank % 2 + 2 * j
                v_ind += (reverse_voigt_map[ind[pos:pos + 2]],)
            vdict[ind] = v_ind
        return vdict 
[docs]    @classmethod
    def from_voigt(cls, voigt_input):
        """
        Constructor based on the voigt notation vector or matrix.
        Args:
            voigt_input (array-like): voigt input for a given tensor
        """
        voigt_input = np.array(voigt_input)
        rank = sum(voigt_input.shape) // 3
        t = cls(np.zeros([3] * rank))
        if voigt_input.shape != t._vscale.shape:
            raise ValueError("Invalid shape for voigt matrix")
        voigt_input = voigt_input / t._vscale
        this_voigt_map = t.get_voigt_dict(rank)
        for ind in this_voigt_map:
            t[ind] = voigt_input[this_voigt_map[ind]]
        return cls(t) 
[docs]    @classmethod
    def from_values_indices(cls,
                            values,
                            indices,
                            populate=False,
                            structure=None,
                            voigt_rank=None,
                            vsym=True,
                            verbose=False):
        """
        Creates a tensor from values and indices, with options
        for populating the remainder of the tensor.
        :param values: numbers to place at indices
        :type values: list[float]
        :param indices: array-like collection of indices to place values at
        :param populate: whether to populate the tensor
        :type populate: bool
        :param structure: structure to base population or fit_to_structure on
        :type structure: Structure
        :param voigt_rank: full tensor rank to indicate the shape of the
            resulting tensor. This is necessary if one provides a set of
            indices more minimal than the shape of the tensor they want, e.g.
            Tensor.from_values_indices((0, 0), 100)
        :type voigt_rank: int
        :param vsym: whether to voigt symmetrize during the optimization
            procedure
        :type vsym: bool
        :param verbose: whether to populate verbosely
        :type verbose: bool
        """
        # auto-detect voigt notation
        # TODO: refactor rank inheritance to make this easier
        indices = np.array(indices)
        if voigt_rank:
            shape = ([3] * (voigt_rank % 2) + [6] * (voigt_rank // 2))
        else:
            shape = np.ceil(np.max(indices + 1, axis=0) / 3.) * 3
        base = np.zeros(shape.astype(int))
        for v, idx in zip(values, indices):
            base[tuple(idx)] = v
        if 6 in shape:
            obj = cls.from_voigt(base)
        else:
            obj = cls(base)
        if populate:
            assert structure, "Populate option must include structure input"
            obj = obj.populate(structure, vsym=vsym, verbose=verbose)
        elif structure:
            obj = obj.fit_to_structure(structure)
        return obj  
[docs]class TensorCollection(collections.abc.Sequence):
    """
    A sequence of tensors that can be used for fitting data
    or for having a tensor expansion
    """
[docs]    def __init__(self, tensor_list, base_class=Tensor):
        self.tensors = [
            base_class(t) if not isinstance(t, base_class) else t
            for t in tensor_list
        ] 
[docs]    def __len__(self):
        return len(self.tensors) 
    def __getitem__(self, ind):
        return self.tensors[ind]
    def __iter__(self):
        return self.tensors.__iter__()
[docs]    def zeroed(self, tol=1e-3):
        return self.__class__([t.zeroed(tol) for t in self]) 
[docs]    def rotate(self, matrix, tol=1e-3):
        return self.__class__([t.rotate(matrix, tol) for t in self]) 
    @property
    def symmetrized(self):
        return self.__class__([t.symmetrized for t in self])
[docs]    def is_symmetric(self, tol=1e-5):
        return all([t.is_symmetric(tol) for t in self]) 
[docs]    def fit_to_structure(self, structure, symprec=0.1):
        return self.__class__(
            [t.fit_to_structure(structure, symprec) for t in self]) 
[docs]    def is_fit_to_structure(self, structure, tol=1e-2):
        return all([t.is_fit_to_structure(structure, tol) for t in self]) 
    @property
    def voigt(self):
        return [t.voigt for t in self]
    @property
    def ranks(self):
        return [t.rank for t in self]
[docs]    def is_voigt_symmetric(self, tol=1e-6):
        return all([t.is_voigt_symmetric(tol) for t in self]) 
[docs]    @classmethod
    def from_voigt(cls, voigt_input_list, base_class=Tensor):
        return cls([base_class.from_voigt(v) for v in voigt_input_list]) 
[docs]    def convert_to_ieee(self,
                        structure,
                        initial_fit=True,
                        refine_rotation=True):
        return self.__class__([
            t.convert_to_ieee(structure, initial_fit, refine_rotation)
            for t in self
        ])  
[docs]class SquareTensor(Tensor):
    """
    Base class for doing useful general operations on second rank tensors
    (stress, strain etc.).
    """
    def __new__(cls, input_array, vscale=None):
        """
        Create a SquareTensor object.  Note that the constructor uses __new__
        rather than __init__ according to the standard method of
        subclassing numpy ndarrays.  Error is thrown when the class is
        initialized with non-square matrix.
        Args:
            input_array (3x3 array-like): the 3x3 array-like
                representing the content of the tensor
            vscale (6x1 array-like): 6x1 array-like scaling the
                voigt-notation vector with the tensor entries
        """
        obj = super(SquareTensor, cls).__new__(cls,
                                               input_array,
                                               vscale,
                                               check_rank=2)
        return obj.view(cls)
    @property
    def trans(self):
        """
        shorthand for transpose on SquareTensor
        """
        return SquareTensor(np.transpose(self))
    @property
    def inv(self):
        """
        shorthand for matrix inverse on SquareTensor
        """
        if self.det == 0:
            raise ValueError("SquareTensor is non-invertible")
        return SquareTensor(np.linalg.inv(self))
    @property
    def det(self):
        """
        shorthand for the determinant of the SquareTensor
        """
        return np.linalg.det(self)
[docs]    def is_rotation(self, tol=1e-3, include_improper=True):
        """
        Test to see if tensor is a valid rotation matrix, performs a
        test to check whether the inverse is equal to the transpose
        and if the determinant is equal to one within the specified
        tolerance
        Args:
            tol (float): tolerance to both tests of whether the
                the determinant is one and the inverse is equal
                to the transpose
            include_improper (bool): whether to include improper
                rotations in the determination of validity
        """
        det = np.abs(np.linalg.det(self))
        if include_improper:
            det = np.abs(det)
        return (np.abs(self.inv - self.trans) < tol).all() \
            
and (np.abs(det - 1.) < tol) 
[docs]    def refine_rotation(self):
        """
        Helper method for refining rotation matrix by ensuring
        that second and third rows are perpindicular to the first.
        Gets new y vector from an orthogonal projection of x onto y
        and the new z vector from a cross product of the new x and y
        Args:
            tol to test for rotation
        Returns:
            new rotation matrix
        """
        new_x, y = get_uvec(self[0]), get_uvec(self[1])
        # Get a projection on y
        new_y = y - np.dot(new_x, y) * new_x
        new_z = np.cross(new_x, new_y)
        return SquareTensor([new_x, new_y, new_z]) 
[docs]    def get_scaled(self, scale_factor):
        """
        Scales the tensor by a certain multiplicative scale factor
        Args:
            scale_factor (float): scalar multiplier to be applied to the
                SquareTensor object
        """
        return SquareTensor(self * scale_factor) 
    @property
    def principal_invariants(self):
        """
        Returns a list of principal invariants for the tensor,
        which are the values of the coefficients of the characteristic
        polynomial for the matrix
        """
        return np.poly(self)[1:] * np.array([-1, 1, -1])
[docs]    def polar_decomposition(self, side='right'):
        """
        calculates matrices for polar decomposition
        """
        return polar(self, side=side)  
[docs]def get_uvec(vec):
    """ Gets a unit vector parallel to input vector"""
    l = np.linalg.norm(vec)
    if l < 1e-8:
        return vec
    return vec / l 
[docs]def get_spglib_symmops(spg_cell,
                       symprec=xtal.ASSIGN_SPG_SYMPREC,
                       cartesian=False):
    ops = spglib.get_symmetry(spg_cell, symprec=symprec)
    # Use P1 translations/rotations if symmetry could not be found (MATSCI-8107)
    ops = ops if ops is not None else spglib.get_symmetry_from_database(1)
    translations = []
    for t in ops['translations']:
        translations.append(
            [float(Fraction.from_float(c).limit_denominator(1000)) for c in t])
    translations = np.array(translations)
    # fractional translations of 1 are more simply 0
    translations[np.abs(translations) == 1.] = 0.
    rotations = ops['rotations']
    symmops = []
    vecs = spg_cell[0]
    invmat = np.linalg.inv(vecs.T)
    for rot, trans in zip(rotations, translations):
        if cartesian:
            rot = np.dot(vecs.T, np.dot(rot, invmat))
            trans = np.dot(trans, vecs)
        symmops.append((rot, trans))
    return symmops 
[docs]def symmetry_reduce(tensors, struct, tol=1e-8, **kwargs):
    """
    Function that converts a list of tensors corresponding to a structure
    and returns a dictionary consisting of unique tensor keys with symmop
    values corresponding to transformations that will result in derivative
    tensors from the original list
    :param tensors: list of Tensor objects to test for symmetrically-equivalent
        duplicates
    :type tensors: list[Tensor]
    :param structure: structure from which to get symmetry
    :type structure: Structure
    :param tol: tolerance for tensor equivalence
    :type tol: float
    :param kwargs: keyword arguments for the SpacegroupAnalyzer
    :returns: dictionary consisting of unique tensors with symmetry operations
        corresponding to those which will reconstruct the remaining tensors as
        values
    """
    vecs = np.array(xtal.get_vectors_from_chorus(struct))
    fcoords = xtal.trans_cart_to_frac_from_vecs(struct.getXYZ(), *vecs)
    anums = [a.atomic_number for a in struct.atom]
    spg_cell = (vecs, fcoords, anums)
    symmops = get_spglib_symmops(spg_cell, cartesian=True)
    unique_tdict = {}
    for tensor in tensors:
        is_unique = True
        for unique_tensor, symmop in itertools.product(unique_tdict, symmops):
            if (np.abs(unique_tensor.transform(symmop) - tensor) < tol).all():
                unique_tdict[unique_tensor].append(symmop)
                is_unique = False
                break
        if is_unique:
            unique_tdict[tensor] = []
    return unique_tdict 
[docs]def get_tkd_value(tensor_keyed_dict, tensor, allclose_kwargs=None):
    """
    Helper function to find a value in a tensor-keyed-
    dictionary using an approximation to the key.  This
    is useful if rounding errors in construction occur
    or hashing issues arise in tensor-keyed-dictionaries
    (e. g. from symmetry_reduce).  Resolves most
    hashing issues, and is preferable to redefining
    eq methods in the base tensor class.
    :param tensor_keyed_dict: dict with Tensor keys
    :type tensor_keyed_dict: dict
    :param tensor: tensor to find value of in the dict
    :param allclose_kwargs: dict of keyword-args to pass to allclose.
    :type allclose_kwargs: dict
    """
    if allclose_kwargs is None:
        allclose_kwargs = {}
    for tkey, value in tensor_keyed_dict.items():
        if np.allclose(tensor, tkey, **allclose_kwargs):
            return value 
[docs]def get_symmetric(array):
    """
    Return a symmetric matrix from a square matrix.
    :param numpy.array array: Input array
    :return numpy.array: Symmetrized array
    """
    return (array + array.T) * 0.5