"""
Classes and functions for creating crystals by unit cell.
Copyright Schrodinger, LLC. All rights reserved."""
import typing
from collections import defaultdict
import numpy
# Please DO NOT add any dependencies, to prevent circular imports
from schrodinger.infra import mm
SYMMOP_ROUND = 2
# Global variable
_SPACE_GROUPS = None
[docs]def get_spacegroups():
    """
    Get space groups object.
    :rtype: SpaceGroups
    :return: Cached space groups object, do not modify!!
    """
    global _SPACE_GROUPS
    if _SPACE_GROUPS is None:
        _SPACE_GROUPS = SpaceGroups()
    return _SPACE_GROUPS 
[docs]def get_symmops_from_spglib(symm):
    """
    Get set of symmetry operators from a set of rotations and translations.
    Symmetry operator is defined with a 4 x 4 matrix, top left 3 x 3 is rotation
    matrix, top right column 1 x 3 is translation, rest, bottom row is [0 0 0 1]
    :type symm: dict of two keys: 'rotations': list of 3 x 3 matrices,
        'translations': list of 1 x 3 matrices
    :param symm: dictionary of list of rotations and translations
    :return: list of 4x4 matrices
    :rtype: list of symmetry operators
    """
    ret = []
    for rot, tran in zip(symm['rotations'], symm['translations']):
        symm_op = numpy.zeros(shape=(4, 4))
        symm_op[3, 3] = 1
        symm_op[0:3, 0:3] = rot.copy()
        # Sometime spglib has -1, -0, 1 in translations (which is the same as 0)
        translations = numpy.round(tran, SYMMOP_ROUND)
        translations[translations == 1.] = 0.
        translations[translations == -1.] = 0.
        translations[translations == -0.] = 0.
        symm_op[0:3, 3] = translations
        ret.append(symm_op)
    return ret 
[docs]class CrystalSystems(object):
    """
    Manage the properties of the seven crystal systems.
    """
    TRICLINIC_NAME = 'triclinic'
    MONOCLINIC_NAME = 'monoclinic'
    ORTHORHOMBIC_NAME = 'orthorhombic'
    TETRAGONAL_NAME = 'Tetragonal'
    TRIGONAL_NAME = 'trigonal'
    HEXAGONAL_NAME = 'hexagonal'
    CUBIC_NAME = 'cubic'
[docs]    class Triclinic(object):
        """
        Manage the triclinic system.
        """
[docs]        def __init__(self, name):
            """
            Create an instance.
            :type name: str
            :param name: crystal system name
            """
            self.name = name  
[docs]    class Monoclinic(object):
        """
        Manage the monoclinic system.
        """
[docs]        def __init__(self, name):
            """
            Create an instance.
            :type name: str
            :param name: crystal system name
            """
            self.name = name  
[docs]    class Orthorhombic(object):
        """
        Manage the orthorhombic system.
        """
[docs]        def __init__(self, name):
            """
            Create an instance.
            :type name: str
            :param name: crystal system name
            """
            self.name = name  
[docs]    class Tetragonal(object):
        """
        Manage the tetragonal system.
        """
[docs]        def __init__(self, name):
            """
            Create an instance.
            :type name: str
            :param name: crystal system name
            """
            self.name = name  
[docs]    class Trigonal(object):
        """
        Manage the trigonal system.
        """
[docs]        def __init__(self, name):
            """
            Create an instance.
            :type name: str
            :param name: crystal system name
            """
            self.name = name  
[docs]    class Hexagonal(object):
        """
        Manage the hexagonal system.
        """
[docs]        def __init__(self, name):
            """
            Create an instance.
            :type name: str
            :param name: crystal system name
            """
            self.name = name  
[docs]    class Cubic(object):
        """
        Manage the cubic system.
        """
[docs]        def __init__(self, name):
            """
            Create an instance.
            :type name: str
            :param name: crystal system name
            """
            self.name = name  
[docs]    def getCrystalSystem(self, name):
        """
        Return the crystal system object for the crystem
        system of the provided name.
        :type name: str
        :param name: crystal system name
        :rtype: one of the seven crystal system objects
        :return: crystal_system_obj
        """
        if name == self.TRICLINIC_NAME:
            crystal_system_obj = self.Triclinic(name)
        elif name == self.MONOCLINIC_NAME:
            crystal_system_obj = self.Monoclinic(name)
        elif name == self.ORTHORHOMBIC_NAME:
            crystal_system_obj = self.Orthorhombic(name)
        elif name == self.TETRAGONAL_NAME:
            crystal_system_obj = self.Tetragonal(name)
        elif name == self.TRIGONAL_NAME:
            crystal_system_obj = self.Trigonal(name)
        elif name == self.HEXAGONAL_NAME:
            crystal_system_obj = self.Hexagonal(name)
        elif name == self.CUBIC_NAME:
            crystal_system_obj = self.Cubic(name)
        return crystal_system_obj  
[docs]class SpaceGroup(typing.NamedTuple):
    """
    Collect the properties of a space group.
    """
    DEFINITION_ID = 'Def. ID'
    SPACE_GROUP_ID = 'Space Group ID'
    CRYSTAL_SYSTEM = 'Crystal System'
    SHORT_HERMANN_MAUGUIN_SYMBOL = 'Short H.-M. Symbol'
    FULL_HERMANN_MAUGUIN_SYMBOL = 'Full H.-M. Symbol'
    POINT_GROUP_NAME = 'Point Group'
    NUM_CENTERING_OPERS = 'N Centering Ops.'
    NUM_PRIMARY_OPERS = 'N Primary Ops.'
    NUM_SYMMETRY_OPERS = 'N Symmetry Ops.'
    CENTERING_OPERS = 'Centering Operators'
    PRIMARY_OPERS = 'Primary Operators'
    SYMMETRY_OPERS = 'Symmetry Operators'
    SPACE_GROUP_SETTING = 'Space Group Setting'
    ID_TAG = 'spgid  '
    SHORT_NAME_TAG = 'sspgname  '
    FULL_NAME_TAG = 'fspgname  '
    POINT_GROUP_TAG = 'pgname    '
    CRYSTAL_SYSTEM_TAG = 'crysym    '
    SETTING_TAG = 'setting   '
    ASU_TAG = 'xyzasu    '
    PRIMARY_OPERATIONS_TAG = 'primoper  '
    CENTERING_OPERATIONS_TAG = 'centoper  '
    END_OF_DEF_TAG = 'endofdef'
    definition_id: int
    space_group_id: int
    ichoice: int
    num_choices: int
    space_group_short_name: str
    space_group_full_name: str
    point_group_name: str
    centering_opers: typing.List
    primary_opers: typing.List
    symmetry_opers: typing.List
    num_centering_opers: int
    num_primary_opers: int
    num_symmetry_opers: int
    centering_opers_strs: typing.List
    primary_opers_strs: typing.List
    symmetry_opers_strs: typing.List
    crystal_system: CrystalSystems
    xyzasu: str
    spg_setting: str
[docs]    @classmethod
    def fromData(self, definition_id, space_group_id, ichoice, num_choices,
                 space_group_short_name, space_group_full_name,
                 point_group_name, centering_opers, primary_opers,
                 symmetry_opers, centering_opers_strs, primary_opers_strs,
                 symmetry_opers_strs, crystal_system, xyzasu, spg_setting):
        """
        Create an instance.
        :type definition_id: int
        :param definition_id: the id of the definition, i.e.
            a number ranging from 1 to 291 (some of the 230 space
            groups have more than a single unit cell definition).
        :type space_group_id: int
        :param space_group_id: the id of the space group, i.e.
            a number ranging from 1 to 230, which the number of
            space groups.
        :type ichoice: int
        :param ichoice: the space group setting index
        :type num_choices: int
        :param num_choices: the number of different unit cell
            settings for this space group.  For example, a setting
            may be a choice of axes, etc.
        :type space_group_short_name: string
        :param space_group_short_name: the short Hermann-Mauguin symbol
            of the space group.
        :type space_group_full_name: string
        :param space_group_full_name: the full Hermann-Mauguin symbol
            of the space group.
        :type point_group_name: string
        :param point_group_name: the name of the point group of the
            space group.
        :type centering_opers: list of numpy.array
        :param centering_opers: contains the centering matricies
            of the space group.
        :type primary_opers: list of numpy.array
        :param primary_opers: contains the primary matricies
            of the space group.
        :type symmetry_opers: list of numpy.array
        :param symmetry_opers: contains the symmetry matricies
            of the space group, i.e. the combinations of the centering
            and primary matricies.
        :type centering_opers_strs: list
        :param centering_opers_strs: string representation of the
            centering operators.
        :type primary_opers_strs: list
        :param primary_opers_strs: string representation of the
            primary operators.
        :type symmetry_opers_strs: list
        :param symmetry_opers_strs: string representation of the
            symmetry operators, i.e. the combination of the centering
            and primary string representations.
        :type crystal_system: one of the sevel crystal system objects
        :param crystal_system: the crystal system.
        :type xyzasu: str
        :param xyzasu: the xyzasu descriptor which will be parsed but
            not used
        :type spg_setting: str
        :param spgsetting: the setting of the space group
        """
        return SpaceGroup(definition_id=definition_id,
                          space_group_id=space_group_id,
                          ichoice=ichoice,
                          num_choices=num_choices,
                          space_group_short_name=space_group_short_name,
                          space_group_full_name=space_group_full_name,
                          point_group_name=point_group_name,
                          centering_opers=centering_opers,
                          primary_opers=primary_opers,
                          symmetry_opers=symmetry_opers,
                          num_centering_opers=len(centering_opers),
                          num_primary_opers=len(primary_opers),
                          num_symmetry_opers=len(symmetry_opers),
                          centering_opers_strs=centering_opers_strs,
                          primary_opers_strs=primary_opers_strs,
                          symmetry_opers_strs=symmetry_opers_strs,
                          crystal_system=crystal_system,
                          xyzasu=xyzasu,
                          spg_setting=spg_setting) 
    def __repr__(self):
        """
        Representation of this class.
        """
        spgrepr = '%s   %s   %s   %s   %s   %s   %s   %s   %s   %s' % \
            
(str(self.definition_id).ljust(len(self.DEFINITION_ID)),
             str(self.space_group_id).ljust(len(self.SPACE_GROUP_ID)),
             str(self.crystal_system.name).rjust(len(self.CRYSTAL_SYSTEM)),
             str(self.space_group_short_name).rjust( \
                 
len(self.SHORT_HERMANN_MAUGUIN_SYMBOL)),
             str(self.space_group_full_name).rjust( \
                 
len(self.FULL_HERMANN_MAUGUIN_SYMBOL)),
             str(self.point_group_name).rjust( \
                 
len(self.POINT_GROUP_NAME)),
             str(self.num_centering_opers).rjust( \
                 
len(self.NUM_CENTERING_OPERS)),
             str(self.num_primary_opers).rjust( \
                 
len(self.NUM_PRIMARY_OPERS)),
             str(self.num_symmetry_opers).rjust( \
                 
len(self.NUM_SYMMETRY_OPERS)),
             str(self.spg_setting).rjust( \
                 
len(self.SPACE_GROUP_SETTING)))
        return spgrepr
[docs]    def printSymmetryOpers(self, logger=None):
        """
        Log a formatted print of all of the symmetry operators for
        this space group.
        :type logger: logging.getLogger
        :param logger: output logger
        """
        if logger:
            logger.info(self.space_group_full_name)
            logger.info('-' * len(self.space_group_full_name))
            logger.info('')
            logger.info(self.CENTERING_OPERS)
            logger.info('-' * len(self.CENTERING_OPERS))
            logger.info('')
            for oper_str, oper in zip(self.centering_opers_strs,
                                      self.centering_opers):
                logger.info(oper_str)
                for row in oper:
                    logger.info(row)
                logger.info('')
            logger.info(self.PRIMARY_OPERS)
            logger.info('-' * len(self.PRIMARY_OPERS))
            logger.info('')
            for oper_str, oper in zip(self.primary_opers_strs,
                                      self.primary_opers):
                logger.info(oper_str)
                for row in oper:
                    logger.info(row)
                logger.info('')
            logger.info(self.SYMMETRY_OPERS)
            logger.info('-' * len(self.SYMMETRY_OPERS))
            logger.info('')
            for oper_str, oper in zip(self.symmetry_opers_strs,
                                      self.symmetry_opers):
                logger.info(oper_str)
                for row in oper:
                    logger.info(row)
                logger.info('') 
[docs]    def printDatabaseEntry(self, logger=None):
        """
        Print a space group object in mmspg/spgbase.dat format.
        :type logger: logging.getLogger
        :param logger: output logger
        """
        spg_id = self.ID_TAG + str(self.space_group_id)
        spg_short_name = self.SHORT_NAME_TAG + self.space_group_short_name
        spg_full_name = self.FULL_NAME_TAG + self.space_group_full_name
        spg_point_group = self.POINT_GROUP_TAG + self.point_group_name
        spg_crystal_system = self.CRYSTAL_SYSTEM_TAG + self.crystal_system.name
        spg_setting = self.SETTING_TAG + self.spg_setting
        spg_xyzasu = self.ASU_TAG + self.xyzasu
        spg_eod = self.END_OF_DEF_TAG
        logger.info(spg_id)
        logger.info(spg_short_name)
        logger.info(spg_full_name)
        logger.info(spg_point_group)
        logger.info(spg_crystal_system)
        logger.info(spg_setting)
        logger.info(spg_xyzasu)
        for oper in self.primary_opers_strs:
            logger.info(self.PRIMARY_OPERATIONS_TAG + oper)
        for oper in self.centering_opers_strs:
            logger.info(self.CENTERING_OPERATIONS_TAG + oper)
        logger.info(spg_eod)
        logger.info('')  
[docs]class SpaceGroups(object):
    """
    Manage space group objects.
    """
    NUM_SPACE_GROUPS = 230
    CRYSTAL_SYSTEMS = CrystalSystems()
    CENTERING = 'centering'
    PRIMARY = 'primary'
    SYMMETRY = 'symmetry'
[docs]    def __init__(self):
        """
        Create an instance.
        """
        self.getAllSpaceGroups() 
[docs]    def getAllSpaceGroups(self):
        """
        Make a list of all SpaceGroup objects each of which contains
        some space group parameters from mmspg/spgbase.dat.
        """
        def unpack_symmetry_opers(symtype):
            # get the number of operators and the operators themselves
            if symtype == self.CENTERING:
                nsym = mm.mmspg_num_cent_oper_get(handle)
                symlist = mm.mmspg_cent_oper_get(handle, nsym)
            elif symtype == self.PRIMARY:
                nsym = mm.mmspg_num_prim_oper_get(handle)
                symlist = mm.mmspg_prim_oper_get(handle, nsym)
            elif symtype == self.SYMMETRY:
                nsym = mm.mmspg_num_symmetry_oper_get(handle)
                symlist = mm.mmspg_symmetry_oper_get(handle, nsym)
            symlist = numpy.array(symlist)
            symlist = numpy.reshape(symlist, (nsym, 4, 4))
            # collect the operators and string representations thereof
            symopers = []
            symopersstrs = []
            for oper in symlist:
                symopers.append(tuple(tuple(row) for row in oper))
                operstr = mm.mmspg_get_string_oper_by_matrix(
                    oper.flatten().tolist())
                symopersstrs.append(operstr)
            return tuple(symopers), tuple(symopersstrs)
        # initialize and load the spgbase.dat file contents
        definition_id = 0
        mm.mmspg_initialize()
        mm.mmspg_load_spgbase()
        all_spg_objs = []
        self.spg_objs = defaultdict(list)
        all_spg_full_symbols = []
        all_spg_short_symbols = []
        # loop over space group ids 1..230 and loop over the different
        # settings for each space group id, this includes the standard
        # setting any non-standard settings as well as some alias space
        # group symbols
        for spgid in range(1, self.NUM_SPACE_GROUPS + 1):
            nchoices = mm.mmspg_nchoices_get(spgid)
            for ichoice in range(1, nchoices + 1):
                definition_id += 1
                # get the mm handle
                handle = mm.mmspg_spg_of_choice_get(spgid, ichoice)
                # collect short names
                spgshortname = mm.mmspg_short_spgname_get(handle)
                all_spg_short_symbols.append(spgshortname)
                # collect full names
                spgfullname = mm.mmspg_full_spgname_get(handle)
                all_spg_full_symbols.append(spgfullname)
                # get the point group name
                pgname = mm.mmspg_pgname_get(handle)
                # for each type of symmetry operator, (1) centering, (2) primary,
                # and (3) the "real" final symmetry operator (centering*primary),
                # get the operators themselves as well as their string
                # representations
                centopers, centopers_strs = unpack_symmetry_opers(
                    self.CENTERING)
                primopers, primopers_strs = unpack_symmetry_opers(self.PRIMARY)
                symopers, symopers_strs = unpack_symmetry_opers(self.SYMMETRY)
                # get the crystal system name and object
                crysys = mm.mmspg_crystal_system_name_get(handle)
                crysys = self.CRYSTAL_SYSTEMS.getCrystalSystem(crysys)
                # get the xyz ASU descriptor even though it is never used anywhere
                # in our suite
                xyzasu = mm.mmspg_get_xyzasu(handle)
                # get the space group setting label
                spgsetting = mm.mmspg_get_setting(handle)
                # make and collect the SpaceGroup object
                spgobj = SpaceGroup.fromData(
                    definition_id, spgid, ichoice, nchoices, spgshortname,
                    spgfullname, pgname, centopers, primopers, symopers,
                    centopers_strs, primopers_strs, symopers_strs, crysys,
                    xyzasu, spgsetting)
                all_spg_objs.append(spgobj)
                self.spg_objs[spgid].append(spgobj)
        # Convert to immutable tuples
        self.all_spg_objs = tuple(all_spg_objs)
        self.all_spg_full_symbols = tuple(all_spg_full_symbols)
        self.all_spg_short_symbols = tuple(all_spg_short_symbols)
        # Convert values from list to tuple
        self.spg_objs = {k: tuple(v) for k, v in self.spg_objs.items()}
        # clean up
        mm.mmspg_terminate() 
[docs]    def getSpgObjByName(self, name, first=True):
        """
        Get a space group object by name.
        :param str name: short name (HM symbol) checked first.
            If short_only=False, long symbol checked second. The first space
            group encountered with such a name is returned.
        :param bool first: If True, returns the first occurrence based on the
            symmetry operators
        :rtype: SpaceGroup or None
        :return: Space group object or None if not found
        """
        name_ns = name.replace(' ', '')
        for spgid, spgobjs in self.spg_objs.items():
            for spgobj in spgobjs:
                if (name_ns in [
                        spgobj.space_group_short_name.replace(' ', ''),
                        spgobj.space_group_full_name.replace(' ', '')
                ]):
                    break
            else:
                spgobj = None
            if spgobj:
                break
        else:
            return
        if not first:
            return spgobj
        for aspgobj in self.spg_objs[spgid]:
            if equal_rotations(spgobj.symmetry_opers, aspgobj.symmetry_opers):
                return aspgobj 
[docs]    def printAllSpgInfo(self, verbose, logger):
        """
        Print all space group information.
        :type verbose: bool
        :param verbose: verbose log
        :type logger: logging.getLogger
        :param logger: output logger
        """
        HEADER = 'Space group parameters'
        logger.info(HEADER)
        logger.info('-' * len(HEADER))
        logger.info('')
        defid = SpaceGroup.DEFINITION_ID
        spid = SpaceGroup.SPACE_GROUP_ID
        crysys = SpaceGroup.CRYSTAL_SYSTEM
        short = SpaceGroup.SHORT_HERMANN_MAUGUIN_SYMBOL
        full = SpaceGroup.FULL_HERMANN_MAUGUIN_SYMBOL
        pgname = SpaceGroup.POINT_GROUP_NAME
        ncent = SpaceGroup.NUM_CENTERING_OPERS
        nprim = SpaceGroup.NUM_PRIMARY_OPERS
        nsym = SpaceGroup.NUM_SYMMETRY_OPERS
        setting = SpaceGroup.SPACE_GROUP_SETTING
        HEADER = '%s   %s   %s   %s   %s   %s   %s   %s   %s   %s' % (defid, spid, \
            
crysys, short, full, pgname, ncent, nprim, nsym, setting)
        logger.info(HEADER)
        logger.info('-' * len(HEADER))
        current_spg_id = 1
        for spgobj in self.all_spg_objs:
            if spgobj.space_group_id != current_spg_id:
                logger.info('')
                current_spg_id = spgobj.space_group_id
            logger.info(spgobj)
        if verbose:
            logger.info('')
            HEADER = 'Symmetry operators by space group name'
            logger.info(HEADER)
            logger.info('-' * len(HEADER))
            logger.info('')
            for spgobj in self.all_spg_objs:
                spgobj.printSymmetryOpers(logger)  
[docs]def equal_rotations(rotations1, rotations2):
    """
    Check if rotations are equal.
    :type rotations1: 3D numpy.array
    :param rotations1: Array of rotation matrices (2D arrays) associated with a
        space group
    :type rotations2: 3D numpy.array
    :param rotations2: Array of rotation matrices (2D arrays) associated with a
        space group
    :rtype: bool
    :return: True, if rotations are the same, otherwise False
    """
    rotations1 = numpy.asarray(rotations1)
    rotations2 = numpy.asarray(rotations2)
    shape1 = numpy.shape(rotations1)
    shape2 = numpy.shape(rotations2)
    if not (shape1 == shape2):
        return False
    rots1_1d = []
    rots2_1d = []
    for idx, (rot1, rot2) in enumerate(zip(rotations1, rotations2)):
        rots1_1d.append(''.join(map(str, numpy.round(rot1.flat, SYMMOP_ROUND))))
        rots2_1d.append(''.join(map(str, numpy.round(rot2.flat, SYMMOP_ROUND))))
    # Rotations 1 and 2 are not necessarily listed in the same order
    return not len(set(rots1_1d).symmetric_difference(rots2_1d))