"""
Classes and functions to deal with XML input generation for Quantum Espresso.
Copyright Schrodinger, LLC. All rights reserved."""
import json
import os
from collections import OrderedDict
from collections import namedtuple
from collections.abc import Iterable
from math import ceil
from math import sqrt
import numpy
import spglib
from typing import NamedTuple
from schrodinger import structure
from schrodinger.application.matsci import msprops
from schrodinger.application.matsci.espresso import utils as qeu
from schrodinger.application.matsci.espresso.qeoutput import KPoint
from schrodinger.application.matsci.nano import xtal
from schrodinger.application.matsci import msutils
from schrodinger.infra import mm
from schrodinger.infra.mm import \
    mmelement_get_atomic_weight_by_atomic_number as get_atomic_weight
# File/folder extensions used in IO
IN_EXT = '.qegz'
OUT_SAVE_EXT = '.save'
OUT_EXT = OUT_SAVE_EXT + IN_EXT
DOS_EXT = '.dos'
PAW_EXT = '.paw'
DYN_EXT = '.dyn'
GAU_PBE_FUNCT, EV93_FUNCT, PBE_FUNCT = 'GauPBE', 'EV93', 'PBE'
LDA_FUNCTIONALS = ('PZ',)
GGA_FUNCTIONALS = (PBE_FUNCT, 'PBESOL', 'BP', 'REVPBE', 'BLYP', 'PW91', 'WC',
                   EV93_FUNCT)
VDW_FUNCTIONALS = ('VDW-DF', 'VDW-DF-CX', 'VDW-DF-C09', 'VDW-DF-OB86',
                   'VDW-DF-OBK8', 'VDW-DF2', 'VDW-DF2-C09', 'VDW-DF2-B86R',
                   'RVV10', 'BEEF')
EXX_FUNCTIONALS = ('B3LYP', 'PBE0', GAU_PBE_FUNCT)
LRC_FUNCTIONALS = ('HSE',)
HYBRID_FUNCTIONALS = EXX_FUNCTIONALS + LRC_FUNCTIONALS
TEST_FUNCTIONALS = ('SCAN', 'R2SCAN')
DFT_FUNCTIONALS = list(LDA_FUNCTIONALS + GGA_FUNCTIONALS + VDW_FUNCTIONALS +
                       EXX_FUNCTIONALS + LRC_FUNCTIONALS + TEST_FUNCTIONALS)
DFT_D3_BAD_FUNCTIONALS = (EV93_FUNCT,)
DFT_D_VDW_CORR, DFT_D3_VDW_CORR = 'DFT-D', 'DFT-D3'
XDM_VDW_CORRECTION = 'XDM'
TS_VDW_CORRECTION = 'TS'
MBD_VDW_CORRECTION = 'MBD'
VDW_CORRECTION = OrderedDict([('None', ''), (DFT_D_VDW_CORR, DFT_D_VDW_CORR),
                              (DFT_D3_VDW_CORR, DFT_D3_VDW_CORR),
                              (TS_VDW_CORRECTION, TS_VDW_CORRECTION),
                              (XDM_VDW_CORRECTION, XDM_VDW_CORRECTION)])
VDW_CORRECTION_TEST = (MBD_VDW_CORRECTION,)
SMEARING_TYPE = OrderedDict([('Gaussian', 'gaussian'),
                             ('Methfessel-Paxton', 'mp'),
                             ('Marzari-Vanderbilt', 'mv'),
                             ('Fermi-Dirac', 'fd')])
OCCUPATIONS_SMEARING = 'smearing'
OCCUPATIONS_TETRAHEDRA = 'tetrahedra_opt'
OCCUPATIONS_FIXED = 'fixed'
OCCUPATIONS = (OCCUPATIONS_SMEARING, OCCUPATIONS_TETRAHEDRA, OCCUPATIONS_FIXED)
DIAG_TYPE = {
    'Davidson': 'davidson',
    'Conjugate Gradient': 'cg',
    'Iterative PPCG': 'ppcg',
    'Iterative ParO': 'paro'
}
MIXING_MODE = OrderedDict([('Broyden', 'plain'), ('Simple Thomas-Fermi', 'TF'),
                           ('Local Thomas-Fermi', 'local-TF')])
SPIN_CS, SPIN_LSDA, SPIN_NCL, SPIN_SO = 'cs', 'lsda', 'noncolin', 'spinorbit'
SPIN_TYPE = OrderedDict([('Non-polarized', SPIN_CS),
                         ('Spin-polarized', SPIN_LSDA),
                         ('Non-collinear', SPIN_NCL),
                         ('Spin-orbit', SPIN_SO)]) # yapf: disable
DFTU_NONE, DFTU_0, DFTU_1 = '', '0', '1'
DFTU_TYPE = OrderedDict([('Disabled', DFTU_NONE),
                         ('Enabled', DFTU_0)]) # yapf: disable
SCF_TYPE, RELAX_TYPE, VCRELAX_TYPE, MD_TYPE = 'scf', 'relax', 'vc-relax', 'md'
CALCULATION_TYPE = (SCF_TYPE, 'nscf', 'bands', RELAX_TYPE, VCRELAX_TYPE,
                    MD_TYPE, 'vc-md', 'phonon')
ION_DYNAMICS_BFGS = 'bfgs'
ION_DYNAMICS_DAMP = 'damp'
ION_DYNAMICS = OrderedDict([('BFGS quasi-newton algorithm', ION_DYNAMICS_BFGS),
                            ('Damped relaxation', ION_DYNAMICS_DAMP)])
CELL_DYNAMICS_BFGS = 'bfgs'
CELL_DYNAMICS_DAMP_PR = 'damp-pr'
CELL_DYNAMICS = OrderedDict([
    ('BFGS quasi-newton algorithm', CELL_DYNAMICS_BFGS),
    ('Damped dynamics of P-R lagrangian', CELL_DYNAMICS_DAMP_PR),
    ('Damped dynamics of Wentzcovitch lagrangian', 'damp-w')
])
CELL_DOFREE_XY = '2Dxy'
CELL_DOFREE = OrderedDict(
    [('All axis and angles are free', 'all'),
     ('All axis and angles are free, volume fixed', 'shape'),
     ('Only x and y components are free', CELL_DOFREE_XY),
     ('Only x and y components are free, x-y plane fixed', '2Dshape')]) # yapf: disable
ION_TEMP = OrderedDict([('Velocity rescaling', 'rescale-v'),
                        ('Temperature rescaling', 'rescale-T'),
                        ('Berendsen thermostat', 'berendsen'),
                        ('Andersen thermostat', 'andersen'),
                        ('Initialized but uncontrolled', 'initial'),
                        ('Uncontrolled', 'not_controlled')])
MD_DYNAMICS = OrderedDict([('Verlet algorithm', 'verlet'),
                           ('Over-damped Langevin', 'langevin')])
ASSUME_ISOLATED_ESM = 'esm'
ASSUME_ISOLATED = ('', 'makov-payne', 'martyna-tuckerman', ASSUME_ISOLATED_ESM)
ESM_BC_BC1 = 'bc1'
ESM_BC_BC2 = 'bc2'
ESM_BC_BC3 = 'bc3'
ESM_BC = (ESM_BC_BC1, ESM_BC_BC2, ESM_BC_BC3)
WF_COLLECT_KEY = 'wf_collect'
WF_KEEP_KEY = 'wf_keep'
DEFAULT_KPTS_BAND_PATH = [[0.0, 0.0, 0.0, 'Gamma'],
    [0.5, 0.0, 0.5, 'X'], [0.25, 0.5, 0.0, 'W'], [0.375, 0.375, 0.75, 'K'],
    [0.0, 0.0, 0.0, 'Gamma'], [0.25, 0.25, 0.25, 'L'], [0.375, 0.25, 0.625,'U'],
    [0.5, 0.25, 0.75, 'W'], [0.5, 0.5, 0.5, 'L'], [0.375, 0.375, 0.75, 'K'],
    [0.375, 0.25, 0.625, 'U'], [0.5, 0.0, 0.5, 'X']] # yapf: disable
G_ONLY_MESH = [1, 1, 1, 0, 0, 0]
NEB_STRING_TYPE = OrderedDict([('Nudged elastic band', 'neb'),
                               ('String method dynamics', 'smd')])
NEB_OPT_SCHEME_TYPE = OrderedDict([("Quasi-Newton Broyden's method", 'broyden'),
                                   ('Projected velocity Verlet scheme',
                                    'quick-min'), ('Steepest descent', 'sd')])
NEB_CI_TYPE = OrderedDict([('Not used', 'no-CI'), ('Auto', 'auto')])
NEB_MAX_NIMAGES = 10000
CUSTOM_NEB_METHOD = {
    'Nudged elastic band': 'improvedtangent',
    'Classical NEB': 'aseneb',
    'Elastic band': 'eb',
    'Spline method': 'spline',
    'String method': 'string'
}
CUSTOM_NEB_OPT_METHOD = {'ODE': 'ode', 'Static': 'static'}
DataType = namedtuple('DataType', ['engine', 'path'])
EXX_GYGI_TREATMENT = 'gygi-baldereschi'
EXX_TREATMENTS = {
    'Gygi Baldereschi': 'gygi-baldereschi',
    'Spherical': 'vcut_spherical',
    'Anisotropic': 'vcut_ws',
    'None': 'none'
}
DYNMAT_ASR = {
    'Crystal': 'crystal',
    'No': 'no',
    'Simple': 'simple',
    'One dimension': 'one-dim',
    'Zero dimensions': 'zero-dim'
}
[docs]def validate_value(name, value, allowed_values):
    """
    Check that value is in list of allowed values.
    :type name: str
    :param name: Name of the variable
    :type value: str or int or float
    :param value: Value of the variable
    :type allowed_values: list
    :param allowed_values: List of values
    :raise ValueError: If value is not in list
    """
    if value not in allowed_values:
        raise ValueError('%s value error: %s. Allowed types: ' % (name, value) +
                         ', '.join(allowed_values)) 
[docs]class GenericType(object):
    """
    Generic class for different input types.
    """
[docs]    def __init__(self):
        """
        Initialize GenericType object and update attributes from data.
        """
        self._attributes = list(self.DEFAULTS)
        self.data = {}
        self.updateWithData(self.DEFAULTS) 
[docs]    def updateWithData(self, data):
        """
        Update attributes from data. Set object attribute from data.
        Example: self.attribute = self.data.get(attribute)
        :type data: dict
        :param data: Dictionary of settings
        """
        self.data = msutils.deep_update_dict(self.data, data)
        self._cleanData(data)
        for attribute in self._attributes:
            setattr(self, attribute, self.data.get(attribute)) 
    def _cleanData(self, data):
        """
        Any operations on data before setting/updating attributes.
        """ 
[docs]class BandsType(GenericType):
    """
    Class to generate QE input XML section related to bands.
    """
    TOT_CHARGE_KEY, TOT_MAG_KEY = 'total_charge', 'total_magnetization'
    MIN_EMPTY_BANDS = 10
    DEFAULTS = {
        'nbnd': 0,
        'nbnd_empty_percent': 20.0,
        'smearing_degauss': 0.01,
        TOT_CHARGE_KEY: 0.0,
        TOT_MAG_KEY: -1.0,
        'smearing_type': list(SMEARING_TYPE.values())[0],
        'occupations': OCCUPATIONS[0],
    }
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        :raise: ValueError if self.nbnd is not set (or zero)
        """
        validate_value('smearing_type', self.smearing_type,
                       list(SMEARING_TYPE.values()))
        validate_value('occupations', self.occupations, OCCUPATIONS)
        if not self.nbnd:
            raise ValueError('nbnd is not defined or zero')
        empty_bands = self.nbnd * self.nbnd_empty_percent / 100.0
        # For systems with small number of valence electrons, when small
        # nbnd_empty_percent is picked (but larger than zero), empty_bands might
        # be 0, this should be prevented
        if self.nbnd_empty_percent:
            empty_bands = max(empty_bands, self.MIN_EMPTY_BANDS)
        nbnd = int(numpy.ceil(self.nbnd + empty_bands))
        ret = '<bands>\n'
        ret += '<nbnd>%d</nbnd>\n' % nbnd
        if self.occupations == 'smearing':
            ret += ('<smearing degauss="%f">%s</smearing>\n' %
                    (self.smearing_degauss, self.smearing_type))
        ret += '<tot_charge>%f</tot_charge>\n' % self.total_charge
        ret += ('<tot_magnetization>%f</tot_magnetization>\n' %
                self.total_magnetization)
        ret += '<occupations>%s</occupations>\n' % self.occupations
        ret += '</bands>\n'
        return ret 
[docs]class BasisType(GenericType):
    """
    Class to generate QE input XML section related to basis.
    """
    GAMMA_ONLY_KEY, SPLINE_PS_KEY = 'gamma_only', 'spline_ps'
    DEFAULTS = {
        'ecutwfc': 0.,
        'ecutrho': 0.,
        GAMMA_ONLY_KEY: False,
        'fft_grid': [],
        'fft_smooth_grid': [],
        SPLINE_PS_KEY: False
    }
    def _validateGrid(self, grid):
        """
        Validate grid list
        :type grid: list
        :param grid: List of integers describing a grid
        :rtype: list
        :return: List of integers describing a grid
        :raise: ValueError if grid is present but has more or less than three
            values
        """
        if len(grid) not in [0, 3]:
            raise ValueError('grid must have three integer values')
        return [int(coord) for coord in grid]
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        :raise: ValueError if self.ecutwfc or self.ecutrho is not set (or zero)
        """
        self.fft_grid = self._validateGrid(self.fft_grid)
        self.fft_smooth_grid = self._validateGrid(self.fft_smooth_grid)
        if not self.ecutwfc:
            raise ValueError('ecutwfc is not defined or zero')
        if not self.ecutrho:
            raise ValueError('ecutrho is not defined or zero')
        ret = '<basis>\n'
        ret += '<gamma_only>%s</gamma_only>\n' % str(self.gamma_only).lower()
        ret += '<ecutwfc>%f</ecutwfc>\n' % self.ecutwfc
        ret += '<ecutrho>%f</ecutrho>\n' % self.ecutrho
        if self.fft_grid:
            ret += '<fft_grid nr1="%d" nr2="%d" nr3="%d" />\n' % tuple(
                self.fft_grid)
        if self.fft_smooth_grid:
            ret += '<fft_smooth nr1="%d" nr2="%d" nr3="%d" />\n' % tuple(
                self.fft_smooth_grid)
        ret += '<spline_ps>%s</spline_ps>\n' % str(self.spline_ps).lower()
        ret += '</basis>\n'
        return ret 
[docs]class ElectronControlType(GenericType):
    """
    Class to generate QE input XML section related to electron control.
    """
    DEFAULTS = {
        'diagonalization': list(DIAG_TYPE.values())[0],
        'mixing_mode': list(MIXING_MODE.values())[0],
        'mixing_beta': 0.7,
        'conv_thr': 1.e-6,
        'mixing_ndim': 8,
        'max_steps': 100,
        'real_space_q': False,
        'diago_thr_init': 0.0,
        'diago_full_acc': False,
        'diago_cg_maxiter': 20
    }
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        validate_value('diagonalization', self.diagonalization,
                       list(DIAG_TYPE.values()))
        validate_value('mixing_mode', self.mixing_mode,
                       list(MIXING_MODE.values()))
        ret = '<electron_control>\n'
        ret += '<diagonalization>%s</diagonalization>\n' % self.diagonalization
        ret += '<mixing_mode>%s</mixing_mode>\n' % self.mixing_mode
        ret += '<mixing_beta>%f</mixing_beta>\n' % self.mixing_beta
        ret += '<conv_thr>%.5e</conv_thr>\n' % self.conv_thr
        ret += '<mixing_ndim>%d</mixing_ndim>\n' % self.mixing_ndim
        ret += '<max_nstep>%d</max_nstep>\n' % self.max_steps
        ret += ('<real_space_q>%s</real_space_q>\n' %
                str(self.real_space_q).lower())
        ret += '<tq_smoothing>false</tq_smoothing>\n'
        ret += '<tbeta_smoothing>false</tbeta_smoothing>\n'
        ret += '<diago_thr_init>%f</diago_thr_init>\n' % self.diago_thr_init
        ret += ('<diago_full_acc>%s</diago_full_acc>\n' %
                str(self.diago_full_acc).lower())
        ret += ('<diago_cg_maxiter>%d</diago_cg_maxiter>\n' %
                self.diago_cg_maxiter)
        ret += '</electron_control>\n'
        return ret 
[docs]class KpointsType(GenericType):
    """
    Class to generate QE input XML section related to k-points.
    """
    KPTS_MESH_KEY, IS_FOR_ESM = 'kpts_mesh', 'is_for_esm'
    AUTOMATIC_BAND = 'automatic'
    DEFAULTS = {
        KPTS_MESH_KEY: [6, 6, 6, 1, 1, 1],
        'kpts_list': [],
        'kpts_dens': 0,
        'kpts_spacing': 0,
        'kpts_band': False,
        'kpts_dens_force_gamma': False,
        'kpts_spacing_force_gamma': False,
        'kpts_band_line_density': 20,
        IS_FOR_ESM: False
    }
    KPTS_TYPES = {
        'kpts_mesh', 'kpts_list', 'kpts_dens', 'kpts_spacing', 'kpts_band'
    }
[docs]    def setStructure(self, struct):
        """
        Set structure in self.struct and several other attributes: self.vecs and
        self.alat.
        :type struct: `structure.Structure`
        :param struct: Structure used for k-point generation ('kpts_dens' case)
        """
        self.struct = struct
        if self.struct:
            self.vecs = xtal.get_vectors_from_chorus(self.struct)
            self.rvecs = xtal.get_reciprocal_lattice_vectors(*self.vecs)
            self.alat = sqrt(numpy.dot(self.vecs[0], self.vecs[0])) 
    def _cleanData(self, data):
        """
        Remove mutually exclusive kpoint types. Top-priority type comes from
        data variable.
        :type data: dict
        :param data: Top-priority dictionary of settings
        """
        kpts_key = None
        # Find kpoint type present in the 'top-priority' data
        for key in self.KPTS_TYPES:
            if data.get(key):
                kpts_key = key
        # Delete other kpoint types if 'top-priority' key is present
        if kpts_key:
            for kpts_type in self.KPTS_TYPES.difference({kpts_key}):
                self.data.pop(kpts_type, None)
                # Also unset class attribute
                setattr(self, kpts_type, None)
    def _validateKpoints(self):
        """
        Validate k-points. See PW input file schema documentation
        :type kpts_mesh: list
        :param kpts_mesh: List of 6 ints (nk1, nk2, nk3 and sk1, sk2, sk3)
        :type kpts_list: list
        :param kpts_list: List of lists of 4 floats (xk_x, xk_y, xk_z, wk)
        :raise: ValueError if kpts_mesh has wrong number of values
        :raise: ValueError if an item from kpts_list has wrong number of values
        """
        # This 'BIG'-if ensures that at least one option is present
        if self.kpts_mesh:
            if len(self.kpts_mesh) != 6:
                raise ValueError('kpts_mesh must have exactly six values')
        elif self.kpts_list:
            for kpt in self.kpts_list:
                if len(kpt) != 4:
                    raise ValueError('kpts_list must have lists of four values')
        elif self.kpts_dens:
            if not self.kpts_dens >= 1:
                raise ValueError('kpts_dens must be positive integer')
        elif self.kpts_spacing:
            if not self.kpts_spacing > 0.:
                raise ValueError('kpts_spacing must be positive number')
        elif self.kpts_band:
            pass
        else:
            raise ValueError('k-points have to be defined')
    def _generateBandPath(self, path=None, line_density=20.0):
        """
        Get list of k-points based on the (currently predefined) path.
        :type path: list of lists of 3 floats or None
        :param path: list of k-points in reciprocal fractional coordinates
        :type line_density: float
        :param line_density: Density of the k-points between 'edges'
        :rtype: list of lists of 3 floats
        :return: list of k-point reciprocal Cartesian coordinates
        """
        # Based on: pymatgen/symmetry/bandstructure.py :: get_kpoints
        # MIT license
        if not path:
            kpath = qeu.HighSymmetryKPath(self.struct, is_2d=self.is_for_esm)
            path = kpath.kpath
        k_points = []
        for indx in range(1, len(path)):
            start, start_label = KPoint.getKptFromCfg(path[indx - 1])
            start_c = self._toCart(start)
            end, end_label = KPoint.getKptFromCfg(path[indx])
            end_c = self._toCart(end)
            distance = numpy.linalg.norm(start_c - end_c)
            nkpts = int(ceil(distance * line_density))
            # Ensure that there is at least one k-point
            if nkpts == 0:
                continue
            for jndx in range(nkpts + 1):
                coord = start_c + float(jndx) / float(nkpts) * (end_c - start_c)
                coord = coord.tolist()
                if jndx == 0:
                    coord.extend([start_label])
                elif jndx == nkpts:
                    coord.extend([end_label])
                # Prevents adding 'edge' point twice
                if len(k_points):
                    kpt_np = numpy.array(k_points[-1][:3])
                    coord_np = numpy.array(coord[:3])
                    if numpy.linalg.norm(kpt_np - coord_np) < 0.001:
                        continue
                k_points.extend([coord])
        return k_points
[docs]    @staticmethod
    def getMeshShift(do_shift):
        """
        Get MP mesh shift.
        :param bool do_shift: Whether to do shift or not (Gamma centered)
        :rtype: list[int]
        :return: K-point mesh shift
        """
        return [1, 1, 1] if do_shift else [0, 0, 0] 
    def _generateDensity(self, ngrid, force_gamma=False):
        """
        Generate k-point mesh based on unit cell dimensions. A simple approach
        of scaling the number of divisions along each reciprocal lattice vector
        proportional to its length. Uses Gamma centered meshes for hexagonal
        cells and Monkhorst-Pack grids otherwise.
        :type ngrid: int
        :param ngrid: Grid density
        :type force_gamma: bool
        :param force_gamma: Enforce Gamma-centered mesh
        :rtype: list of 6 floats
        :return: First 3 floats represent the number of kpoints in grid
            directions, next 3 floats, offsets in the corresponding direction
        """
        # Based on: pymatgen/io/vasp/inputs.py :: automatic_density
        # MIT license
        lengths = xtal.get_lattice_param_properties(self.struct)[:3]
        ngrid /= self.struct.atom_total
        mult = (ngrid * lengths[0] * lengths[1] * lengths[2])**(1. / 3.)
        num_div = [int(numpy.floor(max(mult / l, 1))) for l in lengths]
        if self.is_for_esm:
            # ESM requires kz to be 1 without shift (Gamma point in z)
            num_div[2] = 1
        # TODO: Set hexagonal correctly from the structure
        is_hexagonal = False
        has_odd = any([i % 2 == 1 for i in num_div])
        do_shift = not (has_odd or is_hexagonal or force_gamma or
                        self.is_for_esm)
        shift = self.getMeshShift(do_shift)
        return num_div + shift
[docs]    @staticmethod
    def getFromSpacing(spacing, rvecs, force_gamma=False, is_for_esm=False):
        """
        Get MP mesh from K-point spacing.
        :type rvecs: numpy.array
        :param rvecs: Reciprocal lattice vectors
        :type spacing: float
        :param spacing: K-point spacing in 1/A
        :type force_gamma: bool
        :param force_gamma: Enforce Gamma-centered mesh
        :param bool is_for_esm: Whether this is for ESM (kz must be 0)
        :rtype: list of 6 floats
        :return: First 3 floats represent the number of kpoints in grid
            directions, next 3 floats, offsets in the corresponding direction
        """
        # See MATSCI-6285 for implementation details.
        rvol = xtal.get_volume_from_vecs(rvecs)
        mesh = [0, 0, 0]
        for idx, vecs in enumerate([[rvecs[1], rvecs[2]], [rvecs[2], rvecs[0]],
                                    [rvecs[0], rvecs[1]]]):
            val = (rvol / numpy.linalg.norm(numpy.cross(*vecs), ord=2) /
                   spacing) + 1
            # Hard-code z k-point for the esm case (MATSCI-11377)
            val = 1 if (is_for_esm and idx == 2) else int(val)
            mesh[idx] = val
        do_shift = not (force_gamma or is_for_esm)
        shift = KpointsType.getMeshShift(do_shift)
        return mesh + shift 
[docs]    def getNKpts(self):
        """
        Get total number of k-points based on k-points definition. If k-points
        type is defined, throw a ValueError.
        :rtype: int
        :return: Number of k-points
        :raise ValueError: If k-points type is not defined
        """
        ret = 0
        if self.kpts_mesh:
            ret = numpy.prod(self.kpts_mesh[:3])
        elif self.kpts_dens:
            kpts_mesh = self._generateDensity(self.kpts_dens,
                                              self.kpts_dens_force_gamma)
            ret = numpy.prod(kpts_mesh[:3])
        elif self.kpts_spacing:
            kpts_mesh = self.getFromSpacing(self.kpts_spacing, self.rvecs,
                                            self.kpts_spacing_force_gamma,
                                            self.is_for_esm)
            ret = numpy.prod(kpts_mesh[:3])
        elif self.kpts_band:
            if self.kpts_band == self.AUTOMATIC_BAND:
                ret = len(DEFAULT_KPTS_BAND_PATH) * self.kpts_band_line_density
            else:
                ret = len(self.kpts_band) * self.kpts_band_line_density
        elif self.kpts_list:
            ret = len(self.kpts_list)
        else:
            raise ValueError('Cannot compute number of k-points.')
        return int(round(ret)) 
[docs]    def matdynStr(self):
        """
        Get kpoint list suitable for matdyn input.
        :rtype: str
        :return: Kpoint list
        """
        if not self.kpts_band:
            raise ValueError('matdynStr requires kpts_band.')
        ret = ''
        if self.kpts_band == self.AUTOMATIC_BAND:
            kpts_list = self._generateBandPath(
                line_density=self.kpts_band_line_density)
        else:
            kpts_list = self._generateBandPath(
                path=self.kpts_band, line_density=self.kpts_band_line_density)
        ret += '%d\n' % len(kpts_list)
        for kpt in kpts_list:
            if len(kpt) == 3:
                ret += '%f %f %f *\n' % tuple(kpt)
            else:
                label = kpt[-1].replace(' ', '_')
                ret += '%f %f %f %s\n' % (kpt[0], kpt[1], kpt[2], label)
        return ret 
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        self._validateKpoints()
        ret = '<k_points_IBZ>\n'
        if self.kpts_mesh:
            ret += ('<monkhorst_pack nk1="%d" nk2="%d" nk3="%d"'
                    ' k1="%d" k2="%d" k3="%d">K-point mesh</monkhorst_pack>\n' %
                    tuple(self.kpts_mesh))
        if self.kpts_dens:
            kpts_mesh = self._generateDensity(self.kpts_dens,
                                              self.kpts_dens_force_gamma)
            ret += ('<monkhorst_pack nk1="%d" nk2="%d" nk3="%d"'
                    ' k1="%d" k2="%d" k3="%d">K-point mesh</monkhorst_pack>\n' %
                    tuple(kpts_mesh))
        elif self.kpts_spacing:
            kpts_mesh = self.getFromSpacing(self.kpts_spacing, self.rvecs,
                                            self.kpts_spacing_force_gamma,
                                            self.is_for_esm)
            ret += ('<monkhorst_pack nk1="%d" nk2="%d" nk3="%d"'
                    ' k1="%d" k2="%d" k3="%d">K-point mesh</monkhorst_pack>\n' %
                    tuple(kpts_mesh))
        elif self.kpts_list:
            ret += '<nk>%d</nk>\n' % len(self.kpts_list)
            for kpt in self.kpts_list:
                kpt_c = self._toCart(kpt[:3])
                ret += '<k_point weight="%f">%f %f %f</k_point>\n' % \
                                         
(kpt[-1], kpt_c[0], kpt_c[1], kpt_c[2])
        elif self.kpts_band:
            if self.kpts_band == self.AUTOMATIC_BAND:
                kpts_list = self._generateBandPath(
                    line_density=self.kpts_band_line_density)
            else:
                kpts_list = self._generateBandPath(
                    path=self.kpts_band,
                    line_density=self.kpts_band_line_density)
            ret += '<nk>%d</nk>\n' % len(kpts_list)
            for kpt in kpts_list:
                if len(kpt) == 3:
                    ret += ('<k_point weight="1.0">%f %f %f</k_point>\n' %
                            tuple(kpt))
                else:
                    ret += ('<k_point label="%s" weight="1.0">%f %f %f'
                            '</k_point>\n' % (kpt[-1], kpt[0], kpt[1], kpt[2]))
        ret += '</k_points_IBZ>\n'
        return ret
    def _toCart(self, coords):
        """
        Transform k-point coordinates from fractional reciprocal to Cartesian
        reciprocal in the 2pi / alat units.
        :type coords: list of 3 floats
        :param coords: Coordinates of the k-point
        :rtype: numy.array
        :return: Cartesian coordinates of the k-point
        """
        coords_np = numpy.array(coords)
        return xtal.trans_frac_to_cart_from_vecs(
            coords_np, *self.vecs, rec=True) * self.alat 
[docs]class SpinType(GenericType):
    """
    Class to generate QE input XML section related to spin.
    """
    DEFAULTS = {'spin_type': list(SPIN_TYPE.values())[0]}
    SPIN_NON_LSDA_ID, SPIN_LSDA_ID = 1, 2
[docs]    def __init__(self):
        """
        Initialize SpinType object from string.
        :type data: dict
        :param data: Dictionary of settings
        """
        self._spin_types = list(SPIN_TYPE.values())
        GenericType.__init__(self) 
    def _validateSpin(self, spin):
        """
        Validate spin type.
        :type spin: str
        :param spin: Spin treatment
        :rtype: str
        :return: spin treatment
        :raise: ValueError if spin is of an unknown type
        """
        if spin not in self._spin_types:
            raise ValueError('spin value error: %s' % str(spin))
        return spin
    def _get_tag(self, tag, val):
        """
        Generate an XML tag
        :type tag: str
        :param tag: Tag name
        :type val: str
        :param val: Tag value
        :rtype: str
        :return: Empty str if tag is None, otherwise the tag
        """
        if tag == SPIN_CS:
            return ''
        else:
            return '<%s>%s</%s>\n' % (tag, str(val).lower(), tag)
[docs]    def getNSpin(self):
        """
        Get nspin based on the spin settings.
        :return: Spin type in the PW integer representation
        :rtype: int
        """
        spin_type = self._validateSpin(self.spin_type)
        if spin_type == SPIN_LSDA:
            return self.SPIN_LSDA_ID
        else:
            return self.SPIN_NON_LSDA_ID 
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        self.spin_type = self._validateSpin(self.spin_type)
        ret = '<spin>\n'
        if self.spin_type != SPIN_SO:
            for spin_type in self._spin_types:
                ret += self._get_tag(spin_type, self.spin_type == spin_type)
        else:
            # Spin-orbit requires a non collinear calculation
            ret += self._get_tag(SPIN_LSDA, False)
            ret += self._get_tag(SPIN_NCL, True)
            ret += self._get_tag(SPIN_SO, True)
        ret += '</spin>\n'
        return ret 
[docs]class VdwType(GenericType):
    """
    Class to generate QE input XML section related to vdw type.
    """
    DEFAULTS = {
        'correction': '',
        'london_s6': 0.75,
        'london_rcut': 200,
        'dftd3_version': 3,
        'dftd3_threebody': True,
        'xdm_a1': 0.6836,
        'xdm_a2': 1.5045
    }
[docs]    def __init__(self):
        """
        Initialize VdwType object from dictionary.
        :type data: dict
        :param data: Dictionary of settings
        """
        super().__init__() 
    def _validateVdw(self):
        """
        Validate vdW type and related variables.
        :raise: ValueError if vdw correction is of an unknown type
        """
        corrections = list(VDW_CORRECTION.values()) + list(VDW_CORRECTION_TEST)
        if self.correction not in corrections:
            raise ValueError('vdw type value error: %s' % str(self.correction))
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        self._validateVdw()
        ret = ''
        if self.correction:
            ret += '<vdW>\n'
            ret += '<vdw_corr>%s</vdw_corr>\n' % self.correction
            if self.correction == DFT_D_VDW_CORR:
                ret += '<london_s6>%f</london_s6>\n' % self.london_s6
                ret += '<london_rcut>%f</london_rcut>\n' % self.london_rcut
            elif self.correction == DFT_D3_VDW_CORR:
                ret += ('<dftd3_version>%d</dftd3_version>\n' %
                        self.dftd3_version)
                ret += ('<dftd3_threebody>%s</dftd3_threebody>\n' %
                        str(self.dftd3_threebody).lower())
            elif self.correction == XDM_VDW_CORRECTION:
                ret += '<xdm_a1>%f</xdm_a1>\n' % self.xdm_a1
                ret += '<xdm_a2>%f</xdm_a2>\n' % self.xdm_a2
            ret += '</vdW>\n'
        return ret 
[docs]class HybridType(GenericType):
    """
    Class to generate QE input XML section related to hybrid functional type.
    """
    SCREEN_PARAM_KEY = 'screening_parameter'
    DEFAULTS = {
        'qpts_mesh': [1, 1, 1],
        'ecutvcut': 0.7,
        'x_gamma_extrapolation': True,
        'exxdiv_treatment': EXX_GYGI_TREATMENT,
        SCREEN_PARAM_KEY: 0.106,
        'ecutfock': None,
    }
[docs]    def __init__(self):
        """ Initialize object. """
        super().__init__()
        self._is_lrc = False 
[docs]    def setLRC(self, is_lrc):
        """
        Set _is_lrc value.
        :param bool is_lrc: Whether this belongs to LRC or hybrid functional
        """
        self._is_lrc = is_lrc 
    def __str__(self):
        """
        Return XML from attributes.
        :rtype: str
        :return: XML string
        """
        validate_value('exxdiv_treatment', self.exxdiv_treatment,
                       EXX_TREATMENTS.values())
        ret = '<hybrid>\n'
        ret += ('<qpoint_grid nqx1="%d" nqx2="%d" nqx3="%d" />\n' %
                tuple(self.qpts_mesh))
        if self.ecutfock is not None:
            # ecutfock is expected in Ha
            ret += '<ecutfock>%f</ecutfock>\n' % (self.ecutfock / qeu.HA2RY)
        if self._is_lrc:
            ret += ('<screening_parameter>%f</screening_parameter>\n' %
                    self.screening_parameter)
        ret += ('<exxdiv_treatment>%s</exxdiv_treatment>\n' %
                self.exxdiv_treatment)
        ret += ('<x_gamma_extrapolation>%s</x_gamma_extrapolation>\n' %
                str(self.x_gamma_extrapolation).lower())
        # exxdiv_treatment 'gygi-baldereschi' doesn't require ecutvcut
        # (MATSCI-5034)
        # ecutvcut is expected in Ha
        ret += '<ecutvcut>%f</ecutvcut>\n' % (self.ecutvcut / qeu.HA2RY)
        ret += '</hybrid>\n'
        return ret 
[docs]class DftUType(GenericType):
    """
    Class to generate QE input XML section related to Dft U.
    """
    DEFAULTS = {'structure_type': None, 'dftu_type': DFTU_NONE}
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        if (getattr(self, 'structure_type', None) is None or
                self.dftu_type == DFTU_NONE):
            return ''
        ret = '<dftU>\n'
        ret += '<lda_plus_u_kind>%s</lda_plus_u_kind>\n' % self.dftu_type
        hubb_present = False
        for mag_name, element in \
                
self.structure_type.mag_species.species.items():
            element_mag = self.structure_type.mag_species.getMag(
                element, mag_name)
            if element_mag.hubb_u != 0.0:
                ret += (
                    '<Hubbard_U specie="%s" label="label1">%.7e</Hubbard_U>\n' %
                    (mag_name, element_mag.hubb_u))
                hubb_present = True
            if element_mag.hubb_j0 != 0.0:
                ret += (
                    '<Hubbard_J0 specie="%s" label="label1">%.7e</Hubbard_J0>\n'
                    % (mag_name, element_mag.hubb_j0))
                hubb_present = True
        ret += '<U_projection_type>atomic</U_projection_type>\n'
        ret += '</dftU>\n'
        return ret if hubb_present else '' 
[docs]class DftType(GenericType):
    """
    Class to generate QE input XML section related to dft type.
    """
    DEFAULTS = {
        'functional': '',
        'vdw': VdwType.DEFAULTS,
        'hybrid': HybridType.DEFAULTS,
        'dftu': DftUType.DEFAULTS
    }
    @property
    def vdw(self):
        return self._vdw
    @vdw.setter
    def vdw(self, data):
        if not hasattr(self, 'vdw'):
            self._vdw = VdwType()
        self.vdw.updateWithData(data)
    @property
    def hybrid(self):
        """
        Get hybrid attribute (self._hybrid).
        :return: Hybrid attribute
        :rtype: `HybridType`
        """
        return self._hybrid
    @hybrid.setter
    def hybrid(self, data):
        """
        Set hybrid attribute (self._hybrid).
        :param data: Dictionary with data
        :type data: dict
        """
        if not hasattr(self, 'hybrid'):
            self._hybrid = HybridType()
        self.hybrid.updateWithData(data)
        self.hybrid.setLRC(self.functional in LRC_FUNCTIONALS)
    @property
    def dftu(self):
        """
        Get DftU attribute from self._dftu.
        :return: Hybrid attribute
        :rtype: `HybridType`
        """
        return self._dftu
    @dftu.setter
    def dftu(self, data):
        """
        Set DftU attribute in self._dftu.
        :param data: Dictionary with data
        :type data: dict
        """
        if not hasattr(self, 'dftu'):
            self._dftu = DftUType()
        self.dftu.updateWithData(data)
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        validate_value('functional', self.functional, DFT_FUNCTIONALS)
        # Special treatment for Gau-PBE functional
        if self.functional == GAU_PBE_FUNCT:
            self.hybrid.exxdiv_treatment = 'none'
            self.hybrid.x_gamma_extrapolation = False
            self.hybrid.ecutvcut = 0.0
        ret = '<dft>\n'
        ret += '<functional>%s</functional>\n' % self.functional
        if self.functional in HYBRID_FUNCTIONALS:
            ret += '%s' % self.hybrid
        ret += '%s' % self.dftu
        ret += '%s' % self.vdw
        ret += '</dft>\n'
        return ret 
[docs]class PseudopotentialsType(GenericType):
    """
    Class to generate QE input XML section related to pseudopotentials.
    """
    DEFAULTS = {'species': {}}
[docs]    def setStructureType(self, struct_type):
        """
        Set structure_type in self.struct_type and atom types in
        self.st_species.
        :type struct_type: StructureType
        :param struct_type: Structure type used to calculate number of valence
            electrons in the cell
        """
        self.stype = struct_type 
[docs]    def getData(self):
        """
        Get pseudopotentials file paths and other related data from the database
        having the same functional and family.
        :rtype: dict, float, float, float
        :return: Dict with paths to PPs ({'element': 'path'}), number of valence
            electrons, max ecutwfc, max ecutrho for elements
        """
        missing_elems = set(list(
            self.stype.mag_species.species.values())).difference(
                set(list(self.species)))
        if missing_elems:
            missing_elems_str = ', '.join(missing_elems)
            raise ValueError('Pseudopotentials are missing for the following '
                             'elements: %s' % missing_elems_str)
        ppfiles = {}
        ecutwfcs = []
        ecutrhos = []
        zvals = {}
        # Iterate over elements present in structure only
        for element in self.stype.mag_species.species.values():
            data = self.species[element]
            ecutwfcs.append(data[0])
            ecutrhos.append(data[1])
            zvals[element] = data[2]
            ppfiles[element] = data[3]
        max_ecutwfc = max(ecutwfcs)
        max_ecutrho = max(ecutrhos)
        nvelect = sum([zvals[x.element] for x in self.stype.struct.atom])
        return ppfiles, nvelect, max_ecutwfc, max_ecutrho  
[docs]class StructureType(object):
    """
    Class to generate QE input XML section related to structure.
    """
    X, Y, Z = list(range(3))
[docs]    def __init__(self, structs, to_primitive, log=None):
        """
        Initialize StructureType object.
        :type structs: List of structure.Structure or structure.Structure
        :param struct: Structure(s) to generate XML from
        :type to_primitive: bool or None
        :param to_primitive: Whether to use primitive (True),
            conventional (False), don't run spglib (None)
        """
        self.log = log if log else lambda x: None
        if isinstance(structs, Iterable):
            self.structs = [st.copy() for st in structs]
        else:
            self.structs = [structs.copy()]
        self.mag_species = qeu.MagSpecies()
        for idx, struct in enumerate(self.structs):
            self._validatePBCs(struct)
            if idx == 0:
                self.struct = self.structs[0]
            else:
                self.validateStruct(struct)
        self.setElements()
        if to_primitive is not None and len(self.structs) == 1:
            # Don't run spglib for NEB calculations
            self.setSymmetrizedCell(to_primitive)
        self._getAtomicSpecies() 
[docs]    def setSymmetrizedCell(self, to_primitive):
        """
        Symmetrize cell using spglib.
        :param bool to_primitive: Whether to obtain primitive (True) or
            conventional (False)
        """
        vecs = numpy.array(xtal.get_vectors_from_chorus(self.struct))
        fcoords = xtal.trans_cart_to_frac_from_vecs(self.struct.getXYZ(), *vecs)
        spg_cell = (vecs, fcoords, self.anums)
        new_cell = spglib.standardize_cell(spg_cell,
                                           to_primitive=to_primitive,
                                           symprec=qeu.PDB_PREC)
        if new_cell is None:
            return
        vecs, fcoords, anums = new_cell
        self.anums = anums
        coords = xtal.trans_frac_to_cart_from_vecs(fcoords, *vecs)
        struct = structure.create_new_structure()
        for number, xyz in zip(self.anums, coords):
            atom = struct.addAtom('C', *xyz)
            mag_element = self.elements[number]
            # Remove digits from atom type (might be present due to starting
            # magnetization)
            element = ''.join(x for x in mag_element if not x.isdigit())
            atom.element = element
            mag_element = self.mag_species.getMag(element, mag_element)
            qeu.set_mag_hubbu(atom, mag_element)
        if msutils.has_atom_property(self.struct, msprops.QE_CART_ATOM_CONSTR):
            # Only copy atom Cartesian constraints when number of atoms is same
            if struct.atom_total == self.struct.atom_total:
                for src_atom, dest_atom in zip(self.struct.atom, struct.atom):
                    qeu.copy_atom_constr(src_atom, dest_atom)
            else:
                self.log('Atomic constraints will be discarded.')
        xtal.set_pbc_properties(struct, vecs.flat)
        struct.property[xtal.SPACE_GROUP_KEY] = self.space_group
        self.structs[0] = self.struct = struct 
[docs]    def setElements(self):
        """
        Set elements in self.elements and atomic numbers for spglib in
        self.anums.
        """
        self.elements = []
        for atom in self.struct.atom:
            mag_element = qeu.get_mag_hubbu(atom)
            element = self.mag_species.createUniqueElement(
                atom.element, mag_element)
            self.elements.append(element)
        self.anums = [self.elements.index(x) for x in self.elements] 
[docs]    def getAtomicStructure(self):
        """
        Generate XML input related to atomic_structure.
        :rtype: str
        :return: cell xml data
        """
        ret = ''
        for struct in self.structs:
            ret += ('<atomic_structure nat="%d" alat="1.0">\n' %
                    struct.atom_total)
            ret += '<crystal_positions>\n'
            vecs = xtal.get_vectors_from_chorus(struct)
            fracs = xtal.trans_cart_to_frac_from_vecs(struct.getXYZ(), *vecs)
            for jdx, (atom, anum) in enumerate(zip(struct.atom, self.anums)):
                ret += ('<atom name="%s" index="%d">\n' %
                        (self.elements[anum], jdx + 1))
                ret += '%f %f %f\n' % tuple(fracs[jdx])
                ret += '</atom>\n'
            ret += '</crystal_positions>\n'
            ret += self._getCell()
            ret += '</atomic_structure>\n'
        return ret 
[docs]    def getAtomicSpecies(self, ppfiles):
        """
        Generate XML input related to atomic species.
        :type ppfiles: dict({str: str})
        :param: Dictionary containing paths to PPs ({'Element': 'Path'})
        :rtype: str
        :return: species xml data
        """
        ntyp = len(self.mag_species.species)
        ret = '<atomic_species ntyp="%d">\n' % ntyp
        for element_mag_name, element in self.mag_species.species.items():
            ret += '<species name="%s">\n' % element_mag_name
            mass = get_atomic_weight(
                mm.mmelement_get_atomic_number_by_symbol(element))
            ret += '<mass>%f</mass>\n' % mass
            pseudo_fn = os.path.split(ppfiles[element])[1]
            ret += '<pseudo_file>%s</pseudo_file>\n' % pseudo_fn
            element_mag = self.mag_species.getMag(element, element_mag_name)
            ret += ('<starting_magnetization>%f</starting_magnetization>\n' %
                    element_mag.mag)
            ret += '</species>\n'
        ret += '</atomic_species>\n'
        return ret 
    def _getPPNelect(self, zvals):
        """
        Get number of the valence electrons based on the Z valence values of the
        pseudopotentials.
        :type zvals: dict
        :param zvals: Dict with valence electrons: {'element': 'zval'}
        :rtype: float
        :return: Total number of valence electrons for self.struct
        """
        return sum([zvals[x.element] for x in self.struct.atom])
[docs]    def fetchPPDB(self, functional):
        """
        Get pseudopotentials file paths and other related data from the database
        having the same functional and family.
        :type functional: str
        :param functional: DFT Functional (one of the DFT_FUNCTIONALS)
        :rtype: dict, float, float, float
        :return: Dict with paths to PPs ({'element': 'path'}), number of valence
            electrons, max ecutwfc, max ecutrho for elements
        """
        import sqlite3
        from schrodinger.application.matsci.espresso import ppdb
        PP_TABLE_SELECT = """
            SELECT DISTINCT a.atomic_number as atomic_number,
            a.ecutwfc as ecutwfc, a.ecutrho as ecutrho, a.pp_fn as pp_fn,
            a.z_valence as zval
            FROM %s a INNER JOIN %s b ON a.family = b.family
            AND a.sha1_ppfile_checksum != b.sha1_ppfile_checksum
            WHERE a.dft_functional=? AND a.atomic_number IN (%s)
        """
        atomic_numbers = list(
            map(mm.mmelement_get_atomic_number_by_symbol, list(self.species)))
        an_placeholders = ', '.join(['?'] * len(atomic_numbers))
        conn, cur = ppdb.get_ppdb()
        conn.row_factory = sqlite3.Row
        cur = conn.cursor()
        cur.execute(
            PP_TABLE_SELECT %
            (ppdb.PP_TABLE_NAME, ppdb.PP_TABLE_NAME, an_placeholders),
            ([functional] + atomic_numbers))
        results = cur.fetchall()
        ppfiles = {}
        zvals = {}
        for row in results:
            atomic_symbol = mm.mmelement_get_symbol_by_atomic_number(
                row['atomic_number'])
            ppfiles[atomic_symbol] = row['pp_fn']
            zvals[atomic_symbol] = row['zval']
        if len(self.species) != len(ppfiles):
            set_req = set(list(self.species))
            set_db = set(list(ppfiles))
            missing_species = ' '.join(set_req.difference(set_db))
            raise ValueError('Could not find pseudopotentials for all atom '
                             'types. Missing species: %s' % missing_species)
        pp_nelect = self._getPPNelect(zvals)
        max_ecutwfc = max([row['ecutwfc'] for row in results])
        max_ecutrho = max([row['ecutrho'] for row in results])
        return ppfiles, pp_nelect, max_ecutwfc, max_ecutrho 
    def _getAtomicSpecies(self):
        """
        Extract number of electrons (in self.nelect) and species
        (in self.species) from self.struct.
        """
        self.nelect = 0
        self.species = {}
        for atom in self.struct.atom:
            self.nelect += atom.atomic_number
            mag_element = qeu.get_mag_hubbu(atom)
            self.mag_species.createUniqueElement(atom.element, mag_element)
            self.species[atom.element] = get_atomic_weight(atom.atomic_number)
    def _getCell(self):
        """
        Generate XML input related to cell.
        :rtype: str
        :return: cell xml data
        """
        vecs = xtal.get_vectors_from_chorus(self.struct)
        vecs_bohrs = vecs * qeu.A2B
        ret = '<cell>\n'
        ret += '<a1>%f %f %f</a1>\n' % tuple(vecs_bohrs[0])
        ret += '<a2>%f %f %f</a2>\n' % tuple(vecs_bohrs[1])
        ret += '<a3>%f %f %f</a3>\n' % tuple(vecs_bohrs[2])
        ret += '</cell>\n'
        return ret
    def _validatePBCs(self, struct):
        """
        Check that PBCs are present for the self.structure. Also set the space
        group in self.space_group.
        """
        if not xtal.sync_pbc2(struct):
            raise ValueError('Structure (%s) does not contain PBC information' %
                             struct.title)
        try:
            self.space_group = struct.property[xtal.SPACE_GROUP_KEY]
        except KeyError:
            self.space_group = xtal.P1_SPACE_GROUP_SYMBOL
        struct.property[xtal.SPACE_GROUP_KEY] = self.space_group
[docs]    def alignStructWithPlane(self, vector_index):
        """
        Slide structure (self.struct) along the axis such that all atoms are on
        top of the plane perpendicular to the axis defined by vector_index.
        :type vector_index: int
        :param vector_index: Vector index (X or Y or Z) to align structure to
            the perpendicular plane
        """
        vecs = xtal.get_vectors_from_chorus(self.struct)
        xyz = self.struct.getXYZ()
        fxyz = xtal.trans_cart_to_frac_from_vecs(xyz, *vecs)
        shift = numpy.array([0.0] * 3)
        min_coord = fxyz.min(axis=0)[vector_index]
        shift[vector_index] = min_coord
        shifted_fxyz = fxyz - shift
        shifted_xyz = xtal.trans_frac_to_cart_from_vecs(shifted_fxyz, *vecs)
        self.struct.setXYZ(shifted_xyz)
        self.struct.property[xtal.PBC_POSITION_KEY] = 'anchor_0_0_0' 
[docs]    def validateStruct(self, struct):
        """
        Validate provided structure against self.struct.
        :type struct: `schrodinger.structure.Structure`
        :param struct: Structure to validate
        :raise ValueError: If number of atoms or atomic types differ compared to
            the initial structure (self.structure)
        """
        if struct.atom_total != self.struct.atom_total:
            raise ValueError('Number of atoms differs compared to the initial '
                             'structure')
        for idx, atom in enumerate(self.struct.atom, start=1):
            if struct.atom[idx].element != atom.element:
                raise ValueError('Atom type differs compared to the initial '
                                 'structure')  
[docs]class ControlType(GenericType):
    """
    Class to generate QE input XML section related to control.
    """
    FORCES_KEY = 'forces'
    DEFAULTS = {
        'stress': False,
        FORCES_KEY: False,
        'relax_steps': 50,
        'etot_conv_thr': 1e-5,
        'forc_conv_thr': 1e-3,
        'press_conv_thr': 0.5,
        'title': 'Default Title',
        'calculation_type': '',
        'prefix': '',
        WF_COLLECT_KEY: False,
        WF_KEEP_KEY: False,
    }
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        validate_value('calculation_type', self.calculation_type,
                       CALCULATION_TYPE)
        if not self.prefix:
            raise ValueError('prefix is required and it is not set')
        ret = '<control_variables>\n'
        ret += '<title>%s</title>\n' % self.title
        ret += '<calculation>%s</calculation>\n' % self.calculation_type
        ret += '<restart_mode>from_scratch</restart_mode>\n'
        ret += '<prefix>%s</prefix>\n' % self.prefix
        ret += '<pseudo_dir>./</pseudo_dir>\n'
        ret += '<outdir>./</outdir>\n'
        ret += '<stress>%s</stress>\n' % str(self.stress).lower()
        ret += '<forces>%s</forces>\n' % str(self.forces).lower()
        ret += '<wf_collect>%s</wf_collect>\n' % str(self.wf_collect).lower()
        ret += '<disk_io>low</disk_io>\n'
        ret += '<max_seconds>%d</max_seconds>\n' % 1e7
        ret += '<nstep>%d</nstep>\n' % self.relax_steps
        ret += '<etot_conv_thr>%.5e</etot_conv_thr>\n' % self.etot_conv_thr
        ret += '<forc_conv_thr>%.5e</forc_conv_thr>\n' % self.forc_conv_thr
        ret += '<press_conv_thr>%.5e</press_conv_thr>\n' % self.press_conv_thr
        ret += '<verbosity>low</verbosity>\n'
        ret += '<print_every>100000</print_every>\n'
        ret += '</control_variables>\n'
        return ret 
[docs]class SymmetryType(GenericType):
    """
    Class to generate QE input XML section related to symmetry.
    """
    USE_SYMM_KEY = 'use_symmetry'
    USE_PRIM_KEY = 'use_primitive'
    USE_CONV_KEY = 'use_conventional'
    USE_ALL_FRAC_KEY = 'use_all_frac'
    NO_INV_KEY = 'noinv'
    DEFAULTS = {
        USE_SYMM_KEY: False,
        USE_PRIM_KEY: False,
        USE_CONV_KEY: False,
        USE_ALL_FRAC_KEY: False,
        NO_INV_KEY: False,
    }
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        ret = '<symmetry_flags>\n'
        ret += '<nosym>%s</nosym>\n' % str(not self.use_symmetry).lower()
        ret += '<nosym_evc>false</nosym_evc>\n'
        ret += '<noinv>%s</noinv>\n' % str(self.noinv).lower()
        ret += '<no_t_rev>false</no_t_rev>\n'
        ret += '<force_symmorphic>false</force_symmorphic>\n'
        ret += ('<use_all_frac>%s</use_all_frac>\n' %
                str(self.use_all_frac).lower())
        ret += '</symmetry_flags>\n'
        return ret 
[docs]class FreePositionsType(GenericType):
    """
    Class to generate QE input XML section related to Cartesian atomic
    constraints.
    """
    DEFAULTS = {}
[docs]    @staticmethod
    def saveConstraint(data, atom):
        """
        Static method to save Cartesian atomic constraint.
        :type data: list of 3 integers
        :param data: Cartesian constraints for each coordinate:
            1 - not constrained, 0 - constrained (QE convention)
        :type atom: `structure._StructureAtom`
        :param atom: Atom to add/modify constraint to
        :raise ValueError: If data is not a list of 3 elements
        """
        if len(data) != 3:
            raise ValueError('"data" must be list of 3 integers')
        data_prop_s = atom.property.get(msprops.QE_CART_ATOM_CONSTR)
        if data_prop_s is None:
            atom.property[msprops.QE_CART_ATOM_CONSTR] = json.dumps(data)
            return
        data_prop = json.loads(data_prop_s)
        new_data = [0] * len(data)
        for idx, (val, val_prop) in enumerate(zip(data, data_prop)):
            # Multiplication is used since desired results are:
            # 1 and 1 : 1
            # 1 and 0 : 0
            # 0 and 1 : 0
            # 0 and 0 : 0
            new_data[idx] = int(val) * int(val_prop)
        atom.property[msprops.QE_CART_ATOM_CONSTR] = json.dumps(new_data) 
[docs]    @staticmethod
    def saveTorsionConstraint(struct, indices):
        """
        Static method to save torsion angle constraint.
        :type struct: structure.Structure
        :param struct: Structure to add constraint to as a property
        :type indices: list
        :param indices: Atomic indices that describe a torsion angle
        :raise ValueError: If data is not a list of 3 elements
        """
        data = struct.property.get(msprops.QE_TORS_ATOM_CONSTR)
        if data is None:
            constraints = []
        else:
            constraints = json.loads(data)
        if len(set(indices)) != 4:
            raise ValueError('torsion angle must be defined with 4 integers')
        constraints.append(indices)
        struct.property[msprops.QE_TORS_ATOM_CONSTR] = json.dumps(constraints) 
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        if getattr(self, 'struct', None) is None:
            return ''
        # Deal with atomic coordinates constraints
        ret = ('<free_positions rank="2" dims="3 %d" order="F">\n' %
               self.struct.atom_total)
        for atom in self.struct.atom:
            data_s = atom.property.get(msprops.QE_CART_ATOM_CONSTR,
                                       json.dumps([1, 1, 1]))
            ret += '%d %d %d\n' % tuple(json.loads(data_s))
        ret += '</free_positions>\n'
        data = self.struct.property.get(msprops.QE_TORS_ATOM_CONSTR)
        if data is None:
            return ret
        # Deal with torsion constraints
        ret += '<atomic_constraints>\n'
        constraints = json.loads(data)
        for indices in constraints:
            atoms = [self.struct.atom[idx] for idx in indices]
            ret += '<atomic_constraint>\n'
            ret += '<constr_type>torsional_angle</constr_type>\n'
            ret += '<constr_parms>'
            ret += ' '.join([str(idx - 1) for idx in indices])
            ret += '</constr_parms>\n'
            ret += '<constr_target>%.5f</constr_target>\n' % self.struct.measure(
                *atoms)
            ret += '</atomic_constraint>\n'
        ret += '<num_of_constraints>%d</num_of_constraints>\n' % len(
            constraints)
        ret += '<tolerance>1.e-5</tolerance>\n'
        ret += '</atomic_constraints>\n'
        return ret 
[docs]class MdType(GenericType):
    """
    Class to generate XML section related to MD control.
    """
    DEFAULTS = {
        'ion_temperature': next(iter(ION_TEMP.values())),
        'timestep': 2.0,  # in femtoseconds
        'tempw': 300.0,
        'deltat': 1.0,
        'nraise': 1
    }
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        validate_value('ion_temperature', self.ion_temperature,
                       ION_TEMP.values())
        ret = '<md>\n'
        ret += '<pot_extrapolation>atomic</pot_extrapolation>\n'
        ret += '<wfc_extrapolation>none</wfc_extrapolation>\n'
        ret += '<ion_temperature>%s</ion_temperature>\n' % self.ion_temperature
        ret += '<timestep>%.5e</timestep>\n' % (self.timestep / qeu.AU2FS)
        ret += '<tempw>%.5e</tempw>\n' % self.tempw
        ret += '<tolp>100.</tolp>\n'
        ret += '<deltaT>%.5e</deltaT>\n' % self.deltat
        ret += '<nraise>%d</nraise>\n' % self.nraise
        ret += '</md>\n'
        return ret 
[docs]class IonControlType(GenericType):
    """
    Class to generate QE ionic control XML section related to control.
    """
    ION_DYNAMICS_KEY = 'ion_dynamics'
    DEFAULTS = {
        ION_DYNAMICS_KEY: next(iter(ION_DYNAMICS.values())),
        'dynamics': MdType.DEFAULTS
    }
    @property
    def dynamics(self):
        """
        Getter for the dynamics attribute.
        :rtype: MdType
        :return: dynamics object
        """
        return self._dynamics
    @dynamics.setter
    def dynamics(self, data):
        """
        Setter for the dynamics attribute.
        :type data: dict
        :param data: Dictionary with settings for dynamics object
        """
        if not hasattr(self, 'dynamics'):
            self._dynamics = MdType()
        self.dynamics.updateWithData(data)
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        dynamics_list = list(ION_DYNAMICS.values()) + list(MD_DYNAMICS.values())
        validate_value('ion_dynamics', self.ion_dynamics, dynamics_list)
        ret = '<ion_control>\n'
        ret += '<ion_dynamics>%s</ion_dynamics>\n' % self.ion_dynamics
        ret += str(self.dynamics)
        ret += '</ion_control>\n'
        return ret 
[docs]class CellControlType(GenericType):
    """
    Class to generate QE cell control XML section related to control.
    """
    CELL_DYNAMICS_KEY, PRESSURE_KEY, CELL_FACTOR_KEY, CELL_DOFREE_KEY = \
        
'cell_dynamics', 'pressure', 'cell_factor', 'cell_dofree'
    DEFAULTS = {
        CELL_DYNAMICS_KEY: list(CELL_DYNAMICS.values())[0],
        PRESSURE_KEY: 0.0,
        CELL_FACTOR_KEY: 2.0,
        CELL_DOFREE_KEY: 'all',
    }
    CELL_DOFREE_XML = {
        'shape': 'fix_volume',
        '2Dshape': 'fix_area',
        CELL_DOFREE_XY: 'fix_xy'
    }
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        validate_value('cell_dynamics', self.cell_dynamics,
                       list(CELL_DYNAMICS.values()))
        validate_value('cell_dofree', self.cell_dofree,
                       list(CELL_DOFREE.values()))
        ret = '<cell_control>\n'
        ret += '<cell_dynamics>%s</cell_dynamics>\n' % self.cell_dynamics
        ret += '<pressure>%f</pressure>\n' % self.pressure
        ret += '<cell_factor>%f</cell_factor>\n' % self.cell_factor
        if self.cell_dofree in self.CELL_DOFREE_XML:
            tag = self.CELL_DOFREE_XML[self.cell_dofree]
            ret += '<%s>true</%s>\n' % (tag, tag)
        ret += '</cell_control>\n'
        return ret 
[docs]class EsmType(GenericType):
    """
    Class to generate QE input XML section related to esm type.
    """
    BC_KEY, BC_W_KEY, BC_EFIELD_KEY = 'bc', 'offset', 'efield'
    DEFAULTS = {BC_KEY: ESM_BC_BC1, BC_W_KEY: 0.0, BC_EFIELD_KEY: 0.0}
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        validate_value('bc', self.bc, ESM_BC)
        ret = '<esm>\n'
        ret += '<bc>%s</bc>\n' % self.bc
        ret += '<nfit>4</nfit>\n'  # Hard-code default PW input for now
        ret += '<w>%s</w>\n' % (self.offset * qeu.A2B)
        ret += '<efield>%s</efield>\n' % self.efield
        ret += '</esm>\n'
        return ret 
[docs]class BoundaryType(GenericType):
    """
    Class to generate QE input XML section related to boundary conditions type.
    """
    ASSUME_ISOLATED_KEY, ESM_KEY = 'assume_isolated', 'esm'
    EFERMI_KEY, RELATIVE_POT_KEY = 'efermi', 'relative_pot'
    DEFAULTS = {
        ASSUME_ISOLATED_KEY: '',
        ESM_KEY: EsmType.DEFAULTS,
        EFERMI_KEY: 0.0,
        RELATIVE_POT_KEY: None
    }
    @property
    def esm(self):
        """
        Getter for the esm attribute.
        :rtype: EsmType
        :return: esm object
        """
        return self._esm
    @esm.setter
    def esm(self, data):
        """
        Setter for the esm attribute.
        :type data: dict
        :param data: Dictionary with settings for esm object
        """
        if not hasattr(self, 'esm'):
            self._esm = EsmType()
        self.esm.updateWithData(data)
    def _validateAssumeIsolated(self):
        """
        Validate attributes' values
        :raise ValueError: If attributes have invalid values
        """
        validate_value('assume_isolated', self.assume_isolated, ASSUME_ISOLATED)
        if self.relative_pot is not None and self.esm.bc == ESM_BC_BC3 and (
                self.assume_isolated != ASSUME_ISOLATED_ESM or
                self.efermi is None):
            raise ValueError('relative_pot is set, thus assume_isolated must '
                             'be set to "%s", bc to "%s" and "efermi" set.' %
                             (ASSUME_ISOLATED_ESM, ESM_BC_BC3))
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        self._validateAssumeIsolated()
        if not self.assume_isolated:
            return ''
        ret = '<boundary_conditions>\n'
        ret += '<assume_isolated>%s</assume_isolated>\n' % self.assume_isolated
        if self.assume_isolated == ASSUME_ISOLATED_ESM:
            ret += str(self.esm)
        if self.relative_pot is not None and self.esm.bc == ESM_BC_BC3:
            relative_pot_ry = self.relative_pot / qeu.RY2EV
            ret += '<fcp_opt>true</fcp_opt>\n'
            ret += '<fcp_mu>%f</fcp_mu>\n' % (self.efermi + relative_pot_ry)
        ret += '</boundary_conditions>\n'
        return ret 
[docs]class NEBPathType(GenericType):
    """
    Class to generate QE input XML section related to the NEB path.
    """
    MIMAGE_KEY, NIMAGE_KEY = 'minimum_image', 'nimages'
    DEFAULTS = {
        'string_method': list(NEB_STRING_TYPE.values())[0],
        NIMAGE_KEY: 4,
        'nsteps': 50,
        'path_thr': 0.05,
        'optimization_scheme': list(NEB_OPT_SCHEME_TYPE.values())[0],
        'climbing_image': list(NEB_CI_TYPE.values())[0],
        'use_masses': False,
        'optimize_first_last': False,
        'esm': False,
        'esm_efermi': 0.0,
        'esm_first_image_charge': 0.0,
        'esm_last_image_charge': 0.0,
        'restart_mode': 'from_scratch',
        MIMAGE_KEY: True
    }
    def _validate(self):
        """
        Validate attributes.
        :raise ValueError: If attributes have invalid values
        """
        validate_value('string_method', self.string_method,
                       list(NEB_STRING_TYPE.values()))
        if self.nimages < 4 or self.nimages > NEB_MAX_NIMAGES:
            raise ValueError(
                'nimages (%d) must be greater or equal than 4, but '
                'smaller than %d.' % (self.nimages, NEB_MAX_NIMAGES))
        if self.nsteps < 1:
            raise ValueError('nsteps (%d) must be greater or equal than 1.' %
                             self.nsteps)
        validate_value('optimization_scheme', self.optimization_scheme,
                       list(NEB_OPT_SCHEME_TYPE.values()))
        validate_value('climbing_image', self.climbing_image,
                       list(NEB_CI_TYPE.values()))
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        self._validate()
        ret = '<path>\n'
        ret += '<stringMethod>%s</stringMethod>\n' % self.string_method
        ret += '<restart_mode>%s</restart_mode>\n' % self.restart_mode
        ret += '<pathNstep>%d</pathNstep>\n' % self.nsteps
        ret += '<numOfImages>%d</numOfImages>\n' % self.nimages
        ret += ('<optimizationScheme>%s</optimizationScheme>\n' %
                self.optimization_scheme)
        ret += '<climbingImage>%s</climbingImage>\n' % self.climbing_image
        ret += ('<endImagesOptimizationFlag>%s</endImagesOptimizationFlag>\n' %
                str(self.optimize_first_last).lower())
        ret += ('<minimumImageFlag>%s</minimumImageFlag>\n' %
                str(self.minimum_image).lower())
        ret += '<optimizationStepLength>1.</optimizationStepLength>\n'
        ret += '<pathThreshold>%.5e</pathThreshold>\n' % self.path_thr
        ret += '<elasticConstMax>0.1</elasticConstMax>\n'
        ret += '<elasticConstMin>0.1</elasticConstMin>\n'
        ret += ('<useMassesFlag>%s</useMassesFlag>\n' %
                str(self.use_masses).lower())
        ret += '<useFreezingFlag>false</useFreezingFlag>\n'
        if self.esm:
            ret += '<constantBiasFlag>true</constantBiasFlag>\n'
            ret += ('<targetFermiEnergy>%.5e</targetFermiEnergy>\n' %
                    self.esm_efermi)
            ret += ('<totChargeFirst>%.5e</totChargeFirst>\n' %
                    self.esm_first_image_charge)
            ret += ('<totChargeLast>%.5e</totChargeLast>\n' %
                    self.esm_last_image_charge)
        ret += '</path>\n'
        return ret 
[docs]class PHControlType(GenericType):
    """
    Class to generate phonon control input XML section related to control.
    """
    QPTS_MESH_KEY = 'qpts_mesh'
    QPTS_SPACING_KEY = 'qpts_spacing'
    DEFAULTS = {
        'prefix': '',
        'epsil': False,
        QPTS_MESH_KEY: [1, 1, 1],
        'lraman': False,
    }
[docs]    def setStructure(self, struct):
        """
        Copy structure into self.struct.
        :type struct: `structure.Structure`
        :param struct: Structure used for k-point generation ('kpts_dens' case)
        """
        self.struct = struct.copy() 
    def __str__(self):
        """
        Return XML.
        :rtype: str
        :return: XML string
        """
        if not self.prefix:
            raise ValueError('prefix is required and it is not set')
        nat_todo = set()
        for idx, atom in enumerate(self.struct.atom, start=1):
            atom_constr = json.dumps([0, 0, 0])
            if atom.property.get(msprops.QE_CART_ATOM_CONSTR) != atom_constr:
                nat_todo.add(idx)
        ret = '<files>\n'
        ret += '<prefix>%s</prefix>\n' % self.prefix
        ret += '<fildyn>%s%s</fildyn>\n' % (self.prefix, DYN_EXT)
        ret += '</files>\n'
        ret += '<scf_ph>\n'
        ret += '<tr2_ph>1.0e-14</tr2_ph>\n'
        ret += '</scf_ph>\n'
        ret += '<control_ph>\n'
        ret += '<ldisp>true</ldisp>\n'
        ret += '<epsil>%s</epsil>\n' % str(self.epsil).lower()
        ret += '<trans>true</trans>\n'
        ret += '<zeu>false</zeu>\n'
        ret += '<zue>false</zue>\n'
        ret += '<lraman>%s</lraman>\n' % str(self.lraman).lower()
        ret += '</control_ph>\n'
        ret += '<irr_repr>\n'
        ret += '<nat_todo natom="%d">\n' % len(nat_todo)
        for atom in sorted(nat_todo):
            ret += '<atom>%d</atom>\n' % atom
        ret += '</nat_todo>\n'
        ret += '</irr_repr>\n'
        ret += '<q_points>\n'
        ret += '<grid nq1="%d" nq2="%d" nq3="%d"/>\n' % tuple(self.qpts_mesh)
        ret += '</q_points>\n'
        return ret 
[docs]class PHDynmatType(GenericType):
    """
    Class to generate dynmat control input.
    """
    DEFAULTS = {
        'asr': next(iter(DYNMAT_ASR.values())),
        'q_dir': [0, 0, 0],
        'lperm': False,
        'lplasma': False,
        'loto_2d': False,
        'axis': 3
    }
    def __str__(self):
        """
        Return input string.
        :rtype: str
        :return: Input string
        """
        validate_value('asr', self.asr, DYNMAT_ASR.values())
        validate_value('axis', self.axis, range(1, 4))
        ret = [
            'q(1)=%.5f, q(2)=%.5f, q(3)=%.5f' % tuple(self.q_dir),
            "asr='%s'" % self.asr,
            'axis=%d' % self.axis,
            'lperm=.%s.' % str(self.lperm).lower(),
            'lplasma=.%s.' % str(self.lplasma).lower(),
            'loto_2d=.%s.' % str(self.loto_2d).lower()
        ]
        ret_str = '\n'.join(ret)
        return ret_str 
[docs]class ElasticType(GenericType):
    """
    Elastic settings class.
    """
    STRAIN_DEFO_KEY = 'deformation_matrix'
    DEFAULTS = {STRAIN_DEFO_KEY: []} 
[docs]class GIPAWType(GenericType):
    """
    gipaw.x input file class.
    """
    JOBTYPE, Q_GIPAW, SPLINE_PS = 'jobtype', 'q_gipaw', 'spline_ps'
    USE_NMR_MACROSCOPIC_SHAPE = 'use_nmr_macroscopic_shape'
    DEFAULTS = {
        JOBTYPE: 'nmr',
        Q_GIPAW: 0.01,
        SPLINE_PS: True,
        USE_NMR_MACROSCOPIC_SHAPE: False
    } 
[docs]class EpsilonType(GenericType):
    """
    Epsilon.x input file class.
    """
    SMEARTYPES = ['lorentz', 'gauss']
    CALCULATIONS = ['eps', 'jdos', 'offdiag']
    INTERSMEAR_KEY, INTRASMEAR_KEY = 'intersmear', 'intrasmear'
    DEFAULTS = {
        'smeartype': 'lorentz',
        'calculation': 'eps',
        'prefix': '',
        'shift': 0.,
        INTERSMEAR_KEY: 0.136,
        INTRASMEAR_KEY: 0.,
        'wmax': 30.,
        'wmin': 0.,
        'nw': 60,
    } 
[docs]class PPlotType(GenericType):
    """
    pp.x input file class.
    """
    DEFAULTS = {'plot_num': 0, 'spin_component': 0} 
[docs]class HPType(GenericType):
    """
    hp.x input file class.
    """
    QPTS_MESH_KEY = 'qpts_mesh'
    QPTS_SPACING_KEY = 'qpts_spacing'
    DISABLE_TYPE_KEY = 'disable_type_analysis'
    DEFAULTS = {
        QPTS_MESH_KEY: [2, 2, 2],
        QPTS_SPACING_KEY: 0,
        DISABLE_TYPE_KEY: True
    } 
[docs]class CustomNEBType(NamedTuple):
    """
    Tuple to hold custom neb input.
    """
    method: str = list(CUSTOM_NEB_METHOD.values())[0]
    opt_method: str = list(CUSTOM_NEB_OPT_METHOD.values())[0]
    nsteps: int = 50
    path_thr: float = 0.05
    climb: bool = True