# -*- coding: utf-8 -*-
"""This module defines functions for calculating atomic properties from normal
modes."""
from past.utils import old_div
import numpy as np
from schrodinger.trajectory.prody.mode import Mode
from schrodinger.trajectory.prody.modeset import ModeSet
from schrodinger.trajectory.prody.pca import NMA
[docs]def calcCrossCorr(modes, n_cpu=1):
    """Return cross-correlations matrix.  For a 3-d model, cross-correlations
    matrix is an NxN matrix, where N is the number of atoms.  Each element of
    this matrix is the trace of the submatrix corresponding to a pair of atoms.
    Covariance matrix may be calculated using all modes or a subset of modes
    of an NMA instance.  For large systems, calculation of cross-correlations
    matrix may be time consuming.  Optionally, multiple processors may be
    employed to perform calculations by passing ``n_cpu=2`` or more."""
    if not isinstance(n_cpu, int):
        raise TypeError('n_cpu must be an integer')
    elif n_cpu < 1:
        raise ValueError('n_cpu must be equal to or greater than 1')
    if not isinstance(modes, (Mode, NMA, ModeSet)):
        raise TypeError('modes must be a Mode, NMA, or ModeSet instance, '
                        'not {0}'.format(type(modes)))
    if modes.is3d():
        model = modes
        if isinstance(modes, (Mode, ModeSet)):
            model = modes._model
            if isinstance(modes, (Mode)):
                indices = [modes.getIndex()]
                n_modes = 1
            else:
                indices = modes.getIndices()
                n_modes = len(modes)
        else:
            n_modes = len(modes)
            indices = np.arange(n_modes)
        array = model._array
        n_atoms = model._n_atoms
        variances = model._vars
        if n_cpu == 1:
            s = (n_modes, n_atoms, 3)
            arvar = (array[:, indices] * variances[indices]).T.reshape(s)
            array = array[:, indices].T.reshape(s)
            covariance = np.tensordot(array.transpose(2, 0, 1),
                                      arvar.transpose(0, 2, 1),
                                      axes=([0, 1], [1, 0]))
            print(covariance.shape)
        else:
            import multiprocessing
            n_cpu = min(multiprocessing.cpu_count(), n_cpu)
            queue = multiprocessing.Queue()
            size = old_div(n_modes, n_cpu)
            for i in range(n_cpu):
                if n_cpu - i == 1:
                    indices = modes.indices[i * size:]
                else:
                    indices = modes.indices[i * size:(i + 1) * size]
                process = multiprocessing.Process(target=_crossCorrelations,
                                                  args=(queue, n_atoms, array,
                                                        variances, indices))
                process.start()
            covariance = queue.get()
            while queue.qsize() > 0:
                covariance += queue.get()
    else:
        covariance = calcCovariance(modes)
    diag = np.power(covariance.diagonal(), 0.5)
    return old_div(covariance, np.outer(diag, diag)) 
def _crossCorrelations(queue, n_atoms, array, variances, indices):
    """Calculate covariance-matrix for a subset of modes."""
    n_modes = len(indices)
    arvar = (array[:, indices] * variances[indices]).T.reshape(
        (n_modes, n_atoms, 3))
    array = array[:, indices].T.reshape((n_modes, n_atoms, 3))
    covariance = np.tensordot(array.transpose(2, 0, 1),
                              arvar.transpose(0, 2, 1),
                              axes=([0, 1], [1, 0]))
    queue.put(covariance)
[docs]def calcCovariance(modes):
    """Return covariance matrix calculated for given *modes*."""
    if isinstance(modes, Mode):
        array = modes._getArray()
        return np.outer(array, array) * modes.getVariance()
    elif isinstance(modes, ModeSet):
        array = modes._getArray()
        return np.dot(array, np.dot(np.diag(modes.getVariances()), array.T))
    elif isinstance(modes, NMA):
        return modes.getCovariance()
    else:
        raise TypeError('modes must be a Mode, NMA, or ModeSet instance')