"""
Classes and functions for crystal planes.
Copyright Schrodinger, LLC. All rights reserved."""
# Contributor: Thomas F. Hughes
from collections import OrderedDict
from functools import reduce
from math import gcd
from past.utils import old_div
import numpy
from scipy.spatial import ConvexHull
from schrodinger.application.matsci import shapes
from schrodinger.application.matsci.nano import xtal
from schrodinger.math import mathutils
from schrodinger.structutils import transform
_version = '$Revision 0.0 $'
[docs]def ext_gcd(a, b):
    """
    Solve ax + by = gcd(a, b) using the Extended Euclidean Algorithm.
    Return (1) the greatest common divisor (gcd) of integers a and b
    and (2) the integer Bezout coefficients x and y.
    :type a: int
    :param a: the a coefficient
    :type b: int
    :param b: the b coefficient
    :rtype: int, int, int
    :return: gcd(a, b), Bezout coefficient x, and Bezout coefficient y
    """
    iterate = lambda r_prev, q, r: (r_prev - q * r, r)
    r_prev, r = a, b
    x_prev, x = 1, 0
    y_prev, y = 0, 1
    while r != 0:
        q = old_div(r_prev, r)
        (r, r_prev) = iterate(r_prev, q, r)
        (x, x_prev) = iterate(x_prev, q, x)
        (y, y_prev) = iterate(y_prev, q, y)
    return r_prev, x_prev, y_prev 
[docs]def reduce_hkl(hkl):
    """
    Reduce hkl to the smallest set of indices
    param list hkl: Miller indices
    :retype: list, int
    :return: Reduced Miller indices, divisor
    """
    divisor = abs(reduce(gcd, hkl))
    hkl = [int(index / divisor) for index in hkl]
    return hkl, divisor 
[docs]class CrystalPlane(object):
    """
    Manage a crystal plane object.
    """
    SQUARE = OrderedDict([(1, numpy.array([-1.0, 1.0, 0.0])),
                          (2, numpy.array([-1.0, -1.0, 0.0])),
                          (3, numpy.array([1.0, -1.0, 0.0])),
                          (4, numpy.array([1.0, 1.0, 0.0]))])
    DISTANCE_THRESH = -0.001
    SAME_VECTOR_THRESH = 0.0001
    SLAB_THRESHOLD = 0.0000001
[docs]    def __init__(self,
                 h_index,
                 k_index,
                 l_index,
                 a_vec,
                 b_vec,
                 c_vec,
                 origin=None,
                 logger=None):
        """
        Create an instance.
        :type h_index: int
        :param h_index: the h Miller index
        :type k_index: int
        :param k_index: the k Miller index
        :type l_index: int
        :param l_index: the l Miller index
        :type a_vec: numpy.array
        :param a_vec: the a lattice vector
        :type b_vec: numpy.array
        :param b_vec: the b lattice vector
        :type c_vec: numpy.array
        :param c_vec: the c lattice vector
        :type origin: numpy.array
        :param origin: the origin of the lattice vectors in Angstrom
        :type logger: logging.getLogger
        :param logger: output logger
        """
        self.h_index = h_index
        self.k_index = k_index
        self.l_index = l_index
        self.a_vec = a_vec
        self.b_vec = b_vec
        self.c_vec = c_vec
        if origin is None:
            self.origin = numpy.array(xtal.ParserWrapper.ORIGIN)
        else:
            self.origin = origin
        self.logger = logger
        self.ra_vec = None
        self.rb_vec = None
        self.rc_vec = None
        self.inter_planar_separation = None
        self.normal_vec = None
        self.rotate_to_z = None
        self.inv_rotate_to_z = None
        self.checkMillerIndices()
        self.getReciprocals()
        self.getNormal()
        self.u_vec = None
        self.v_vec = None
        self.w_vec = None 
[docs]    def checkMillerIndices(self):
        """
        Check the user provided Miller indices.
        """
        bad_indices_msg = (
            'You have provided the Miller indices '
            '(000).  At least a single Miller index must be non-zero.')
        if self.h_index == self.k_index == self.l_index == 0:
            if self.logger:
                self.logger.error(bad_indices_msg)
            raise ValueError(bad_indices_msg) 
[docs]    def getReciprocals(self):
        """
        Return the reciprocal lattice vectors.
        :rtype: three numpy.array
        :return: the three reciprocal lattice vectors.
        """
        self.ra_vec, self.rb_vec, self.rc_vec = \
            
xtal.get_reciprocal_lattice_vectors(self.a_vec, self.b_vec, self.c_vec)
        return self.ra_vec, self.rb_vec, self.rc_vec 
[docs]    def getNormal(self):
        """
        Return the normal vector.
        :rtype: numpy.array
        :return: the normal vector for this plane
        """
        rnormal_vec = self.h_index * self.ra_vec + self.k_index * self.rb_vec + \
            
self.l_index * self.rc_vec
        self.inter_planar_separation = old_div(
            1.0, transform.get_vector_magnitude(rnormal_vec))
        self.normal_vec = \
            
self.inter_planar_separation * transform.get_normalized_vector(rnormal_vec)
        return self.normal_vec 
[docs]    def getLinDepPlaneVectors(self):
        """
        Return three typically used plane vectors that are linearly dependent.
        :rtype: numpy.array, numpy.array, numpy.array
        :return: typically used plane vectors that are linearly dependent
        """
        a_vec_prime = self.k_index * self.a_vec - self.h_index * self.b_vec
        b_vec_prime = self.l_index * self.a_vec - self.h_index * self.c_vec
        c_vec_prime = self.l_index * self.b_vec - self.k_index * self.c_vec
        return a_vec_prime, b_vec_prime, c_vec_prime 
[docs]    def getSimpleSlabVectors(self):
        """
        This sets the simple, i.e. two of the Miller indices
        are zero, transformation matrix into self.basis and
        sets the simple slab vectors.
        """
        miller_indices = [self.h_index, self.k_index, self.l_index]
        num_zero = miller_indices.count(0)
        if num_zero != 2:
            raise ValueError('Number of null Miller indices is not two.')
        if self.h_index:
            self.basis = numpy.array([(0, 0, 1), (0, 1, 0), (-1, 0, 0)])
        elif self.k_index:
            self.basis = numpy.array([(1, 0, 0), (0, 0, 1), (0, -1, 0)])
        elif self.l_index:
            self.basis = numpy.array([(1, 0, 0), (0, 1, 0), (0, 0, 1)])
        self.u_vec, self.v_vec, self.w_vec = \
            
self.transformVectors(self.a_vec, self.b_vec, self.c_vec) 
[docs]    def getSlabVectors(self):
        """
        This sets the transformation matrix into self.basis.  Basis
        vectors are chosen such that a and b axes are in the plane
        of the Miller plane, and c-axis is out of this plane (NOT
        necessarily normal to it).  Also sets the slab vectors.
        """
        # if this is a simple case, i.e. two of the Miller
        # indices are zero, then just handle it now
        try:
            self.getSimpleSlabVectors()
            return
        except ValueError:
            pass
        # the algorithm used here follows from that given in
        # https://wiki.fysik.dtu.dk/ase/_downloads/general_surface.pdf
        # to summarize, three plane vectors, a', b', and c', are
        # formed from linear combinations of the three lattice vectors,
        # a, b, and c, using the Miller indices:
        # a' = k*a - h*b
        # b' = l*a - h*c
        # c' = l*b - k*c
        # one of these vectors is picked to be one of the two plane
        # vectors while a linear combination of the other two vectors
        # is taken to form the last plane vector.  this linear combination
        # is taken so as to both minimize the planar area spanned by the two
        # planar vectors and to make them as orthogonal as possible.  these
        # are determined subject to their cross product having a length of one
        # inter-planar separation
        # get some typcially used linearly dependent vectors that
        # span the plane
        a_vec_prime, b_vec_prime, c_vec_prime = self.getLinDepPlaneVectors()
        # set the first plane vector with linear combinations of
        # a_vec_prime and b_vec_prime, the linear combination to
        # minimize the area requires solving pk + ql = 1
        gcd_kl, p_coef, q_coef = ext_gcd(self.k_index, self.l_index)
        u_vec = p_coef * a_vec_prime + q_coef * b_vec_prime
        # if we can make the planar vectors more orthogonal then go ahead,
        # in the original document cited above they mention the need for
        # a threshold due to numerical precision however I have not yet
        # seen a case where this is necessary
        numerator = numpy.dot(u_vec, c_vec_prime)
        denominator = \
            
numpy.dot(self.l_index * a_vec_prime - self.k_index * b_vec_prime, \
            
c_vec_prime)
        if abs(denominator) > self.SLAB_THRESHOLD:
            coef = -1 * int(mathutils.roundup(numerator / denominator))
            p_coef += coef * self.l_index
            q_coef -= coef * self.k_index
        # Set transformation matrix into self.basis
        vec1 = p_coef * numpy.array((self.k_index, -self.h_index, 0)) + \
               
q_coef * numpy.array((self.l_index, 0, -self.h_index))
        vec2 = old_div(numpy.array((0, self.l_index, -self.k_index)), \
               
abs(xtal.gcd(self.l_index, self.k_index)))
        gcd_coef, a_coef, b_coef = ext_gcd(
            p_coef * self.k_index + q_coef * self.l_index, self.h_index)
        vec3 = (b_coef, a_coef * p_coef, a_coef * q_coef)
        self.basis = numpy.array([vec1, vec2, vec3]).T
        self.u_vec, self.v_vec, self.w_vec = \
            
self.transformVectors(self.a_vec, self.b_vec, self.c_vec) 
[docs]    def getSpanningVectors(self, st):
        """
        Return the spanning vectors of this bounding box.
        :type st: schrodinger.structure.Structure
        :param st: the structure
        :rtype: list of numpy.array
        :return: contains vectors spanning the parallelepiped and its sides
        """
        try:
            vecs = xtal.get_vectors_from_chorus(st)
        except ValueError:
            params = xtal.get_lattice_param_properties(st)
            vecs = xtal.get_lattice_vectors(*params)
        a_vec, b_vec, c_vec = vecs
        return [
            a_vec, b_vec, c_vec, a_vec + b_vec, a_vec + c_vec, b_vec + c_vec,
            a_vec + b_vec + c_vec
        ] 
[docs]    def getBestSpanningVector(self, st):
        """
        Return the spanning vector with the largest projection onto the
        plane normal vector.
        :type st: schrodinger.structure.Structure
        :param st: the structure
        :rtype: numpy.array
        :return: the best spanning vector
        """
        spanning_vecs = self.getSpanningVectors(st)
        unit_normal_vec = transform.get_normalized_vector(self.normal_vec)
        pairs = [(x, abs(numpy.dot(x, unit_normal_vec))) for x in spanning_vecs]
        return max(pairs, key=lambda x: x[1])[0] 
[docs]    def getNumPlanes(self, st):
        """
        Return the number of planes that will fit inside the bounding box.
        :type st: schrodinger.structure.Structure
        :param st: the structure
        :rtype: int
        :return: the number of planes that will fit inside the bounding box
        """
        unit_normal_vec = transform.get_normalized_vector(self.normal_vec)
        spanning_vec = self.getBestSpanningVector(st)
        spanning_normal_vec = abs(numpy.dot(spanning_vec,
                                            unit_normal_vec)) * unit_normal_vec
        nplanes = int(
            round(
                old_div(transform.get_vector_magnitude(spanning_normal_vec),
                        self.inter_planar_separation)))
        return nplanes 
[docs]    def getInterPlanarSeparation(self):
        """
        Return the inter-planar separation in Angstrom.
        :rtype: float
        :return: the inter-planar separation in Angstrom
        """
        return self.inter_planar_separation 
[docs]    def getRotationToZ(self):
        """
        Return the rotation matrix needed to rotate this plane to the XY-plane
        as well as its inverse.
        :rtype: two numpy.array
        :return: the rotation matrix that rotates this plane to the XY-plane and
            its inverse.
        """
        self.rotate_to_z = transform.get_alignment_matrix(
            self.normal_vec, numpy.array(transform.Z_AXIS))
        self.inv_rotate_to_z = numpy.array(numpy.matrix(self.rotate_to_z).I)
        return self.rotate_to_z, self.inv_rotate_to_z 
[docs]    def getSquareVertices(self):
        """
        Return the vertices of a square that lies in this plane.  The square has
        and edge-length of 2 Angstrom.  It is rotated from the XY-plane, centered on
        origin, into this plane.
        :rtype: list of numpy.array
        :return: the vertices of the squre that lies in this plane
        """
        self.getRotationToZ()
        vertices = []
        for vertex in self.SQUARE.values():
            vertices.append(
                transform.transform_atom_coordinates(numpy.copy(vertex),
                                                     self.inv_rotate_to_z))
        return vertices 
[docs]    def getParallelepipedLineSegments(self, st):
        """
        Return the line segments that make this bounding box.
        :type st: schrodinger.structure.Structure
        :param st: the structure
        :rtype: list of tuples of heads and tails of 12 line segments.
        :return: the line segments that make this bounding box
        """
        a_vec, b_vec, c_vec, ab_vec, ac_vec, bc_vec, abc_vec = \
            
self.getSpanningVectors(st)
        # (start, end) line segment pairs
        # yapf: disable
        line_segments = [
            (self.origin, self.origin + a_vec),
            (self.origin, self.origin + b_vec),
            (self.origin, self.origin + c_vec),
            (self.origin + a_vec, self.origin + ab_vec),
            (self.origin + a_vec, self.origin + ac_vec),
            (self.origin + b_vec, self.origin + ab_vec),
            (self.origin + b_vec, self.origin + bc_vec),
            (self.origin + ab_vec, self.origin + abc_vec),
            (self.origin + c_vec, self.origin + ac_vec),
            (self.origin + c_vec, self.origin + bc_vec),
            (self.origin + ac_vec, self.origin + abc_vec),
            (self.origin + bc_vec, self.origin + abc_vec)
        ]
        # yapf: enable
        return line_segments 
[docs]    def getPlaneBoxIntersections(self, st, vertices):
        """
        Return the points where the plane containing the specified
        vertices intersects the parallelepiped.
        :type st: schrodinger.structure.Structure
        :param st: the structure
        :type vertices: list of numpy.array
        :param vertices: the vertices of the square that lies in this plane
        :rtype: list of numpy.array
        :return: the points of intersection
        """
        num_unique = 1
        aface = shapes.Face(1, list(self.SQUARE), vertices, num_unique)
        intersections = []
        for line_segment in self.getParallelepipedLineSegments(st):
            start, end = line_segment
            intersection = aface.intersectSegmentAndPlane(start, end, \
                
distance_thresh=self.DISTANCE_THRESH)
            if intersection is not None:
                if not intersections:
                    intersections.append(intersection)
                else:
                    redundant = False
                    for point in intersections:
                        if transform.get_vector_magnitude(intersection - point) < \
                            
self.SAME_VECTOR_THRESH:
                            redundant = True
                            break
                    if not redundant:
                        intersections.append(intersection)
        return intersections 
[docs]    def getOrderedIntersections(self, intersections):
        """
        Return the provided list of planar points in counter-clockwise
        order.
        :type intersections: list of numpy.array
        :param intersections: some intersection points in a plane
        :rtype: list of numpy.array
        :return: those planar intersections in counter-clockwise order
        """
        intersections += [old_div(sum(intersections), len(intersections))]
        intersections_xy = []
        for intersection in intersections:
            intersection_xy = transform.transform_atom_coordinates(
                numpy.copy(intersection) - self.origin, self.rotate_to_z)
            z_coord = intersection_xy[2]
            intersections_xy.append(numpy.delete(intersection_xy, 2))
        indices = ConvexHull(numpy.matrix(intersections_xy)).vertices
        intersections_xy_with_z = [
            numpy.append(intersections_xy[index], z_coord) for index in indices
        ]
        intersections = []
        for intersection in intersections_xy_with_z:
            point = transform.transform_atom_coordinates(
                numpy.copy(intersection), self.inv_rotate_to_z) + self.origin
            intersections.append(point)
        return intersections 
[docs]    def getVertices(self, normal_vec, idx, draw_location=0, thickness=1):
        """
        Return a list of numpy.array containing vertices of a plane
        with the given index.
        :type normal_vec: numpy.array
        :param normal_vec: the normal vector
        :type idx: int
        :param idx: the plane index, the sign of the index controls
            whether behind or ahead of the normal vector origin
        :type draw_location: float
        :param draw_location: specifies the starting location at which planes
            will be drawn in terms of a fraction of the normal vector
            which has a length of one inter-planar spacing
        :type thickness: float
        :param thickness: specifies the thickness or distance between
            consecutive planes in terms of a fraction of the normal vector
            which has a length of one inter-planar spacing
        :rtype: list of numpy.array
        :return: the plane vertices
        """
        # get square vertices of the hkl plane that passes through (0, 0, 0)
        square_vertices = self.getSquareVertices()
        # define a vector from (0, 0, 0) to the given draw_location
        shift_vec = self.origin + draw_location * normal_vec
        # define an offset vector
        offset_vec = idx * thickness * normal_vec
        return [v + shift_vec + offset_vec for v in square_vertices] 
[docs]    def getVerticesOfAllPlanes(self,
                               st,
                               draw_location=0,
                               thickness=1,
                               also_draw_planes_behind=True):
        """
        Return a list of lists of points where the set of planes
        intersect the parallelepiped.
        :type st: schrodinger.structure.Structure
        :param st: the structure
        :type draw_location: float
        :param draw_location: specifies the starting location at which planes
            will be drawn in terms of a fraction of the normal vector
            which has a length of one inter-planar spacing
        :type thickness: float
        :param thickness: specifies the thickness or distance between
            consecutive planes in terms of a fraction of the normal vector
            which has a length of one inter-planar spacing
        :type also_draw_planes_behind: bool
        :param also_draw_planes_behind: whether to also draw planes behind
            the specified draw location, this is in addition to always
            drawing planes ahead of the draw location
        :rtype: list of list of numpy.array
        :return: where the planes intersect the parallelepiped
        """
        # number of planes that can fit inside the largest bounding box
        nplanes = self.getNumPlanes(st)
        # get normal vector of proper sign
        spanning_vec = self.getBestSpanningVector(st)
        normal_vec = numpy.sign(
            numpy.dot(spanning_vec,
                      transform.get_normalized_vector(
                          self.normal_vec))) * self.normal_vec
        # if the zero plane (idx_plane == 0) doesn't intersect the parallelepiped
        # then create one outside the box at the parallelepiped origin at the start
        # of the normal vector, the size of the plane will be proportional to the
        # length of the normal vector
        if also_draw_planes_behind:
            idxs_plane = range(-nplanes, nplanes + 1)
        else:
            idxs_plane = range(nplanes + 1)
        vertices_of_all_planes = []
        for idx_plane in idxs_plane:
            vertices = self.getVertices(normal_vec,
                                        idx_plane,
                                        draw_location=draw_location,
                                        thickness=thickness)
            intersections = self.getPlaneBoxIntersections(st, vertices)
            if len(intersections) >= 3:
                if len(intersections) >= 4:
                    intersections = self.getOrderedIntersections(intersections)
                vertices_of_all_planes.append(intersections)
            elif not idx_plane:
                normal_len = transform.get_vector_magnitude(normal_vec)
                vertices = [
                    self.origin + normal_len *
                    transform.get_normalized_vector(v - self.origin)
                    for v in vertices
                ]
                vertices_of_all_planes.append(vertices)
        return vertices_of_all_planes