"""
Classes and functions to deal reading XML generated by Quantum Espresso.
Copyright Schrodinger, LLC. All rights reserved."""
import math
import operator
import os
import tarfile
from collections import defaultdict
from collections import namedtuple
from itertools import islice
from xml.etree import ElementTree
import numpy
import scipy.constants as const
from schrodinger import structure
from schrodinger.application.desmond.constants import FF_CUSTOM_CHARGE
from schrodinger.application.matsci import msutils
from schrodinger.application.matsci import msprops
from schrodinger.application.matsci.clusterstruct import DEFAULT_CHARGE_PROP
from schrodinger.application.matsci.espresso import utils as qeu
from schrodinger.application.matsci.nano import xtal
from schrodinger.structutils import color
# Paths in XML
INPUT_TAG = 'input'
TITLE_TAG = INPUT_TAG + '/control_variables/title'
HAS_FORCES = INPUT_TAG + '/control_variables/forces'
KPTS_IBZ = INPUT_TAG + '/k_points_IBZ'
TOT_CHARGE_TAG = INPUT_TAG + '/bands/tot_charge'
ESM_BC_TAG = INPUT_TAG + '/boundary_conditions/esm/bc'
ESM_EFIELD_TAG = INPUT_TAG + '/boundary_conditions/esm/efield'
MD_TIMESTEP_TAG = INPUT_TAG + '/ion_control/md/timestep'
NOSYM_TAG = INPUT_TAG + '/symmetry_flags/nosym'
FREE_POS_TAG = INPUT_TAG + '/free_positions'
OUTPUT_TAG = 'output'
ATOMIC_STRUCTURE_TAG = 'atomic_structure'
ATOMIC_POSITIONS_TAG = ATOMIC_STRUCTURE_TAG + '/atomic_positions/atom'
CELL_VECTORS_TAG = ATOMIC_STRUCTURE_TAG + '/cell/a%d'
ATOMIC_SPECIES_TAG = 'atomic_species/species'
DFT_HU_TAG = 'dft/dftU/Hubbard_U'
TOT_ENERGY_TAG = 'total_energy/etot'
TOT_MAG_TAG = 'magnetization/total'
ECUTWFC_TAG = 'basis_set/ecutwfc'
ECUTRHO_TAG = 'basis_set/ecutrho'
DFT_FUNCT_TAG = 'dft/functional'
DFT_VDW_TAG = 'dft/vdW/vdw_corr'
BAND_STRUCTURE_TAG = 'band_structure'
FORCES_TAG = 'forces'
NBND_TAG = BAND_STRUCTURE_TAG + '/nbnd'
NKS_TAG = BAND_STRUCTURE_TAG + '/nks'
NBND_UP_TAG = BAND_STRUCTURE_TAG + '/nbnd_up'
NBND_DW_TAG = BAND_STRUCTURE_TAG + '/nbnd_dw'
EFERMI_TAG = BAND_STRUCTURE_TAG + '/fermi_energy'
TWO_EFERMIS_TAG = BAND_STRUCTURE_TAG + '/two_fermi_energies'
HOMO_TAG = BAND_STRUCTURE_TAG + '/highestOccupiedLevel'
KS_ENERGIES_TAG = BAND_STRUCTURE_TAG + '/ks_energies'
STRESS_TAG = 'stress'
# Timing
TIMING_TAG = 'timing_info'
TIMING_TOTAL_CPU_TAG = TIMING_TAG + '/total/cpu'
TIMING_TOTAL_WALL_TAG = TIMING_TAG + '/total/wall'
SPIN_UP = 'up'
SPIN_DW = 'dw'
BAND_INDEX_KEY = 'band_index'
KPOINT_INDEX_KEY = 'kpoint_index'
KPOINT_KEY = 'kpoint'
ENERGY_KEY = 'energy'
DIRECT_KEY = 'direct'
IS_METAL_KEY = 'is_metal'
IS_DIRECT_KEY = 'is_direct'
BAND_GAP_KEY = 'band_gap'
TRANSITION_KEY = 'transition'
SQRT_1PI = 1.0 / math.sqrt(math.pi)
DOS_ENERGY_GRID_SPACING = 0.01  # eV
KptLegend = namedtuple('KptLegend', ['label', 'coords'])
WfcType = namedtuple('WfcType',
                     ['atom_idx', 'atom_type', 'n_qn', 'l_qn', 'm_qn'])
EsmType = namedtuple('EsmType', ['data', 'bc_type', 'efield'])
BOLTZ_THZ_PER_K = const.value(
    "Boltzmann constant in Hz/K") / const.tera  # Boltzmann constant in THz/K
THZ_TO_J = const.value("hertz-joule relationship") * const.tera
[docs]def gaussian_delta(eigenval, energies, degauss):
    """
    Get Gaussian-function values
    :type eigenval: float
    :param eigenval: Energy at which to calculate Gaussian delta
    :type energies: `numpy.array`
    :param energies: Energy grid
    :type degauss: float
    :param degauss: Broadening
    :rtype: `numpy.array`
    :return: delta values on the grid
    """
    arg = (energies - eigenval)**2 / (degauss**2)
    # In previous versions (< 17-4) if argument of the exp() was smaller than
    # -10, exponent became to small for matplotlib and 'spikes' were observed.
    # In the current version (17-4), I see no spikes in the plots, thus
    # releasing the condition. '10' has been identified empirically by
    # plotting DOS.
    exp_arg = numpy.exp(-arg)
    return exp_arg 
[docs]class KPoint(object):
    """
    Class to hold information about a k-point.
    """
[docs]    def __init__(self, tag, vecs=None):
        """
        Initialize KPoint object from ElementTree element.
        :type tag: `xml.etree.ElementTree.Element`
        :param tag: k_point tag
        :type vecs: numpy array (3x3)
        :param vecs: Cell vectors
        """
        self.cart_coords = numpy.array(self.getCoords(tag.text))
        self.weight = float(tag.attrib['weight'])
        try:
            self.label, self.frac_coords = list(
                map(str.strip, tag.attrib['label'].split('=', 1)))
            self.frac_coords = numpy.array(self.getCoords(self.frac_coords))
        except KeyError:  # 'label' attribute is missing
            self.label = ''
            self.frac_coords = []
        if vecs is not None and not len(self.frac_coords):
            alat = math.sqrt(numpy.dot(vecs[0], vecs[0]))
            frac = xtal.trans_cart_to_frac_from_vecs(self.cart_coords,
                                                     *vecs,
                                                     rec=True)
            self.frac_coords = frac / alat 
[docs]    def getCoords(self, coords_str):
        """
        Return list of coordinates.
        :type coords_str: str
        :param coords_str: String representing K-point coordinates
        :rtype: list of three floats
        :return: K-point coordinates
        """
        return list(map(float, coords_str.strip().split())) 
[docs]    def getCoordsStr(self, frac=False):
        """
        Get string representation of the coordinates.
        :type frac: bool
        :param frac: If True, self.frac_coords are returned, otherwise
            self.cart_coords
        :rtype: str
        :return: String representation of the coordinates
        """
        if frac:
            return '%.3f, %.3f, %.3f' % tuple(self.frac_coords)
        else:
            return '%.3f, %.3f, %.3f' % tuple(self.cart_coords) 
[docs]    @staticmethod
    def getKptFromCfg(kpt):
        """
        """
        kpt_np = numpy.array(kpt[:3])
        label = '%.3f %.3f %.3f' % tuple(kpt_np)
        if len(kpt) == 4:
            label = str(kpt[3]) + ' = ' + label
        return kpt_np, label  
[docs]class DOS(object):
    """
    Basic DOS class, based on the class from pymatgen (MIT license).
    """
[docs]    def __init__(self, band, dos_fn):
        """
        Initialize DOS object. 'band' can be None, in this case, dos_fn MUST
        point to the .dos file. If both 'band' and 'dos_fn' are present, former
        one has priority.
        :type band: `BandStructure`
        :param band: BandStructure object to extract eigenvalues and k-points
            from
        :type dos_fn: str
        :param dos_fn: .dos filename. This file holds DOS plot data
        """
        self.dos = None
        if band is None:
            self.band = None
            self._getDOSFromFile(dos_fn)
        else:
            self.band = band
            self.efermi = self.band.efermi
            self.is_spin_polarized = band.is_spin_polarized 
    def _getDOSFromFile(self, dos_fn):
        """
        Read energies and densities from a .dos file and sets: self.energies_ev
        and self.dos
        :type dos_fn: str
        :param dos_fn: .dos filename. This file holds DOS plot data
        """
        data = numpy.loadtxt(dos_fn)
        if len(data) < 3:
            raise ValueError('Unknown data format in %s file' % dos_fn)
        self.energies_ev = data[:, 0]
        self.dos = {SPIN_UP: data[:, 1]}
        # Spin-polarized case
        self.is_spin_polarized = data.shape[1] == 6
        if self.is_spin_polarized:
            self.dos[SPIN_DW] = data[:, 2]
            self.idos = {SPIN_UP: data[:, 3], SPIN_DW: data[:, 4]}
        else:
            self.idos = {SPIN_UP: data[:, 2]}
        with open(dos_fn) as file_fh:
            line = file_fh.readline()
            if not line.startswith('#'):
                raise ValueError('Could not find correct data in %s file.' %
                                 dos_fn)
            # First line from .dos should look like:
            # '#  E (eV)   dos(E)     Int dos(E) EFermi =   4.597 eV'
            try:
                self.efermi = float(line.split()[-2]) / qeu.RY2EV
            except (TypeError, IndexError) as msg:
                raise ValueError('Could not parse Fermi energy from line: '
                                 '"%s" in %s file.' % (line, dos_fn))
[docs]    def canBroaden(self):
        """
        Can getDOS (broadening) be called on this object.
        :rtype: bool
        :return: True if it can, otherwise False
        """
        return self.band is not None 
[docs]    def getDOS(self, degauss, delta_e=DOS_ENERGY_GRID_SPACING):
        """
        Broaden energies and set DOS in self.dos. This requires self.band to be
        set in the constructor. Saves degauss in self.degauss.
        :type degauss: float
        :param degauss: Used only if dos is True, broadening (eV) for computing
            DOS
        :type delta_e: float
        :param delta_e: Used only if dos is True, energy grid spacing (in eV)
        :raise ValueError: If self.band is None
        """
        if self.band is None:
            raise ValueError('getDOS is called, however self.band is None')
        self.degauss = degauss / qeu.RY2EV
        e_min = self.band.bands[SPIN_UP][0].min(0)
        if self.band.is_spin_polarized:
            e_min = min(e_min, self.band.bands[SPIN_DW][0].min(0))
        e_min -= 3.0 * self.degauss
        e_max = self.band.bands[SPIN_UP][-1].max(0)
        if self.band.is_spin_polarized:
            e_max = max(e_max, self.band.bands[SPIN_DW][-1].max(0))
        e_max += 3.0 * self.degauss
        delta_e /= qeu.RY2EV
        self.energies = numpy.arange(e_min - delta_e, e_max + delta_e, delta_e)
        self.energies_ev = self.energies * qeu.RY2EV
        self.dos = {SPIN_UP: numpy.zeros(shape=(len(self.energies)))}
        if self.band.is_spin_polarized:
            self.dos[SPIN_DW] = numpy.zeros(shape=(len(self.energies)))
        self.idos = {SPIN_UP: numpy.zeros(shape=(len(self.energies)))}
        if self.band.is_spin_polarized:
            self.idos[SPIN_DW] = numpy.zeros(shape=(len(self.energies)))
        for spin in self.band.bands:
            for iband in range(self.band.nb_bands):
                for jkpoint in range(len(self.band.kpoints)):
                    eigenval = self.band.bands[spin][iband][jkpoint]
                    kpt_weight = self.band.kpoints[jkpoint].weight
                    delta = gaussian_delta(eigenval, self.energies,
                                           self.degauss)
                    self.dos[spin] += (kpt_weight * delta * SQRT_1PI /
                                       self.degauss)
            self.idos[spin] = numpy.cumsum(self.dos[spin]) * delta_e
            self.dos[spin] /= qeu.RY2EV 
[docs]    def getDensities(self, spin=None):
        """
        Get density of states for a particular spin.
        :type spin: str or None
        :param spin: Can be SPIN_UP or SPIN_DW or None.
        :rtype: `numpy.array`
        :return: Density of states for a particular spin. If Spin is None,
            the sum of all spins is returned.
        """
        if self.dos is None:
            return None
        elif spin is None:
            if SPIN_DW in self.dos:
                result = self.dos[SPIN_UP] + self.dos[SPIN_DW]
            else:
                result = self.dos[SPIN_UP].copy()
        else:
            result = self.dos[spin].copy()
        return result 
[docs]    def getCbmVbm(self, tol=0.001, abs_tol=False, spin=None):
        """
        Get Conduction Band Minimum (cbm) and Valence Band Maximum (vbm).
        :type tol: float
        :param: tolerance in occupations for determining the cbm/vbm
        :type abs_tol: bool
        :param abs_tol: An absolute tolerance (True) or a relative one (False)
        :type spin: str or None
        :param spin: Possible values are None - finds the cbm/vbm in the summed
            densities, SPIN_UP - finds the cbm/vbm in the up spin channel,
            SPIN_DW - finds the cbm/vbm in the down spin channel.
        :rtype: float, float
        :return: cbm and vbm in Ry corresponding to the gap
        """
        # determine tolerance
        tdos = self.getDensities(spin)
        if not abs_tol:
            tol = tol * tdos.sum() / tdos.shape[0]
        # find index of fermi energy
        i_fermi = 0
        while self.energies[i_fermi] <= self.efermi:
            i_fermi += 1
        # work backwards until tolerance is reached
        i_gap_start = i_fermi
        while i_gap_start - 1 >= 0 and tdos[i_gap_start - 1] <= tol:
            i_gap_start -= 1
        # work forwards until tolerance is reached
        i_gap_end = i_gap_start
        while i_gap_end < tdos.shape[0] and tdos[i_gap_end] <= tol:
            i_gap_end += 1
        i_gap_end -= 1
        return self.energies[i_gap_end], self.energies[i_gap_start] 
[docs]    def getGap(self, tol=0.001, abs_tol=False, spin=None):
        """
        Get the gap.
        :type tol: float
        :param: tolerance in occupations for determining the gap
        :type abs_tol: bool
        :param abs_tol: An absolute tolerance (True) or a relative one (False)
        :type spin: str or None
        :param spin: Possible values are None - finds the gap in the summed
            densities, SPIN_UP - finds the gap in the up spin channel,
            SPIN_DW - finds the gap in the down spin channel.
        :rtype: float
        :return: gap in Ry or 0.0, if it is a metal
        """
        (cbm, vbm) = self.getCbmVbm(tol, abs_tol, spin)
        return max(cbm - vbm, 0.0)  
[docs]class PhDOS(object):
    """
    Phonon DOS class.
    """
[docs]    def __init__(self, file_fh):
        """
        Initialize PhDOS object.
        :type file_fh: File handler of phonon DOS from matdyn.x
        :param file_fh: file object
        """
        self.frequencies, self.densities = numpy.loadtxt(file_fh,
                                                         unpack=True,
                                                         usecols=(0, 1))
        # Convert from cm-1 (output of matdyn) to THz (expected by c_v, etc.)
        self.frequencies /= qeu.THZ2CM1
        ind = numpy.searchsorted(self.frequencies, 0)
        if ind >= len(self.frequencies):
            raise IOError('No positive frequencies found.')
        self.positive_frequencies = self.frequencies[ind:]
        self.positive_densities = self.densities[ind:] 
[docs]    def c_v(self, temperature):
        """
        Constant volume specific heat C_v at temperature T obtained from the
        integration of the DOS. Only positive frequencies will be used.
        Result in J/(K*mol-c). A mol-c is the abbreviation of a mole-cell,
        that is, the number of Avogadro times the atoms in a unit cell. To
        compare with experimental data the result should be divided by the
        number of unit formulas in the cell. If the structure is provided
        the division is performed internally and the result is in J/(K*mol).
        :type temperature: float
        :param temperature: Temperature at which to evaluate C_v, in K
        :rtype: float
        :return: Constant volume specific heat C_v in J/(K*mol)
        """
        # This function is based on pymatgen/phonon/dos.py (MIT, LEGAL-413)
        if temperature == 0:
            return 0.
        freqs = self.positive_frequencies
        dens = self.positive_densities
        csch2 = lambda x: 1. / (numpy.sinh(x)**2)
        wd2kt = freqs / (2. * BOLTZ_THZ_PER_K * temperature)
        csch2_val = csch2(wd2kt)
        csch2_val[csch2_val == numpy.inf] = 0.
        c_v = numpy.trapz(wd2kt**2 * csch2_val * dens, x=freqs)
        c_v *= const.Boltzmann * const.Avogadro
        return c_v 
[docs]    def entropy(self, temperature):
        """
        Vibrational entropy at temperature T obtained from the integration of
        the DOS. Only positive frequencies will be used. Result in J/(K*mol-c).
        A mol-c is the abbreviation of a mole-cell, that is, the number of
        Avogadro times the atoms in a unit cell. To compare with experimental
        data the result should be divided by the number of unit formulas in the
        cell. If the structure is provided the division is performed internally
        and the result is in J/(K*mol).
        :type temperature: float
        :param temperature: Temperature at which to evaluate C_v, in K
        :rtype: float
        :return: Vibrational entropy in J/(K*mol)
        """
        # This function is based on pymatgen/phonon/dos.py (MIT, LEGAL-413)
        if temperature == 0:
            return 0.
        freqs = self.positive_frequencies
        dens = self.positive_densities
        coth = lambda x: 1. / numpy.tanh(x)
        wd2kt = freqs / (2. * BOLTZ_THZ_PER_K * temperature)
        coth_val = coth(wd2kt)
        coth_val[coth_val == numpy.inf] = 0.
        log_val = numpy.log(2 * numpy.sinh(wd2kt))
        log_val[log_val == -numpy.inf] = 0.
        ent = numpy.trapz((wd2kt * coth_val - log_val) * dens, x=freqs)
        ent *= const.Boltzmann * const.Avogadro
        return ent 
[docs]    def internal_energy(self, temperature):
        """
        Phonon contribution to the internal energy at temperature T obtained
        from the integration of the DOS. Only positive frequencies will be used.
        Result in J/mol-c. A mol-c is the abbreviation of a mole-cell, that is,
        the number of Avogadro times the atoms in a unit cell. To compare with
        experimental data the result should be divided by the number of unit
        formulas in the cell. If the structure is provided the division is
        performed internally and the result is in J/mol.
        :type temperature: float
        :param temperature: Temperature at which to evaluate energy, in K
        :rtype: float
        :return: Phonon contribution to the internal energy, in J/mol.
        """
        # This function is based on pymatgen/phonon/dos.py (MIT, LEGAL-413)
        if temperature == 0:
            return self.zero_point_energy()
        freqs = self.positive_frequencies
        dens = self.positive_densities
        coth = lambda x: 1. / numpy.tanh(x)
        wd2kt = freqs / (2. * BOLTZ_THZ_PER_K * temperature)
        coth_val = coth(wd2kt)
        coth_val[coth_val == numpy.inf] = 0.
        energy = numpy.trapz(freqs * coth_val * dens, x=freqs) / 2.
        energy *= THZ_TO_J * const.Avogadro
        return energy 
[docs]    def helmholtz_free_energy(self, temperature):
        """
        Phonon contribution to the Helmholtz free energy at temperature T
        obtained from the integration of the DOS. Only positive frequencies will
        be used. Result in J/mol-c. A mol-c is the abbreviation of a mole-cell,
        that is, the number of Avogadro times the atoms in a unit cell. To
        compare with experimental data the result should be divided by the
        number of unit formulas in the cell. If the structure is provided the
        division is performed internally and the result is in J/mol.
        :type temperature: float
        :param temperature: Temperature at which to evaluate free energy, in K
        :rtype: float
        :return: Phonon contribution to the Helmholtz free energy, in J/mol
        """
        if temperature == 0:
            return self.zero_point_energy()
        freqs = self.positive_frequencies
        dens = self.positive_densities
        wd2kt = freqs / (2. * BOLTZ_THZ_PER_K * temperature)
        log_val = numpy.log(2. * numpy.sinh(wd2kt))
        log_val[log_val == -numpy.inf] = 0.
        fenergy = numpy.trapz(log_val * dens, x=freqs)
        fenergy *= const.Boltzmann * const.Avogadro * temperature
        return fenergy 
[docs]    def zero_point_energy(self):
        """
        Zero point energy energy of the system. Only positive frequencies will
        be used. Result in J/mol-c. A mol-c is the abbreviation of a mole-cell,
        that is, the number of Avogadro times the atoms in a unit cell. To
        compare with experimental data the result should be divided by the
        number of unit formulas in the cell. If the structure is provided the
        division is performed internally and the result is in J/mol.
        :type temperature: float
        :param temperature: Temperature at which to evaluate ZPE, in K
        :rtype: float
        :return: Phonon contribution to ZPE, in J/mol
        """
        freqs = self.positive_frequencies
        dens = self.positive_densities
        zpe = 0.5 * numpy.trapz(freqs * dens, x=freqs)
        zpe *= THZ_TO_J * const.Avogadro
        return zpe  
[docs]class PDOS(object):
    """
    Class that holds partial DOS (PDOS) data. Call getPDOS to get broadened
    data.
    """
    # PDOS can be (or not) summed over:
    # - PDOS: original data - currently disabled, due to rare probable use
    # - Local PDOS (LDOS): Summed over m quantum number
    # - Atomic PDOS (ADOS): Summed over atomic index and m quantum number
    # - Element PDOS (EDOS): Summed over atomic index and all quantum numbers
    # - AIDOS: Summed over atomic type and all quantum numbers
    # - ALDOS: Summed over atomic type only
    NUM_IDX = 5
    LDOS_IDX, ADOS_IDX, EDOS_IDX, AIDOS_IDX, ALDOS_IDX = list(range(NUM_IDX))
[docs]    def __init__(self, proj, wfc_types, efermi, band):
        """
        Initialize PDOS object. Constructor only assigns values, call getPDOS
        for broadening.
        :type proj: dict of 3d numpy.array
        :param proj: Dict with SPIN_UP, SPIN_DW (optional) keys, each containing
            a 3D array containing: index of projected atom wavefunction, index of
            k-point, index of band and WFC projection as value.
        :type wfc_types: list of `WfcType`
        :param wfc_types: List containing wavefunction types description
        :type efermi: float
        :param efermi: Fermi energy in eV
        :type band: `BandStructure`
        :type band: Band structure object, needed for k-points weights and
            energies
        """
        self.proj = proj
        self.wfc_types = wfc_types
        self.efermi = efermi
        self.band = band
        self.energies = None
        self.kpts_weights = [kpt.weight for kpt in self.band.kpoints]
        self.is_spin_polarized = len(self.proj) == 2
        self.nwfc, self.nkpts, self.nbnd = proj[SPIN_UP].shape 
    @staticmethod
    def _getPDOSDict(is_spin_polarized, nsteps):
        """
        Get an empty dictionary with correct number of spin keys and arrays for
        PDOS data.
        :type is_spin_polarized: bool
        :param: If True, create dict for holding spin polarized PDOS, otherwise
            create for closed-shell PDOS
        :type nsteps: int
        :param nsteps: 1D grid size
        :rtype: dict
        :return: Dictionary to hold PDOS data
        """
        dos = {SPIN_UP: numpy.zeros(shape=(nsteps))}
        if is_spin_polarized:
            dos[SPIN_DW] = numpy.zeros(shape=(nsteps))
        return dos
[docs]    def getPDOS(self, degauss, delta_e=DOS_ENERGY_GRID_SPACING):
        """
        Calculate PDOS and set in self.pdos. Saves degauss in self.degauss.
        :type degauss: float
        :param degauss: Broadening (eV) for computing PDOS
        :type delta_e: float
        :param delta_e: Energy grid spacing eV
        """
        self.degauss = degauss / qeu.RY2EV
        delta_e /= qeu.RY2EV
        e_min = self.band.bands[SPIN_UP][0].min(0)
        if self.band.is_spin_polarized:
            e_min = min(e_min, self.band.bands[SPIN_DW][0].min(0))
        e_min -= 3.0 * self.degauss
        e_max = self.band.bands[SPIN_UP][-1].max(0)
        if self.band.is_spin_polarized:
            e_max = max(e_max, self.band.bands[SPIN_DW][-1].max(0))
        e_max += 3.0 * self.degauss
        nsteps = int(round((e_max - e_min) / delta_e + 0.5))
        self.energies = numpy.array(
            [e_min + delta_e * nstep for nstep in range(nsteps)]) * qeu.RY2EV
        self.pdos = [None] * self.NUM_IDX
        for idx in range(self.NUM_IDX):
            # Argument MUST be callable, that's why lambda
            self.pdos[idx] = defaultdict(
                lambda: self._getPDOSDict(self.is_spin_polarized, nsteps))
        # TDOS is total PDOS (it is different that just DOS)
        self.tdos = {SPIN_UP: numpy.zeros(shape=(nsteps))}
        if SPIN_DW in self.proj:
            self.tdos[SPIN_DW] = numpy.zeros(shape=(nsteps))
        ie_delta = 5 * int(round(self.degauss / delta_e))
        for spin in self.proj:
            for kpt in range(self.nkpts):
                kptw = self.kpts_weights[kpt]
                for iband in range(self.nbnd):
                    eband = self.band.bands[spin][iband][kpt]
                    ie_mid = int(round((eband - e_min) / delta_e))
                    start_idx = max(ie_mid - ie_delta, 0)
                    end_idx = min(ie_mid + ie_delta, nsteps)
                    start = e_min + delta_e * start_idx
                    stop = e_min + delta_e * end_idx
                    energies = numpy.linspace(start, stop, end_idx - start_idx)
                    delta = gaussian_delta(eband, energies, self.degauss)
                    for nwfc in range(self.nwfc):
                        proj = self.proj[spin][nwfc, kpt, iband]
                        val = (delta * proj * kptw * SQRT_1PI / self.degauss /
                               qeu.RY2EV)
                        wfc_type = self.wfc_types[nwfc]
                        # Deal with PDOS - currently disabled
                        #pwfc_type = WfcType(wfc_type.atom_idx,
                        #    wfc_type.atom_type, wfc_type.n_qn, wfc_type.l_qn,
                        #    wfc_type.m_qn)
                        #self.pdos[self.PDOS_IDX][spin][pwfc_type][start_idx:end_idx] += val
                        # Deal with LDOS - currently disabled
                        #lwfc_type = WfcType(wfc_type.atom_idx,
                        #                    wfc_type.atom_type, wfc_type.n_qn,
                        #                    wfc_type.l_qn, None)
                        #self.pdos[self.LDOS_IDX][lwfc_type][spin][
                        #    start_idx:end_idx] += val
                        # Deal with ADOS - currently disabled
                        #awfc_type = WfcType(None, wfc_type.atom_type,
                        #                    wfc_type.n_qn, wfc_type.l_qn, None)
                        #self.pdos[self.ADOS_IDX][awfc_type][spin][
                        #    start_idx:end_idx] += val
                        # Deal with EDOS - currently disabled
                        #ewfc_type = WfcType(None, wfc_type.atom_type, None,
                        #                    None, None)
                        #self.pdos[self.EDOS_IDX][ewfc_type][spin][
                        #    start_idx:end_idx] += val
                        # Deal with AIDOS
                        aiwfc_type = WfcType(wfc_type.atom_idx, None, None,
                                             None, None)
                        self.pdos[self.AIDOS_IDX][aiwfc_type][spin][
                            start_idx:end_idx] += val
                        # Deal with ALDOS
                        alwfc_type = WfcType(wfc_type.atom_idx, None,
                                             wfc_type.n_qn, wfc_type.l_qn, None)
                        self.pdos[self.ALDOS_IDX][alwfc_type][spin][
                            start_idx:end_idx] += val  
                        # Deal with TDOS - currently disabled
                        #self.tdos[spin][start_idx:end_idx] += val
[docs]class BandStructure(object):
    """
    This class is based on the class from pymatgen (MIT license).
    """
[docs]    def __init__(self, kpoints, eigenvals, efermi, struct=None):
        """
        Initialize BandStructure object.
        :type kpoints: list of Kpoint objects
        :param: List of k-points for this band structure
        :type eigenvals: dict
        :param eigenvals: Energies of the band structure in the shape::
            {SPIN_UP: numpy.array([iband, jkpoint]), \
             SPIN_DW: numpy.array([iband, jkpoint])}
        SPIN_DW key can be present or not depending on the calculation type
        :type efermi: float
        :param efermi: Fermi energy in Hartree
        :type struct: `structure.Structure`
        :param struct: Related structure
        """
        self.kpoints = kpoints
        self.bands = eigenvals
        self.efermi = efermi
        self.struct = struct
        self.nb_bands = len(self.bands[SPIN_UP])
        self.is_spin_polarized = len(self.bands) == 2 
[docs]    def getVbmCbm(self, vbm=True):
        """
        Return data about the valence band maximum (VBM) or conduction band
        minimum (CBM).
        :type vbm: bool
        :param vbm: If True calculates VBM, if False CBM
        :rtype: dict
        :return: dict with keys BAND_INDEX_KEY, KPOINT_INDEX_KEY, KPOINT_KEY,
            ENERGY_KEY where:
            - BAND_INDEX_KEY: A dict with spin keys pointing to a list of the
            indices of the band containing the VBM (please note that you
            can have several bands sharing the VBM) {SPIN_UP:[],
            SPIN_DW:[]}
            - KPOINT_INDEX_KEY: The list of indices in self.kpoints for the
            kpoint vbm. Please note that there can be several
            kpoint_indices relating to the same kpoint (e.g., Gamma can
            occur at different spots in the band structure line plot)
            - KPOINT_KEY: The kpoint (as a kpoint object)
            - ENERGY_KEY: The energy of the VBM
        """
        # Based on: pymatgen/electronic_structure/bandstructure.py
        # MIT license
        # Check if VBM is requested
        if vbm:
            op1 = operator.le
            op2 = operator.gt
        else:
            op1 = operator.gt
            op2 = operator.le
        null_ret = {
            BAND_INDEX_KEY: [],
            KPOINT_INDEX_KEY: [],
            KPOINT_KEY: [],
            ENERGY_KEY: None
        }
        if self.isMetal():
            return null_ret
        extreme_val = -numpy.inf if vbm else numpy.inf
        extreme = {SPIN_UP: extreme_val, SPIN_DW: extreme_val}
        index = None
        kpointvbm = {SPIN_UP: None, SPIN_DW: None}
        for iband in range(self.nb_bands):
            for jkpoint in range(len(self.kpoints)):
                for spin in self.bands:
                    if op1(self.bands[spin][iband][jkpoint], self.efermi):
                        if op2(self.bands[spin][iband][jkpoint], extreme[spin]):
                            extreme[spin] = self.bands[spin][iband][jkpoint]
                            index = jkpoint
                            kpointvbm[spin] = self.kpoints[jkpoint]
        if index is None:
            # All bands are above (for vbm=True) or below Fermi
            return null_ret
        list_ind_kpts = []
        list_ind_kpts.append(index)
        # get all other bands sharing the vbm
        list_ind_band = {SPIN_UP: []}
        if self.is_spin_polarized:
            list_ind_band[SPIN_DW] = []
        for spin in self.bands:
            for iband in range(self.nb_bands):
                diff = self.bands[spin][iband][index] - extreme[spin]
                if math.fabs(diff) < 0.001:
                    list_ind_band[spin].append(iband)
        return {
            BAND_INDEX_KEY: list_ind_band,
            KPOINT_INDEX_KEY: list_ind_kpts,
            KPOINT_KEY: kpointvbm,
            ENERGY_KEY: extreme
        } 
[docs]    def getBandGap(self):
        r"""
        Get band gap data.
        :rtype: dict
        :return: dict with keys ENERGY_KEY, DIRECT_KEY, TRANSITION_KEY:
                ENERGY_KEY: band gap energy
                DIRECT_KEY: A boolean telling if the gap is direct or not
                TRANSITION_KEY: kpoint labels of the transition (e.g., "\Gamma-X")
        """
        # Based on: pymatgen/pymatgen/electronic_structure/bandstructure.py
        # MIT license
        result = {
            ENERGY_KEY: {
                SPIN_UP: 0.0,
                SPIN_DW: 0.0
            },
            DIRECT_KEY: {
                SPIN_UP: False,
                SPIN_DW: False
            },
            TRANSITION_KEY: {
                SPIN_UP: '',
                SPIN_DW: ''
            },
            IS_METAL_KEY: None,
            IS_DIRECT_KEY: False,
            BAND_GAP_KEY: float('inf'),
        }
        result[IS_METAL_KEY] = self.isMetal()
        if result[IS_METAL_KEY]:
            result[BAND_GAP_KEY] = 0.0
            return result
        cbm = self.getVbmCbm(vbm=False)
        vbm = self.getVbmCbm()
        if vbm[ENERGY_KEY] is None or cbm[ENERGY_KEY] is None:
            return result
        for spin in self.bands:
            result[ENERGY_KEY][spin] = cbm[ENERGY_KEY][spin] - \
                                                           
vbm[ENERGY_KEY][spin]
            # Update "global" band gap with the smallest band gap
            result[BAND_GAP_KEY] = min(result[BAND_GAP_KEY],
                                       result[ENERGY_KEY][spin])
            if numpy.linalg.norm(cbm[KPOINT_KEY][spin].cart_coords -
                                 vbm[KPOINT_KEY][spin].cart_coords) < 0.01:
                result[DIRECT_KEY][spin] = True
                # Update is direct band gap if gap for this spin is the smallest one
                if result[ENERGY_KEY][spin] == result[BAND_GAP_KEY]:
                    result[IS_DIRECT_KEY] = result[DIRECT_KEY][spin]
            transition_str = '%.3f, %.3f, %.3f' % \
                                        
tuple(vbm[KPOINT_KEY][spin].frac_coords)
            if vbm[KPOINT_KEY][spin].label:
                label = vbm[KPOINT_KEY][spin].label.replace('\\', '')
                transition_str += ' (%s)' % label
            transition_str += ' -> %.3f, %.3f, %.3f' % \
                                        
tuple(cbm[KPOINT_KEY][spin].frac_coords)
            if cbm[KPOINT_KEY][spin].label:
                label = cbm[KPOINT_KEY][spin].label.replace('\\', '')
                transition_str += ' (%s)' % label
            result[TRANSITION_KEY][spin] = transition_str
        return result 
[docs]    def generatePlotData(self):
        """
        Generate distances between k-points (in self.distances)
        for plotting band structure.
        """
        # Based on: pymatgen/pymatgen/electronic_structure/bandstructure.py
        # MIT license
        self.distances = []
        self.xticks = []
        self.xticklabels = []
        self.kpts_legend = []
        previous_kpoint = self.kpoints[0]
        previous_distance = 0.0
        for indx in range(len(self.kpoints)):
            self.distances.append(
                numpy.linalg.norm(self.kpoints[indx].cart_coords -
                                  previous_kpoint.cart_coords) +
                previous_distance)
            previous_kpoint = self.kpoints[indx]
            previous_distance = self.distances[indx]
            if previous_kpoint.label:
                self.xticks.append(previous_distance)
                label = r'$%s$' % previous_kpoint.label
                coords = previous_kpoint.getCoordsStr(frac=True)
                try:
                    self.xticklabels.index(label)
                except ValueError:
                    self.kpts_legend.append(KptLegend(label, coords))
                self.xticklabels.append(label)  
[docs]class Output(object):
    """
    Class to deal with QE XML output parsing.
    """
    PROPERTIES = ('struct', 'band', 'dos', 'pdos', 'esm', 'neb', 'phdos',
                  'phband', 'dynamics', 'epsilon', 'hpu')
    EPS_METAL_ERR = 'Metallic system encountered in epsilon calculation.'
[docs]    def __init__(self, qegz_fn, tree=None, dos_fn=None, **kwargs):
        """
        Initialize Output object. Supported properties are requested in kwargs
        and defined in self.PROPERTIES.
        :param str qegz_fn: Archive name of the compressed .save folder
        :type tree: ElementTree or None
        :param tree: Use tree if not None otherwise read tree from the file
        :type dos_fn: str or None
        :param dos_fn: File to read dos property from, will be used if dos=True
        """
        self.properties = self.getProperties(**kwargs)
        self.ok = True
        self.error = ''
        self.struct = None
        self.band = None
        self.dos = None
        self.pdos = None
        self.lowdin = None
        self.esm = None
        self.phdos = None
        self.phband = None
        self.hpu = None
        self.epsilon = {}
        self.upfs = dict()
        self.forces = None
        self.neb_outs = []
        self.md_steps = []
        if self.properties.dos and dos_fn:
            try:
                self.setDOSFromFile(dos_fn)
            except Exception as err:
                # Don't ever raise
                self.ok = False
                self.error = str(err)
            return
        if tree is None:
            tree = self.getTree(qegz_fn, self.properties)
            if not self.ok:
                return
        if not tree:
            self.error = ('Could not find "data-file-schema.xml" file in %s' %
                          qegz_fn)
            self.ok = False
            return
        root = tree.getroot()
        self.title = root.find(TITLE_TAG).text
        output = root.find(OUTPUT_TAG)
        self.alat, self.vecs = self._getVecsFromTree(output)
        self._getBasicInfo(root, output)
        if self.properties.struct:
            self.struct = self._getStructFromTree(root, output)
            self.forces = self._getAtomsForces(root, output)
            if self.lowdin:
                for atom in self.struct.atom:
                    if atom.index in self.lowdin:
                        atom.property.update(self.lowdin[atom.index])
                    if msprops.QE_LOWDIN_TCHARGE in atom.property:
                        lcharge = atom.property[msprops.QE_LOWDIN_TCHARGE]
                        melement = qeu.get_mag_hubbu(atom)
                        for cprop in FF_CUSTOM_CHARGE, DEFAULT_CHARGE_PROP:
                            atom.property[cprop] = melement.zval - lcharge
        # Get band gap if structure is requested
        if (self.properties.struct or self.properties.band or
                self.properties.dos or self.properties.pdos):
            self._getBandFromTree(root, output)
        # Get band gap if structure is requested
        if self.properties.struct:
            result = self.band.getBandGap()
            self.struct.property[msprops.QE_BAND_GAP] = \
                
result[BAND_GAP_KEY] * qeu.RY2EV
            self.struct.property[msprops.QE_IS_METAL] = result[IS_METAL_KEY]
            self.struct.property[msprops.QE_IS_DIRECT_BAND_GAP] = \
                
result[IS_DIRECT_KEY]
        if self.properties.dos:
            self._getDOS()
        if self.properties.pdos:
            self._getPDOS(self.pdos_proj, self.pdos_wfc_types)
        if self.properties.dynamics:
            timestep = float(root.find(MD_TIMESTEP_TAG).text) * qeu.AU2PS
            for step in root.iter('step'):
                struct = self.getMDStepStruct(root, step, timestep)
                self.md_steps.append(struct) 
[docs]    @classmethod
    def getProperties(cls, **kwargs):
        """
        Get properties (namedtuple) from kwargs. Supported properties are
        defined in self.PROPERTIES.
        :rtype: namedtuple
        :return: Properties requested
        """
        Properties = namedtuple('Properties',
                                cls.PROPERTIES,
                                defaults=(False,) * len(cls.PROPERTIES))
        return Properties()._replace(**kwargs) 
[docs]    def setDOSFromFile(self, dos_fn):
        """
        Parse .dos file and set DOS into self.dos.
        :param str dos_fn: File to read dos property from
        """
        try:
            self.dos = DOS(None, dos_fn)
        except ValueError as err:
            self.dos = None
            self.ok = False
            self.error = 'Could not read DOS from %s.' % dos_fn 
[docs]    def processEpsilonFile(self, tgz, tgz_file):
        """
        Process file from archive, if it is epsilon, return parsed data
        :param TarFile tgz: Tar archive
        :param str tgz_file: Filename
        :rtype: dict or None
        :return: Dict with one key (real or imag) and corresponding data in
            value or None if extensions didn't match
        :raise ValueError: When numpy can't parse the data
        """
        exts = {'real': qeu.EPS_R_PREFIX, 'imag': qeu.EPS_I_PREFIX}
        for key, ext in exts.items():
            if tgz_file.name.startswith(ext):
                with tgz.extractfile(tgz_file) as file_fh:
                    try:
                        return {key: numpy.loadtxt(file_fh, unpack=True)}
                    except ValueError as err:
                        if 'could not convert string to float' in str(err):
                            raise ValueError(self.EPS_METAL_ERR)
                        else:
                            # Re-raise general ValueError
                            raise 
[docs]    def processFile(self, tgz, tgz_file, options):
        """
        Process file from archive and set object properties.
        :param TarFile tgz: Tar archive
        :param str tgz_file: Filename
        :param namedtuple options: Properties to get
        """
        if (options.esm and self.esm is None and
                tgz_file.name.endswith(qeu.ESM_EXT)):
            with tgz.extractfile(tgz_file) as file_fh:
                self.esm = EsmType(numpy.loadtxt(file_fh, unpack=True), '', 0.0)
            return
        if options.epsilon and tgz_file.name.endswith(qeu.DAT_EXT):
            epsilon_data = self.processEpsilonFile(tgz, tgz_file)
            if epsilon_data:
                self.epsilon.update(epsilon_data)
                return
        if (options.phdos and self.phdos is None and
                tgz_file.name.endswith(qeu.PHDOS_EXT)):
            with tgz.extractfile(tgz_file) as file_fh:
                self.phdos = PhDOS(file_fh)
            return
        if (options.phband and self.phband is None and
                tgz_file.name.endswith(qeu.GP_EXT)):
            with tgz.extractfile(tgz_file) as file_fh:
                tmp = numpy.genfromtxt(file_fh, dtype=None, encoding=None)
            tmp_len = len(tmp.dtype.names)
            tmp_names = tmp.dtype.names
            self.phband = numpy.zeros(shape=(tmp_len - 1, tmp.shape[0]))
            self.phband[0] = tmp[tmp.dtype.names[0]]
            for idx in range(2, tmp_len):
                self.phband[idx - 1] = tmp[tmp.dtype.names[idx]]
            self.phband_labels = []
            self.phband_ticks = []
            self.phband_fcoords = []
            for idx, label in enumerate(tmp[tmp_names[1]]):
                label = msutils.getstr(label)
                if label != '*':
                    llist = label.replace('_', ' ').split('=')
                    self.phband_labels.append(r'$%s$' % llist[0].strip())
                    self.phband_ticks.append(self.phband[0, idx])
                    self.phband_fcoords.append(llist[1].strip())
            return
        if options.pdos:
            if tgz_file.name.endswith(qeu.PROJWFC_UP_EXT):
                with tgz.extractfile(tgz_file) as file_fh:
                    self.pdos_proj[SPIN_UP], self.pdos_wfc_types = \
                        
self._parsePDOSProj(file_fh)
            if tgz_file.name.endswith(qeu.PROJWFC_DW_EXT):
                with tgz.extractfile(tgz_file) as file_fh:
                    self.pdos_proj[SPIN_DW], not_needed = \
                        
self._parsePDOSProj(file_fh)
            if tgz_file.name.endswith(qeu.LOWDIN_EXT):
                with tgz.extractfile(tgz_file) as file_fh:
                    self.lowdin = self._parseLowdin(file_fh)
            if tgz_file.name.endswith(qeu.UPF_EXT):
                tmp, file_fn = os.path.split(tgz_file.name)
                with tgz.extractfile(tgz_file) as file_fh:
                    self.upfs[file_fn] = qeu.UPFParser(None,
                                                       file_fh=file_fh,
                                                       binary=True).getPseudo()
            return
        if options.hpu and tgz_file.name.endswith(qeu.HP_EXT):
            with tgz.extractfile(tgz_file) as file_fh:
                self.hpu = self.parseHP(file_fh)
            return 
[docs]    def getTree(self, qegz_fn, options):
        """
        Get data in from of tree from archived file
        :param str qegz_fn: Archive name
        :param namedtuple options: Properties to get
        :return: Data from data-file-schema.xml
        :rtype: `xml.etree.ElementTree`
        """
        self.pdos_proj = {}
        self.pdos_wfc_types = None
        neb_error = 'Could not find NEB output file.'
        tree = None
        try:
            with tarfile.open(qegz_fn, 'r') as tgz:
                for tgz_file in tgz:
                    if (tree is None and
                            'data-file-schema.xml' in tgz_file.name):
                        with tgz.extractfile(tgz_file) as file_fh:
                            tree = ElementTree.parse(file_fh)
                            continue
                    if options.neb and tgz_file.name.endswith(qeu.NEB_EXT):
                        with tgz.extractfile(tgz_file) as file_fh:
                            tree = ElementTree.parse(file_fh)
                            self.neb_outs.append(
                                Output(None, tree=tree, struct=True))
                        if not self.neb_outs[-1].ok:
                            raise IOError(self.neb_outs[-1].error)
                        continue
                    if options.neb and tgz_file.name.endswith(qeu.NEB_OUT_EXT):
                        with tgz.extractfile(tgz_file) as file_fh:
                            for line in file_fh:
                                line = msutils.getstr(line).strip()
                                if line.startswith('neb: convergence achieved'):
                                    neb_error = ''
                                    break
                                else:
                                    neb_error = ('NEB calculation did not '
                                                 'converge.')
                        continue
                    self.processFile(tgz, tgz_file, options)
            if options.pdos and (len(self.pdos_proj) not in (1, 2) or
                                 SPIN_UP not in self.pdos_proj):
                raise IOError('PDOS is requested, however wrong number of '
                              'PDOS files (.up, .down) has been found.')
            if options.esm and self.esm is None:
                raise IOError('ESM is requested, however *%s file could '
                              'not be found.' % qeu.ESM_EXT)
            if options.epsilon and not self.epsilon:
                raise IOError('Epsilon is requested, however *%s file '
                              'could not be found.' % qeu.DAT_EXT)
            if options.phdos and self.phdos is None:
                raise IOError('Phonon DOS is requested, however *%s file '
                              'could not be found.' % qeu.PHDOS_EXT)
            if options.phband and self.phband is None:
                raise IOError('Phonon dispersion is requested, however *%s '
                              'file could not be found.' % qeu.GP_EXT)
            if options.hpu and self.hpu is None:
                raise IOError('HP is requested, however *%s file could not '
                              'be found.' % qeu.HP_EXT)
            if options.neb and neb_error:
                raise ValueError(neb_error)
        # Except Exception to ensure that driver doesn't crash in a workflow
        except (IOError, ElementTree.ParseError, ValueError, Exception) as err:
            self.ok = False
            self.error = str(err)
        return tree 
[docs]    def getMDStepStruct(self, root, step, timestep):
        """
        Extract MD step structure from step XML.
        :param xml.etree.ElementTree.Element root: Root XML element
        :param xml.etree.ElementTree.Element step: Step element
        :param float timestep: MD time step
        :rtype: `structure.Structure`, `structure.Structure` or None
        :return: Structure of the MD step, standardized structure if requested
            and found
        """
        struct = self._getStructFromTree(root, step)
        struct = xtal.chorus_to_lower_triangle(struct)
        nstep = int(step.attrib['n_step'])
        struct.property[msprops.QE_MD_NSTEP] = nstep
        struct.property[msprops.QE_MD_CURTIME] = nstep * timestep
        struct.property[msprops.QE_ETOT] = \
            
float(step.find(TOT_ENERGY_TAG).text) * qeu.HA2RY
        return struct 
    def _convertTagToCoords(self, element):
        """
        Convert text from element having such text: 'float float float' to list
        of floats.
        :type element: `xml.etree.ElementTree.Element`
        :param element: Element to parse
        :rtype: list of floats
        :return: List of floats converted from the element's text
        """
        coords = list(map(float, element.text.strip().split()))
        coords = [coord * qeu.B2A for coord in coords]
        return coords
    def _getVecsFromTree(self, root):
        """
        Parse and set alat (in self.alat), cell vectors (in self.vecs) and cell
        volume (in self.volume) in A^3 from XML tree.
        :type root: `xml.etree.ElementTree.Element`
        :param root: Output element that contains required information
        :rtype: float, numpy.array
        :return: lattice parameter (alat) and lattice vectors
        """
        atomic_structure = root.find(ATOMIC_STRUCTURE_TAG)
        alat = float(atomic_structure.attrib['alat'])
        vecs = []
        for indx in range(1, 4):
            vec_tag = root.find(CELL_VECTORS_TAG % indx)
            vec = self._convertTagToCoords(vec_tag)
            vecs.append(vec)
        return alat, numpy.array(vecs)
    def _getMagHubbUSpecies(self, root):
        """
        Get a dict with species (element name + a number) as keys and starting
        magnetizations and Hubbard U as values.
        :type root: `xml.etree.ElementTree.Element`
        :param root: Element that contains required information
        :rtype: dict
        :return: dict with species (element name + a number) as keys and
            MagElement namedtuple with starting magnetization and Hubbard U
            parameters. If neither magnetization nor Hubbard U are defined
            (or zero), key is not set
        """
        ret = {}
        for atom_species in root.iterfind(ATOMIC_SPECIES_TAG):
            mag_tag = atom_species.find('starting_magnetization')
            species = atom_species.attrib['name']
            ret[species] = qeu.get_null_mag_element()
            if mag_tag is not None:
                val = float(mag_tag.text.strip())
                ret[species] = ret[species]._replace(mag=val)
            upf_tag = atom_species.find('pseudo_file')
            upf_fn = upf_tag.text.strip()
            if upf_fn in self.upfs:
                ret[species] = ret[species]._replace(
                    zval=self.upfs[upf_fn].zval)
        for hubb_u in root.iterfind(DFT_HU_TAG):
            val = float(hubb_u.text.strip()) * qeu.RY2EV
            species = hubb_u.attrib['specie']
            if species in ret:
                ret[species] = ret[species]._replace(hubb_u=val)
            else:
                ret[species] = qeu.get_null_mag_element()
        return ret
[docs]    @staticmethod
    def getFreePositions(root, natom):
        """
        Get atom free_positions from the XML data
        :param xml.etree.ElementTree.Element root: Root XML element
        :param int natom: Number of atoms in the structure
        :rtype: numpy.array
        :return: Atom free positions
        """
        free_pos = root.find(FREE_POS_TAG)
        if free_pos is None:
            # Missing from NEB output, needs fixing
            return numpy.ones((natom, 3), dtype=int)
        free_pos = list(map(int, free_pos.text.split()))
        free_pos = numpy.array(free_pos, dtype=int).reshape(natom, 3)
        return free_pos 
    def _getStructFromTree(self, root, output):
        """
        Parse and set a structure (in self.struct) from XML tree.
        :param xml.etree.ElementTree.Element root: Root XML element
        :param xml.etree.ElementTree.Element output: Output XML element
        :type search_std_cell: bool
        :param search_prim: If True, search for standard conventional cell,
            otherwise use cell information as is.
        :rtype: `structure.Structure`, `structure.Structure` or None
        :return: Structure generated from XML data, and standardized structure,
            if requested (by search_std_cell) and found
        """
        def get_structure(vecs, fcoords, anums, free_pos):
            """
            Get structure from vectors, coordinates and elements.
            :type vecs: numpy.array(3x3 float)
            :param vecs: lattice vectors
            :type fcoords: numpy.array(int x 3 float)
            :param coords: Fractional atomic coordinates
            :type anums: list(int)
            :param anums: List of elements
            :param numpy.array free_pos: Free positions for all the atoms
            :rtype: `structure.Structure`
            :return: Generated structure
            """
            coords = xtal.trans_frac_to_cart_from_vecs(fcoords, *vecs)
            struct = structure.create_new_structure()
            xtal.set_pbc_properties(struct, vecs.flat)
            for xyz, anum, atom_free_pos in zip(coords, anums, free_pos):
                mag_element = mag_elements[anum]
                atom = struct.addAtom('C', *xyz)
                mag_element_zero = qeu.get_null_mag_element()
                mag_hubbu = mag_hubbu_species.get(mag_element, mag_element_zero)
                atom.element = qeu.get_element(mag_element)
                qeu.set_mag_hubbu(atom, mag_hubbu)
                qeu.set_atom_cart_constr(atom, atom_free_pos)
            # Generate bonding
            xtal.connect_atoms(struct, pbc_bonding=True)
            struct.retype()
            # Color atoms
            color.apply_color_scheme(struct, 'element')
            # Set physical properties
            xtal.set_physical_properties(struct)
            # Set total energy, energy and density cutoffs
            struct.property[msprops.QE_ETOT] = self.etot
            struct.property[msprops.QE_MTOT] = self.mtot
            struct.property[msprops.QE_ECUTWFC] = self.ecutwfc
            struct.property[msprops.QE_ECUTRHO] = self.ecutrho
            struct.property[msprops.QE_EFERMI] = self.efermi
            struct.property[msprops.QE_DFT_FUNCT_STR] = self.dft_funct
            if self.dft_vdw_corr:
                struct.property[msprops.QE_DFT_VDW_CORR] = self.dft_vdw_corr
            struct.property[msprops.QE_NKS] = self.nks
            if self.nbnd is not None:
                struct.property[msprops.QE_NBND] = self.nbnd
            else:
                struct.property[msprops.QE_NBND] = (self.nbnd_up + self.nbnd_dw)
                struct.property[msprops.QE_NBND_DW] = self.nbnd_dw
            # Assign space group and space group ID
            xtal.assign_space_group(struct, xtal.ASSIGN_SPG_SYMPREC)
            # Anchor to origin
            struct.property[xtal.PBC_POSITION_KEY] = \
                
xtal.ANCHOR_PBC_POSITION % ('0', '0', '0')
            if self.timing_cpu is not None:
                struct.property[msprops.QE_TIMING_CPU] = self.timing_cpu
            if self.timing_wall is not None:
                struct.property[msprops.QE_TIMING_WALL] = self.timing_wall
            return struct
        mag_hubbu_species = self._getMagHubbUSpecies(output)
        alat, vecs = self._getVecsFromTree(output)
        mag_elements = []
        coords = []
        for atom in output.iterfind(ATOMIC_POSITIONS_TAG):
            xyz = self._convertTagToCoords(atom)
            coords.append(xyz)
            mag_element = atom.attrib['name'].strip()
            mag_elements.append(mag_element)
        anums = [mag_elements.index(x) for x in mag_elements]
        fcoords = xtal.trans_cart_to_frac_from_vecs(coords, *vecs)
        free_pos = self.getFreePositions(root, len(fcoords))
        struct = get_structure(vecs, fcoords, anums, free_pos)
        if self.hpu:
            # Use Hubbard U if hpu has been parsed
            for idx, hubb_u in self.hpu.items():
                atom = struct.atom[idx]
                mag_element = qeu.get_mag_hubbu(atom)
                mag_element = mag_element._replace(hubb_u=hubb_u, hubb_j0=0.)
                qeu.set_mag_hubbu(atom, mag_element)
        return struct
[docs]    def getMaeStructure(self):
        """
        Get structure with lower triangular form of lattice vectors matrix
        :rtype: structure.Structure
        :return: Updated structure
        """
        assert self.struct
        struct = xtal.chorus_to_lower_triangle(self.struct.copy())
        xtal.set_physical_properties(struct)
        return struct 
    def _getInputKpoints(self, root):
        """
        Return k-points present in the input section of the output schema.
        :type root: `xml.etree.ElementTree.Element`
        :param root: Element with required information
        :rtype: list of KPoint
        :return: list of KPoint objects
        """
        kpts_ibz = root.find(KPTS_IBZ)
        kpts_list = list(kpts_ibz.iter('k_point'))
        kpts = []
        for kpoint_tag in kpts_list:
            kpts.append(KPoint(kpoint_tag, self.vecs))
        return kpts
    def _getInputKpointsMesh(self, root):
        """
        Return k-points present in the input section of the output schema.
        :type root: `xml.etree.ElementTree.Element`
        :param root: Element with required information
        :rtype: list of KPoint
        :return: list of KPoint objects
        """
        kpts_ibz = root.find(KPTS_IBZ)
        pack_tag = kpts_ibz.find('monkhorst_pack')
        ret = ''
        if pack_tag is not None:
            ret = ','.join(
                pack_tag.attrib[nki] for nki in ['nk1', 'nk2', 'nk3'])
        return ret
    @staticmethod
    def _getTiming(root):
        """
        Get timings (WALL, CPU) from the output
        :type root: `xml.etree.ElementTree.Element`
        :param root: Element that contains required information
        :rtype: (float, float) or (None, None)
        :return: Return total time for wall, cpu in s. If absent return None
        """
        wall, cpu = None, None
        if not root.find(TIMING_TAG):
            return wall, cpu
        wall = float(root.find(TIMING_TOTAL_WALL_TAG).text)
        cpu = float(root.find(TIMING_TOTAL_CPU_TAG).text)
        return wall, cpu
    def _getBasicInfo(self, root, output):
        """
        Parse and set attributes several attributes.
        :type root: `xml.etree.ElementTree.Element`
        :param root: Root element that contains required information
        :type output: `xml.etree.ElementTree.Element`
        :param output: Output element that contains required information
        """
        # Both should be  present for open shell
        if (output.find(NBND_UP_TAG) is not None and
                output.find(NBND_DW_TAG) is not None):
            self.nbnd_up = int(output.find(NBND_UP_TAG).text)
            self.nbnd_dw = int(output.find(NBND_DW_TAG).text)
            self.nbnd = None
        else:
            self.nbnd = int(output.find(NBND_TAG).text)
            self.nbnd_up = self.nbnd
            self.nbnd_dw = 0
        self.timing_wall, self.timing_cpu = self._getTiming(root)
        self.nks = int(output.find(NKS_TAG).text)
        self.etot = float(output.find(TOT_ENERGY_TAG).text) * qeu.HA2RY
        self.mtot = float(output.find(TOT_MAG_TAG).text)
        self.ecutwfc = float(output.find(ECUTWFC_TAG).text) * qeu.HA2RY
        self.ecutrho = float(output.find(ECUTRHO_TAG).text) * qeu.HA2RY
        self.tot_charge = float(root.find(TOT_CHARGE_TAG).text)
        self.dft_funct = output.find(DFT_FUNCT_TAG).text.strip()
        try:
            self.dft_vdw_corr = output.find(DFT_VDW_TAG).text.strip()
        except AttributeError:
            self.dft_vdw_corr = None
        self.kpts_mesh = self._getInputKpointsMesh(root)
        self.stress = self._getStress(output)
        self.efermi = 0.0
        # E Fermi 2 appears if tot_magnetization != -1.0
        self.efermi2 = 0.0
        try:
            self.efermi = float(output.find(EFERMI_TAG).text) * qeu.HA2RY
        except AttributeError:
            try:
                self.efermi, self.efermi2 = [
                    float(x) * qeu.HA2RY
                    for x in output.find(TWO_EFERMIS_TAG).text.split()
                ]
            except AttributeError:
                self.efermi = float(output.find(HOMO_TAG).text) * qeu.HA2RY
        # Parse ESM related stuff, if requested
        if self.esm:
            try:
                esm_bc_type = root.find(ESM_BC_TAG).text
            except AttributeError:
                esm_bc_type = ''
            try:
                esm_efield = float(root.find(ESM_EFIELD_TAG).text)
            except AttributeError:
                esm_efield = 0.0
            self.esm = self.esm._replace(bc_type=esm_bc_type, efield=esm_efield)
    def _getStress(self, output):
        """
        Get stress from output XML node, if not there, return zero matrix.
        :type output: `xml.etree.ElementTree.Element`
        :param output: Output element that contains required information
        :rtype: numpy.array((3, 3))
        :return: Stress tensor in kBar
        """
        stress = numpy.zeros((3, 3))
        try:
            tmp = output.find(STRESS_TAG).text
            data = tmp.split()
            stress = numpy.array(data, dtype=float).reshape((3, 3))
        except AttributeError:
            pass
        return stress * qeu.HA2KBAR
    def _getAtomsForces(self, root, output):
        """
        Get forces on atoms. QE input should contain tprnfor = .true.
        :type root: `xml.etree.ElementTree.Element`
        :param root: Element that contains required information
        :type output: `xml.etree.ElementTree.Element`
        :param output: Output element that contains required information
        :rtype: numpy.array or None
        :return: Atoms forces in Ry/au if present, otherwise None
        """
        has_forces = msutils.setting_to_bool(root.find(HAS_FORCES).text.strip())
        if not has_forces:
            return None
        forces = list(map(float, output.find(FORCES_TAG).text.strip().split()))
        forces = numpy.array(forces).reshape(-1, 3) * qeu.HA2RY
        return forces
    def _getBandFromTree(self, root, output):
        """
        Parse and set the BandStructure object in self.band from XML tree.
        :type root: `xml.etree.ElementTree.Element`
        :param root: Element that contains required information
        :type output: `xml.etree.ElementTree.Element`
        :param output: Output element that contains required information
        """
        kpts = []
        bands = {}
        bnds_up = numpy.zeros(shape=(self.nbnd_up, self.nks))
        bnds_dw = numpy.zeros(shape=(self.nbnd_dw, self.nks))
        # Get K-points labels, if present
        input_kpts = self._getInputKpoints(root)
        # Iterate over k-point energies
        for kindx, energy in enumerate(output.iterfind(KS_ENERGIES_TAG)):
            kpoint_tag = energy.find('k_point')
            kpoint = KPoint(kpoint_tag, self.vecs)
            # Looking for k-point from the input with the same Cartesian
            # coordinates and hopefully a label
            for input_kpt in input_kpts:
                if numpy.linalg.norm(input_kpt.cart_coords -
                                     kpoint.cart_coords) < 0.01:
                    # Weight should be set from the output. It could be
                    # different from the input in the lsda case.
                    input_kpt.weight = kpoint.weight
                    kpoint = input_kpt
                    break
            kpts.append(kpoint)
            orb_energies = list(
                map(float,
                    energy.find('eigenvalues').text.strip().split()))
            up_energies = orb_energies[:self.nbnd_up]
            for oindx, orb_e in enumerate(up_energies):
                bnds_up[oindx][kindx] = orb_e * qeu.HA2RY
            # Prevent empty lists for 'down' energies (closed-shell case)
            dw_energies = orb_energies[self.nbnd_up:(self.nbnd_up +
                                                     self.nbnd_dw)]
            for oindx, orb_e in enumerate(dw_energies):
                bnds_dw[oindx][kindx] = orb_e * qeu.HA2RY
        bands[SPIN_UP] = bnds_up
        if len(bnds_dw):
            bands[SPIN_DW] = bnds_dw
        self.band = BandStructure(kpts, bands, self.efermi)
    def _getPDOS(self, proj, wfc_types):
        """
        Initialize PDOS object in self.pdos.
        :type proj: 2D numpy.array
        :param proj: 2D array containing index of projected atom wavefunction
            (WFC) and k-point as indexes and WFC projection as value
        :type wfc_types: OrderedDict
        :param wfc_types: OrderedDict containing label of the projected WFC as
            key, and WFC index as value.
        """
        self.pdos = PDOS(proj, wfc_types, self.efermi * qeu.RY2EV, self.band)
    def _getDOS(self):
        """
        Initialize DOS object in self.dos.
        """
        self.dos = DOS(self.band, None)
    def _parsePDOSProj(self, proj_fh):
        """
        Parse and return data from .prowfc_* file. This is generated by
        projwfc.x.
        :type proj_fh: File handler object
        :param proj_fh: Handler to the .projwfc_* file
        :rtype: (numpy.array, OrderedDict())
        :return: 3D array containing: index of projected atom wavefunction
            (WFC), index of k-point, index of band and WFC projection as value.
            OrderedDict containing label of the projected WFC as key, and WFC index
            as value.
        """
        proj_fh.seek(0)
        proj_fh.readline()  # Title
        tmp = proj_fh.readline().split()  # Gridx, grid, nat, ntyp
        nat = int(tmp[-2])
        ntyp = int(tmp[-1])
        proj_fh.readline()  # ibrav / celldm
        proj_fh.readline()  # lattice vectors
        proj_fh.readline()  # lattice vectors
        proj_fh.readline()  # lattice vectors
        proj_fh.readline()  # ecutwfc
        [proj_fh.readline() for x in range(ntyp)]  # atom types
        [proj_fh.readline() for x in range(nat)]  # atom coordinates
        nwfc, nkpt, nbnd = list(map(int, proj_fh.readline().split()))
        proj_fh.readline()  # noncolin, lspinorb
        proj = numpy.zeros(shape=(nwfc, nkpt, nbnd))
        wfc_types = [None] * nwfc
        for iwfc in range(nwfc):
            tmp = proj_fh.readline().split()
            # Example:
            # IDX  ADX ANAME WFC  n    l    m or j
            # 1    1   Si    3S   1    0    1
            # N is starting from 1 all time (which is wrong: MATSCI-3883)
            atom_idx, atom_type = int(tmp[1]), tmp[2]
            atom_type = msutils.getstr(atom_type)
            try:
                # principal quantum number is smaller than 10
                orb_name = msutils.getstr(tmp[3])
                n_qn = int(orb_name[0])
            except (ValueError, TypeError):
                # Some UPFs don't have labels (MATSCI-4777)
                n_qn = 0
            l_qn = int(tmp[5])
            # This could be either m (int) or j (float)
            m_qn = float(tmp[6])
            wfc_types[iwfc] = WfcType(atom_idx, atom_type, n_qn, l_qn, m_qn)
            for ikpt in range(nkpt):
                proj[iwfc, ikpt] = numpy.genfromtxt(islice(proj_fh, nbnd),
                                                    usecols=2,
                                                    encoding=None)
        return proj, wfc_types
    def _parseLowdin(self, lowdin_fh):
        """
        Parse and return data from .lowdin file. This is generated by
        projwfc.x.
        :type lowdin_fh: File handler object
        :param lowdin_fh: Handler to the .lowdin file
        :rtype: dict
        :return: Dictionary with atom indexes as keys and charges as values.
        """
        def set_prop_val(atom, line, prop, prop_name, prop_type):
            tmp = line.strip().split()
            if prop in line:
                prop_str = tmp[tmp.index('=') + 1].replace(',', '')
                atom[prop_name] = prop_type(prop_str)
        lowdin_fh.seek(0)
        lowdin_fh.readline()
        charges = dict()
        atom_idx = None
        for line in lowdin_fh:
            line = msutils.getstr(line)
            tmp = line.strip().split()
            if 'Atom #' in line:
                atom_idx = int(tmp[2].replace(':', ''))
                charges[atom_idx] = dict()
            if atom_idx is None:
                continue
            set_prop_val(charges[atom_idx], line, 'total charge',
                         msprops.QE_LOWDIN_TCHARGE, float)
            set_prop_val(charges[atom_idx], line, 'spin up',
                         msprops.QE_LOWDIN_UP, float)
            set_prop_val(charges[atom_idx], line, 'spin down',
                         msprops.QE_LOWDIN_DW, float)
            if 'Spilling Parameter:' in line:
                spill = float(tmp[2])
                for key in charges:
                    charges[key][msprops.QE_LOWDIN_SPILL] = spill
        return charges
[docs]    @staticmethod
    def parseHP(hp_fh):
        """
        Parse and return data from the Hubbard_parameters.dat.
        :type hp_fh: File handler object
        :param hp_fh: Handler of the Hubbard_parameters.dat
        :rtype: dict
        :return: Dictionary with atom indexes as keys and Hubbard U as values
        """
        ret = dict()
        for line in hp_fh:
            line = msutils.getstr(line).strip()
            # site n.  type  label  spin  new_type  new_label  Hubbard U (eV)
            if line.startswith('site n.'):
                for line in hp_fh:
                    line = msutils.getstr(line).strip()
                    # return on the next empty line
                    if not line:
                        return ret
                    # 1        1    Co      1      1         Co         8.8661
                    tmp = line.split()
                    ret[int(tmp[0])] = float(tmp[-1])
        return ret