# -*- coding: utf-8 -*-
#
# ProDy: A Python Package for Protein Dynamics Analysis
#
# Copyright (C) 2010-2014 University of Pittsburgh
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
"""This module defines classes for principal component analysis (PCA) and
essential dynamics analysis (EDA) calculations."""
import time
from past.utils import old_div
import numpy as np
from .ensemble import Ensemble
from .ensemble import importLA
from .nma import NMA
__all__ = ['PCA', 'EDA']
[docs]class PCA(NMA):
    """A class for Principal Component Analysis (PCA) of conformational
    ensembles. See examples in :ref: `pca`."""
[docs]    def __init__(self, name='Unknown'):
        NMA.__init__(self, name) 
[docs]    def setCovariance(self, covariance):
        """Set covariance matrix."""
        if not isinstance(covariance, np.ndarray):
            raise TypeError('covariance must be an ndarray')
        elif not (covariance.ndim == 2 and
                  covariance.shape[0] == covariance.shape[1]):
            raise TypeError('covariance must be square matrix')
        self._reset()
        self._cov = covariance
        self._dof = covariance.shape[0]
        self._n_atoms = old_div(self._dof, 3)
        self._trace = self._cov.trace() 
[docs]    def buildCovariance(self, coordsets, **kwargs):
        """Build a covariance matrix for *coordsets* using mean coordinates
        as the reference.  *coordsets* argument may be one of the following:
        * :class: `.Ensemble`
        * :class: `numpy.ndarray` with shape ``(n_csets, n_atoms, 3)``
        For ensemble and trajectory objects, ``update_coords=True`` argument
        can be used to set the mean coordinates as the coordinates of the
        object.
        When *coordsets* is a trajectory object, such as :class: `.DCDFile`,
        covariance will be built by superposing frames onto the reference
        coordinate set (see :meth: `.Frame.superpose`).  If frames are already
        aligned, use ``aligned=True`` argument to skip this step.
        """
        mean = None
        weights = None
        ensemble = None
        if isinstance(coordsets, np.ndarray):
            if (coordsets.ndim != 3 or coordsets.shape[2] != 3 or
                    coordsets.dtype not in (np.float32, float)):
                raise ValueError('coordsets is not a valid coordinate array')
        elif isinstance(coordsets, Ensemble):
            ensemble = coordsets
            coordsets = coordsets._getCoordsets()
        update_coords = bool(kwargs.get('update_coords', False))
        n_confs = coordsets.shape[0]
        if n_confs < 3:
            raise ValueError('coordsets must have more than 3 coordinate '
                             'sets')
        n_atoms = coordsets.shape[1]
        if n_atoms < 3:
            raise ValueError('coordsets must have more than 3 atoms')
        dof = n_atoms * 3
        print('Covariance is calculated using coordinate sets.')
        s = (n_confs, dof)
        if weights is None:
            if coordsets.dtype == float:
                self._cov = np.cov(coordsets.reshape((n_confs, dof)).T, bias=1)
            else:
                cov = np.zeros((dof, dof))
                coordsets = coordsets.reshape((n_confs, dof))
                mean = coordsets.mean(0)
                # print('Building covariance', n_confs, '_prody_pca')
                for i, coords in enumerate(coordsets.reshape(s)):
                    deviations = coords - mean
                    cov += np.outer(deviations, deviations)
                cov /= n_confs
                self._cov = cov
        else:
            # PDB ensemble case
            mean = np.zeros((n_atoms, 3))
            for i, coords in enumerate(coordsets):
                mean += coords * weights[i]
            mean /= weights.sum(0)
            d_xyz = ((coordsets - mean) * weights).reshape(s)
            divide_by = weights.astype(float).repeat(3, axis=2).reshape(s)
            self._cov = old_div(np.dot(d_xyz.T, d_xyz),
                                np.dot(divide_by.T, divide_by))
        if update_coords and ensemble is not None:
            if mean is None:
                mean = coordsets.mean(0)
            ensemble.setCoords(mean)
        self._trace = self._cov.trace()
        self._dof = dof
        self._n_atoms = n_atoms 
[docs]    def calcModes(self, n_modes=20, turbo=True):
        """Calculate principal (or essential) modes.  This method uses
        :func: `scipy.linalg.eigh`, or :func: `numpy.linalg.eigh`, function
        to diagonalize the covariance matrix.
        :arg n_modes: number of non-zero eigenvalues/vectors to calculate,
            default is 20, for **None** all modes will be calculated
        :type n_modes: int
        :arg turbo: when available, use a memory intensive but faster way to
            calculate modes, default is **True**
        :type turbo: bool"""
        linalg = importLA()
        if self._cov is None:
            raise ValueError('covariance matrix is not built or set')
        start = time.time()
        dof = self._dof
        if linalg.__package__.startswith('scipy'):
            if n_modes is None:
                eigvals = None
                n_modes = dof
            else:
                n_modes = int(n_modes)
                if n_modes >= self._dof:
                    eigvals = None
                    n_modes = dof
                else:
                    eigvals = (dof - n_modes, dof - 1)
            values, vectors = linalg.eigh(self._cov,
                                          turbo=turbo,
                                          eigvals=eigvals)
        else:
            if n_modes is not None:
                print('Scipy is not found, all modes are calculated.')
            values, vectors = linalg.eigh(self._cov)
        # Order by descending SV
        revert = list(range(len(values) - 1, -1, -1))
        values = values[revert]
        vectors = vectors[:, revert]
        which = values > 1e-8
        self._eigvals = values[which]
        self._array = vectors[:, which]
        self._vars = self._eigvals
        self._n_modes = len(self._eigvals) 
[docs]    def addEigenpair(self, eigenvector, eigenvalue=None):
        """Add eigen *vector* and eigen *value* pair(s) to the instance.
        If eigen *value* is omitted, it will be set to 1.  Eigenvalues
        are set as variances."""
        NMA.addEigenpair(self, eigenvector, eigenvalue)
        self._vars = self._eigvals 
[docs]    def setEigens(self, vectors, values=None):
        """Set eigen *vectors* and eigen *values*.  If eigen *values* are
        omitted, they will be set to 1.  Eigenvalues are set as variances."""
        NMA.setEigens(self, vectors, values)
        self._vars = self._eigvals  
[docs]class EDA(PCA):
    """
    A class for Essential Dynamics Analysis (EDA) [AA93]_.
    See examples in :ref: `eda`.
    .. [AA93] Amadei A, Linssen AB, Berendsen HJ. Essential dynamics of
       proteins. *Proteins* **1993** 17(4):412-25."""