"""
This module provides a class used to describe the elastic tensor,
including methods used to fit the elastic tensor from linear response
stress-strain data.
Copyright Schrodinger, LLC. All rights reserved.
"""
# Based on pymatgen/analysis/elasticity/elastic.py (LEGAL-413)
# last updated from upstream: 342de8d on Aug 20, 2018
import warnings
from collections import OrderedDict
import numpy as np
from scipy.special import factorial
from scipy.stats import linregress
from schrodinger.application.matsci.elasticity.strain import Strain
from schrodinger.application.matsci.elasticity.stress import Stress
from schrodinger.application.matsci.elasticity.tensors import Tensor
from schrodinger.utils.units import GIGA_TO_MEGA
[docs]class ComplianceTensor(Tensor):
    """
    This class represents the compliance tensor, and exists
    primarily to keep the voigt-conversion scheme consistent
    since the compliance tensor has a unique vscale
    """
    def __new__(cls, s_array):
        vscale = np.ones((6, 6))
        vscale[3:] *= 2
        vscale[:, 3:] *= 2
        obj = super(ComplianceTensor, cls).__new__(cls, s_array, vscale=vscale)
        return obj.view(cls) 
[docs]class NthOrderElasticTensor(Tensor):
    """
    An object representing an nth-order tensor expansion
    of the stress-strain constitutive equations
    """
    def __new__(cls, input_array, check_rank=None, tol=1e-4):
        obj = super(NthOrderElasticTensor, cls).__new__(cls,
                                                        input_array,
                                                        check_rank=check_rank)
        if obj.rank % 2 != 0:
            raise ValueError("ElasticTensor must have even rank")
        if not obj.is_voigt_symmetric(tol):
            warnings.warn("Input elastic tensor does not satisfy "
                          "standard voigt symmetries")
        return obj.view(cls)
    @property
    def order(self):
        """
        Order of the elastic tensor
        """
        return self.rank // 2
[docs]    def calculate_stress(self, strain):
        """
        Calculate's a given elastic tensor's contribution to the
        stress using Einstein summation
        :param strain: 3x3 matrix corresponding to strain
        """
        strain = np.array(strain)
        if strain.shape == (6,):
            strain = Strain.from_voigt(strain)
        assert strain.shape == (3, 3), "Strain must be 3x3 or voigt-notation"
        stress_matrix = self.einsum_sequence([strain] * (self.order - 1)) \
                
/ factorial(self.order - 1)
        return Stress(stress_matrix)  
[docs]class ElasticTensor(NthOrderElasticTensor):
    """
    This class extends Tensor to describe the 3x3x3x3
    second-order elastic tensor, C_{ijkl}, with various
    methods for estimating other properties derived from
    the second order elastic tensor
    """
    def __new__(cls, input_array, tol=1e-4):
        """
        Create an ElasticTensor object.  The constructor throws an error if
        the shape of the input_matrix argument is not 3x3x3x3, i. e. in true
        tensor notation.  Issues a warning if the input_matrix argument does
        not satisfy standard symmetries.  Note that the constructor uses
        __new__ rather than __init__ according to the standard method of
        subclassing numpy ndarrays.
        :param input_array: The 3x3x3x3 array-like representing the elastic
            tensor
        :param tol: tolerance for initial symmetry test of tensor
        :type tol: float
        """
        obj = super(ElasticTensor, cls).__new__(cls,
                                                input_array,
                                                check_rank=4,
                                                tol=tol)
        return obj.view(cls)
    @property
    def compliance_tensor(self):
        """
        returns the Voigt-notation compliance tensor,
        which is the matrix inverse of the
        Voigt-notation elastic tensor
        """
        try:
            s_voigt = np.linalg.inv(self.voigt)
        except np.linalg.linalg.LinAlgError as err:
            raise ValueError("Failed to invert matrix: " + str(err))
        return ComplianceTensor.from_voigt(s_voigt)
    @property
    def k_voigt(self):
        """
        returns the K_v bulk modulus
        """
        return self.voigt[:3, :3].mean()
    @property
    def g_voigt(self):
        """
        returns the G_v shear modulus
        """
        return (2. * self.voigt[:3, :3].trace() - np.triu(
            self.voigt[:3, :3]).sum() + 3 * self.voigt[3:, 3:].trace()) / 15.
    @property
    def k_reuss(self):
        """
        returns the K_r bulk modulus
        """
        return 1. / self.compliance_tensor.voigt[:3, :3].sum()
    @property
    def g_reuss(self):
        """
        returns the G_r shear modulus
        """
        return 15. / (8. * self.compliance_tensor.voigt[:3, :3].trace() -
                      4. * np.triu(self.compliance_tensor.voigt[:3, :3]).sum() +
                      3. * self.compliance_tensor.voigt[3:, 3:].trace())
    @property
    def k_vrh(self):
        """
        returns the K_vrh (Voigt-Reuss-Hill) average bulk modulus
        """
        return 0.5 * (self.k_voigt + self.k_reuss)
    @property
    def g_vrh(self):
        """
        returns the G_vrh (Voigt-Reuss-Hill) average shear modulus
        """
        return 0.5 * (self.g_voigt + self.g_reuss)
    @property
    def y_mod(self):
        """
        Calculates Young's modulus (in SI units) using the
        Voigt-Reuss-Hill averages of bulk and shear moduli
        """
        return 9.e9 * self.k_vrh * self.g_vrh / (3. * self.k_vrh + self.g_vrh)
    @property
    def universal_anisotropy(self):
        """
        returns the universal anisotropy value
        """
        return 5. * self.g_voigt / self.g_reuss + \
            
self.k_voigt / self.k_reuss - 6.
    @property
    def homogeneous_poisson(self):
        """
        returns the homogeneous poisson ratio
        """
        return (1. - 2. / 3. * self.g_vrh / self.k_vrh) / \
               
(2. + 2. / 3. * self.g_vrh / self.k_vrh)
    @property
    def lam(self):
        """
        Returns lambda = (C11 + C22 + C33) / 3 - 2 * Mu in GPa.
        :return float: Lambda (GPa)
        """
        return self.voigt[:3, :3].trace() / 3 - 2 * self.mu
    @property
    def mu(self):
        """
        Returns my = (C44 + C55 + C66) / 3 in GPa.
        :return float: mu (GPa)
        """
        return self.voigt[3:, 3:].trace() / 3.
[docs]    @classmethod
    def from_independent_strains(cls,
                                 strains,
                                 stresses,
                                 eq_stress=None,
                                 vasp=False,
                                 tol=1e-10,
                                 dump_full_tensor=False):
        """
        Constructs the elastic tensor least-squares fit of independent strains
        :param strains: list of strain objects to fit
        :type strains: list[Strain]
        :param stresses: list of stress objects to use in fit corresponding to
            the list of strains
        :type stresses: list
        :param eq_stress: equilibrium stress to use in fitting
        :type eq_stress: Stress
        :param vasp: flag for whether the stress tensor should be converted
            based on vasp units/convention for stress
        :type vasp: bool
        :param tol: tolerance for removing near-zero elements of the resulting
            tensor
        :type tol: float
        :param dump_full_tensor: dump complete 6*6 stress and strain matrix
        :type dump_full_tensor: bool
        """
        def get_err(pfit, xdata, ydata):
            """
            Get Y coordinate errors of the the fitted point.
            """
            func = lambda x: pfit[0] * x + pfit[1]
            return [abs(y - func(x)) for x, y in zip(xdata, ydata)]
        strain_states = [tuple(ss) for ss in np.eye(6)]
        ss_dict = get_strain_state_dict(strains, stresses, eq_stress=eq_stress)
        if not set(strain_states) <= set(ss_dict.keys()):
            raise ValueError("Missing independent strain states: "
                             "{}".format(set(strain_states) - set(ss_dict)))
        if len(set(ss_dict.keys()) - set(strain_states)) > 0:
            warnings.warn("Extra strain states in strain-stress pairs "
                          "are neglected in independent strain fitting")
        # Convert units/sign convention of vasp stress tensor
        fac = -0.1 if vasp else 1.
        msg = 'Data points (Epsilon; MPa) for calculating term: C%d%d\n'
        msg_pfit = 'Polynomial fit:, %f, %f\n'
        msg_slope = 'Stress method strain of deformation (MPa) =, %f\n'
        log = ''
        c_ij = np.zeros((6, 6))
        r_squared_ij = np.zeros((6, 6))
        for i in range(6):
            istrains = ss_dict[strain_states[i]]["strains"]
            istresses = ss_dict[strain_states[i]]["stresses"]
            for j in range(6):
                slope, intercept, r_value, p_value, std_err = linregress(
                    istrains[:, i], istresses[:, j])
                c_ij[i, j] = slope
                r_squared_ij[i, j] = r_value**2
                # Write to the log, get the upper triangle constants
                if not dump_full_tensor:
                    if i > j:
                        continue
                log += msg % (i + 1, j + 1)
                pfit = np.array([slope, intercept]) * fac * GIGA_TO_MEGA
                log += msg_pfit % tuple(pfit)
                log += msg_slope % pfit[0]
                xdata = istrains[:, i]
                ydata = istresses[:, j] * fac * GIGA_TO_MEGA
                errors = get_err(pfit, xdata, ydata)
                for strain, stress, err in zip(xdata, ydata, errors):
                    log += '%f, %f, %f\n' % (strain, stress, err)
                log += '\n'
        c_ij *= fac
        c = cls.from_voigt(c_ij)
        c = c.zeroed(tol)
        c.r_squared = r_squared_ij
        c.log = log
        return c  
[docs]def find_eq_stress(strains, stresses, tol=1e-10):
    """
    Finds stress corresponding to zero strain state in stress-strain list
    :param strains: Nx3x3 array-like array corresponding to strains
    :param stresses: Nx3x3 array-like array corresponding to stresses
    :param tol: tolerance to find zero strain state
    :type tol: float
    """
    stress_array = np.array(stresses)
    strain_array = np.array(strains)
    eq_stress = stress_array[np.all(abs(strain_array) < tol, axis=(1, 2))]
    if eq_stress.size != 0:
        all_same = (abs(eq_stress - eq_stress[0]) < 1e-8).all()
        if len(eq_stress) > 1 and not all_same:
            raise ValueError("Multiple stresses found for equilibrium strain"
                             " state, please specify equilibrium stress or  "
                             " remove extraneous stresses.")
        eq_stress = eq_stress[0]
    else:
        warnings.warn("No eq state found, returning zero voigt stress")
        eq_stress = Stress(np.zeros((3, 3)))
    return eq_stress 
[docs]def get_strain_state_dict(strains,
                          stresses,
                          eq_stress=None,
                          tol=1e-10,
                          add_eq=True,
                          sort=True):
    """
    Creates a dictionary of voigt-notation stress-strain sets
    keyed by "strain state", i. e. a tuple corresponding to
    the non-zero entries in ratios to the lowest nonzero value,
    e.g. [0, 0.1, 0, 0.2, 0, 0] -> (0,1,0,2,0,0)
    This allows strains to be collected in stencils as to
    evaluate parameterized finite difference derivatives
    :param strains: Nx3x3 array-like strain matrices
    :param stresses: Nx3x3 array-like stress matrices
    :param eq_stress: Nx3x3 array-like equilibrium stress
    :param tol: tolerance for sorting strain states
    :type tol: float
    :param add_eq: flag for whether to add eq_strain to stress-strain sets for
        each strain state
    :type add_eq: bool
    :param sort: flag for whether to sort strain states
    :type sort: bool
    :rtype: OrderedDict
    :return: OrderedDict with strain state keys and dictionaries with
        stress-strain data corresponding to strain state
    """
    # Recast stress/strains
    vstrains = np.array([Strain(s).zeroed(tol).voigt for s in strains])
    vstresses = np.array([Stress(s).zeroed(tol).voigt for s in stresses])
    # Collect independent strain states:
    independent = set(
        [tuple(np.nonzero(vstrain)[0].tolist()) for vstrain in vstrains])
    strain_state_dict = OrderedDict()
    if add_eq:
        if eq_stress is not None:
            veq_stress = Stress(eq_stress).voigt
        else:
            veq_stress = find_eq_stress(strains, stresses).voigt
    for n, ind in enumerate(independent):
        # match strains with templates
        template = np.zeros(6, dtype=bool)
        np.put(template, ind, True)
        template = np.tile(template, [vstresses.shape[0], 1])
        mode = (template == (np.abs(vstrains) > 1e-10)).all(axis=1)
        mstresses = vstresses[mode]
        mstrains = vstrains[mode]
        # Get "strain state", i.e. ratio of each value to minimum strain
        min_nonzero_ind = np.argmin(np.abs(np.take(mstrains[-1], ind)))
        min_nonzero_val = np.take(mstrains[-1], ind)[min_nonzero_ind]
        strain_state = mstrains[-1] / min_nonzero_val
        strain_state = tuple(strain_state)
        if add_eq:
            # add zero strain state
            mstrains = np.vstack([mstrains, np.zeros(6)])
            mstresses = np.vstack([mstresses, veq_stress])
        # sort strains/stresses by strain values
        if sort:
            mstresses = mstresses[mstrains[:, ind[0]].argsort()]
            mstrains = mstrains[mstrains[:, ind[0]].argsort()]
        strain_state_dict[strain_state] = {
            "strains": mstrains,
            "stresses": mstresses
        }
    return strain_state_dict