"""
Utility functions for dealing with Jaguar basis sets
"""
from collections import namedtuple
from schrodinger.infra import mm
from schrodinger.infra import util
from schrodinger.structure import Structure
DUMMY_ATOM_TYPES = {61, 150}
_BasisSet = namedtuple(
    "BasisSet",
    ("name", "backup", "nstar", "nplus", "is_ecp", "is_ps", "numd", "numf"))
[docs]class BasisSet(_BasisSet):
    """
    A Pythonic wrapper for basis set information.
    :ivar name: The name of the basis set
    :vartype name: str
    :ivar backup: The basis set used for non-effective-core-potential atoms
    :vartype backup: str
    :ivar nstar: The availability of polarization functions
    :vartype nstar: int
    :ivar nplus: The availability of diffuse functions
    :vartype nplus: int
    :ivar is_ecp: Does this basis set use effective core potentials on heavy
        atoms?
    :vartype is_ecp: bool
    :ivar is_ps: Is this basis set pseudospectral?
    :vartype is_ps: bool
    :ivar numd: The number of d functions
    :vartype numd: int
    :ivar numf: The number of f functions
    :vartype numf: int
    """
    def __new__(cls, name, backup, nstar, nplus, is_ecp, is_ps, numd, numf):
        """
        Create a BasisSet object using the return values from
        mm.mmjag_basis_get().
        """
        is_ecp = bool(is_ecp)
        is_ps = bool(is_ps)
        return super(BasisSet, cls).__new__(cls, name, backup, nstar, nplus,
                                            is_ecp, is_ps, numd, numf) 
[docs]def mmjag_function(func):
    """
    A decorator that wraps functions with mmjag_initialize and mmjag_terminate.
    """
    def dec(*args, **kwargs):
        mm.mmjag_initialize(mm.MMERR_DEFAULT_HANDLER)
        try:
            return func(*args, **kwargs)
        finally:
            mm.mmjag_terminate()
    return dec 
[docs]@mmjag_function
def get_basis(i):
    """
    Get information about the specified basis set
    :param i: The index of the basis set to get information about
    :type i: int
    :return: Information about the specified basis set
    :rtype: `BasisSet`
    """
    basis_info = mm.mmjag_basis_get(i)
    return BasisSet(*basis_info) 
[docs]@mmjag_function
def get_basis_by_name(name):
    """
    Get information about the specified basis set
    :param i: The name of the basis set to get information about
    :type i: int
    :return: Information about the specified basis set
    :rtype: `BasisSet`
    """
    basis_info = mm.mmjag_basis_get_by_name(name)
    return BasisSet(name, *basis_info) 
[docs]def get_bases():
    """
    Get information about all basis set
    :return: An iterator for basis sets, where each basis set is returned as a
        `BasisSet` object
    :rtype: iter
    """
    # We can't use the decorator here because it finishes when the generator is
    # returned, rather than when the generator is finished iterating
    mm.mmjag_initialize(mm.MMERR_DEFAULT_HANDLER)
    try:
        for i in range(mm.mmjag_basis_count()):
            basis_info = mm.mmjag_basis_get(i)
            yield BasisSet(*basis_info)
    finally:
        mm.mmjag_terminate() 
[docs]@mmjag_function
def default_basis(struc=None):
    """
    Get the default basis set for the specified structure.
    :param struc: The structure to retrieve the default basis set for.  If not
        given, then the default basis set (6-31G**) will be returned.
    :type struc: `schrodinger.structure.Structure` or NoneType
    :return: The appropriate default basis set
    :rtype: str
    """
    if struc is None:
        # Return the default default basis (6-31G**).
        return mm.mmjag_basis_default(mm.MMCT_INVALID_CT)
    elif isinstance(struc, Structure):
        return mm.mmjag_basis_default(struc)
    else:
        raise TypeError("Argument to default_basis() must be a Structure "
                        "object or None") 
[docs]@mmjag_function
def num_functions(basis_name, struc, func_type=mm.MMJAG_BASIS_DEFAULT):
    """
    Calculate the number of basis functions that will be used for the given
    basis set and structure
    :param basis_name: The basis set to determine the number of basis functions
        for
    :type basis_name: str
    :param struc: The structure to determine the number of basis functions for
    :type struc: `schrodinger.structure.Structure`
    :param func_type: Whether d and f functions are counted as Cartesian or non-
        Cartesian.  Must be one of mm.MMJAG_BASIS_DEFAULT, mm.MMJAG_BASIS_CARTESIAN,
        or mm.MMJAG_BASIS_NONCARTESIAN.
    :type func_type: int
    :return: A tuple of:
          - The number of basis functions (int).  Will be zero if the basis set
            does not cover all atoms of the structure.
          - Are pseudospectral calculations possible (bool)
    :rtype: tuple
    """
    if not isinstance(struc, Structure):
        raise TypeError("Second argument to num_functions() must be a "
                        "Structure object")
    num_funcs, is_ps = mm.mmjag_basis_functions(basis_name, func_type, struc)
    return num_funcs, bool(is_ps) 
[docs]@mmjag_function
def num_functions_per_atom(basis_name,
                           struc,
                           atom_num,
                           func_type=mm.MMJAG_BASIS_DEFAULT):
    """
    Calculate the number of basis functions that will be used for the given
    basis set and specified atom
    :param basis_name: The basis set
    :type basis_name: str
    :param struc: The structure containing `atom_num`
    :type struc: `schrodinger.structure.Structure`
    :param atom_num: The atom index in `struc` to calculate the number of basis
        functions for
    :type atom_num: int
    :param func_type: Whether d and f functions are counted as Cartesian or non-
        Cartesian.  Must be one of mm.MMJAG_BASIS_DEFAULT, mm.MMJAG_BASIS_CARTESIAN,
        or mm.MMJAG_BASIS_NONCARTESIAN.
    :type func_type: int
    :return: A tuple of:
          - The number of basis functions (int).  Will be zero if the basis set
            does not cover the atom.
          - Are pseudospectral calculations possible (bool)
    :rtype: tuple
    """
    if not isinstance(struc, Structure):
        raise TypeError("Second argument to num_functions() must be a "
                        "Structure object")
    num_funcs, is_ps = mm.mmjag_basis_functions_per_atom(
        basis_name, func_type, struc, atom_num)
    return num_funcs, bool(is_ps) 
[docs]@mmjag_function
def num_functions_all_atoms(basis_name,
                            struc,
                            per_atom=None,
                            func_type=mm.MMJAG_BASIS_DEFAULT):
    """
    Calculate the number of basis functions that will be used for the given
    basis set and structure
    :param basis_name: The basis set to determine the number of basis functions
        for
    :type basis_name: str
    :param struc: The structure to determine the number of basis functions for
    :type struc: `schrodinger.structure.Structure`
    :param per_atom: An optional dictionary of {atom index: basis name} for
        per-atom basis sets.
    :type per_atom: dict
    :param func_type: Whether d and f functions are counted as Cartesian or non-
        Cartesian.  Must be one of mm.MMJAG_BASIS_DEFAULT, mm.MMJAG_BASIS_CARTESIAN,
        or mm.MMJAG_BASIS_NONCARTESIAN.
    :type func_type: int
    :return: A tuple of:
          - The number of basis functions for the entire structure (int). Will
            be zero if the basis sets do not cover all atoms of the structure.
          - Are pseudospectral calculations possible (bool)
          - A list of the number of basis functions per atom
            (`util.OneIndexedList`)
    :rtype: tuple
    """
    if not isinstance(struc, Structure):
        raise TypeError("Second argument to num_functions() must be a "
                        "Structure object")
    if per_atom is None:
        per_atom = {}
    num_funcs_per_atom = util.OneIndexedList()
    is_ps_per_atom = []
    for atom_num in range(1, struc.atom_total + 1):
        cur_basis = per_atom.get(atom_num, basis_name)
        cur_num_funcs, cur_is_ps = mm.mmjag_basis_functions_per_atom(
            cur_basis, func_type, struc, atom_num)
        is_ps_per_atom.append(cur_is_ps)
        num_funcs_per_atom.append(cur_num_funcs)
    for atom, num_funcs in zip(struc.atom, num_funcs_per_atom):
        if num_funcs == 0 and atom.atom_type not in DUMMY_ATOM_TYPES:
            covered = False
            break
    else:
        covered = True
    if covered:
        num_funcs = sum(num_funcs_per_atom)
    else:
        num_funcs = 0
    is_ps = all(is_ps_per_atom)
    return num_funcs, is_ps, num_funcs_per_atom 
[docs]def parse_basis(basis):
    """
    Parse the given basis set name and determine that number of `*`'s and `+`'s.
    :param basis: The full basis set name
    :type basis: str
    :return: A tuple of
          - The basis set name with the `*`'s and `+`'s stripped (str)
          - The polarization function count (i.e. the number of `*`'s) (int)
          - The diffuse function count (i.e. the number of `+`'s) (int)
    :rtype: tuple
    """
    polarization = basis.count("*")
    diffuse = basis.count("+")
    basis = basis.replace("*", "")
    basis = basis.replace("+", "")
    return basis, polarization, diffuse