"""
Classes and functions to enumerate surfaces and interfaces.
Copyright Schrodinger, LLC. All rights reserved."""
# Contributor: Thomas F. Hughes
import argparse
import itertools
from schrodinger.application.matsci import parserutils
from schrodinger.application.matsci import textlogger
from schrodinger.application.matsci.nano import interface_mod
from schrodinger.application.matsci.nano import slab
from schrodinger.application.matsci.nano import xtal
from schrodinger.application.matsci import jobutils
from schrodinger.utils import cmdline
from schrodinger.utils import fileutils
MSGWIDTH = 100
LATTICE_PARAM_KEYS = [xtal.Crystal.A_KEY, xtal.Crystal.B_KEY, \
    xtal.Crystal.C_KEY, xtal.Crystal.ALPHA_KEY, xtal.Crystal.BETA_KEY, \
    xtal.Crystal.GAMMA_KEY]
BASE_ASU_FLAG = '-ref_xtal_asu'
BASE_OPTIONS_FLAG = '-ref_surface_options'
ADSORPTION_ASU_FLAG = '-ads_xtal_asu'
ADSORPTION_OPTIONS_FLAG = '-ads_surface_options'
INTERFACE_OPTIONS_FLAG = '-interface_options'
[docs]def get_base_names(ref_xtal_asu,
                   ref_base_name_in=None,
                   ads_xtal_asu=None,
                   ads_base_name_in=None):
    """
    Return the reference, adsorption, and final combined base
    names.
    :type ref_xtal_asu: str
    :param ref_xtal_asu: reference xtal input file
    :type ref_base_name_in: str or None
    :param ref_base_name_in: reference base name or None if
        none has been given
    :type ads_xtal_asu: str or None
    :param ads_xtal_asu: adsorption xtal input file or None
        if none has been given
    :type ads_base_name_in: str or None
    :param ads_base_name_in: adsorption base name or None if
        none has been given
    :rtype: str, str or None, str
    :return: reference, adsorption, and final combined base
        names, the adsorption base name will be None if no adsorption
        xtal asu is provided
    """
    if not ref_base_name_in:
        ref_base_name_out, ext = fileutils.splitext(ref_xtal_asu)
    else:
        ref_base_name_out = ref_base_name_in
    final_base_name = ref_base_name_out
    if ads_xtal_asu:
        if not ads_base_name_in:
            ads_base_name_out, ext = fileutils.splitext(ads_xtal_asu)
        else:
            ads_base_name_out = ads_base_name_in
        final_base_name += '_' + ads_base_name_out
    else:
        ads_base_name_out = None
    return ref_base_name_out, ads_base_name_out, final_base_name 
[docs]class ParserWrapper(object):
    """
    Manages the argparse module to parse user command line
    arguments.
    """
[docs]    def __init__(self, scriptname, description):
        """
        Create a ParserWrapper instance and process it.
        :type scriptname: str
        :param scriptname: name of this script
        :type description: str
        :param description: description of this script
        """
        name = '$SCHRODINGER/run ' + scriptname
        self.parserobj = parserutils.DriverParser(
            prog=name,
            description=description,
            add_help=False,
            formatter_class=argparse.ArgumentDefaultsHelpFormatter) 
[docs]    def loadIt(self):
        """
        Load ParserWrapper with options.
        """
        ref_xtal_asu_help = """Specify a Maestro file containing a
            crystalline ASU from which to create reference surfaces."""
        self.parserobj.add_argument(BASE_ASU_FLAG,
                                    type=str,
                                    required=True,
                                    help=ref_xtal_asu_help)
        ref_surface_options_help = """Specify a file containing input
            options for building reference surfaces.  See the surface
            builder help ($SCHRODINGER/run surface.py -h) for available
            options.  Instead of the -x_index (x = h, k, or l) options
            the options -x_index_min and -x_index_max can be used to
            enumerate hkl triples.  By default these ranges are %s, %s,
            and %s for h, k, and l, respectively.  Such triples may also
            be specified directly using the -hkl_indices option which
            takes whitespace separated integers and internally groups
            them by order into triples.  Note that the option
            -input_file is not needed here."""
        h_range = [slab.H_INDEX_MIN_DEFAULT, slab.H_INDEX_MAX_DEFAULT]
        k_range = [slab.K_INDEX_MIN_DEFAULT, slab.K_INDEX_MAX_DEFAULT]
        l_range = [slab.L_INDEX_MIN_DEFAULT, slab.L_INDEX_MAX_DEFAULT]
        self.parserobj.add_argument(BASE_OPTIONS_FLAG,
                                    type=str,
                                    help=ref_surface_options_help %
                                    (h_range, k_range, l_range))
        ads_xtal_asu_help = """Specify a Maestro file containing a
            crystalline ASU from which to create epitaxial surfaces.
            If specified then all possible interfaces (see the
            %s option) will be enumerated by taking
            reference/epitaxial surface pairs."""
        self.parserobj.add_argument(ADSORPTION_ASU_FLAG,
                                    type=str,
                                    help=ads_xtal_asu_help %
                                    INTERFACE_OPTIONS_FLAG)
        ads_surface_options_help = """Specify a file containing input
            options for building epitaxial surfaces.  These options
            are analogous to those used in %s."""
        self.parserobj.add_argument(ADSORPTION_OPTIONS_FLAG,
                                    type=str,
                                    help=ads_surface_options_help %
                                    BASE_OPTIONS_FLAG)
        interface_options_help = """Specify a file containing input
            options for building interfaces.  See the interface builder
            help ($SCHRODINGER/run interface.py -h) for available options.
            Note that the options -ref_layer and -ads_layer are not
            needed here."""
        self.parserobj.add_argument(INTERFACE_OPTIONS_FLAG,
                                    type=str,
                                    help=interface_options_help)
        ref_base_name_help = """Specify a base name to use in naming
            reference-related output files.  If none is provided then
            the base name of the file given for %s will
            be used."""
        self.parserobj.add_argument('-ref_base_name',
                                    type=str,
                                    help=ref_base_name_help % BASE_ASU_FLAG)
        ads_base_name_help = """Specify a base name to use in naming
            epitaxial-related output files.  If none is provided then
            the base name of the file given for %s will
            be used."""
        self.parserobj.add_argument('-ads_base_name',
                                    type=str,
                                    help=ads_base_name_help %
                                    ADSORPTION_ASU_FLAG)
        self.parserobj.add_argument('-verbose',
                                    action='store_true',
                                    help='Turn on verbose printing.')
        jc_options = [cmdline.HOST, cmdline.NOJOBID, cmdline.SAVE, \
            
cmdline.WAIT]
        cmdline.add_jobcontrol_options(self.parserobj, options=jc_options) 
[docs]    def parseArgs(self, args):
        """
        Parse the command line arguments.
        :type args: tuple
        :param args: command line arguments
        """
        self.options = self.parserobj.parse_args(args)  
def _get_flags_from_file(afile):
    """
    Return a list of flags from the given file.
    :type afile: str
    :param afile: file name
    :rtype: list
    :return: contains flags from the given file
    """
    flags = []
    with open(afile) as file_obj:
        for line in file_obj.read().splitlines():
            alist = line.split()
            flags.append(alist[0])
            if len(alist) > 1:
                try:
                    list(map(int, alist[1:]))
                except ValueError:
                    flags.append(' '.join(alist[1:]))
                else:
                    flags.extend(alist[1:])
    return flags
[docs]def get_surface_kwargs(options_file):
    """
    Return a dictionary of surface options from the given
    options file.
    :type options_file: str or None
    :param options_file: contains options for the surface build
        or None if there are none in which case the defaults are
        used
    :rtype: dict
    :return: contains options for the surface build
    """
    if options_file:
        args = _get_flags_from_file(options_file)
    else:
        args = ''
    parser = slab.ParserWrapper('surface.py', 'Build it.')
    parser.loadEnumeration()
    parser.loadOptions()
    parser.parseArgs(args)
    return vars(parser.options) 
[docs]def get_interface_kwargs(options_file):
    """
    Return a dictionary of interface options from the given
    options file.
    :type options_file: str or None
    :param options_file: contains options for the interface build
        or None if there are none in which case the defaults are
        used
    :rtype: dict
    :return: contains options for the interface build
    """
    if options_file:
        args = _get_flags_from_file(options_file)
    else:
        args = ''
    parser = interface_mod.ParserWrapper('interface.py', 'Build it.')
    parser.loadOptions()
    parser.parseArgs(args)
    return vars(parser.options) 
[docs]def get_hkl_indices(flattened_hkl_indices=None,
                    h_min=0,
                    h_max=0,
                    k_min=0,
                    k_max=0,
                    l_min=0,
                    l_max=0):
    """
    Collect and return all hkl Miller index triples sorted by
    increasing l then k then h.
    :type flattened_hkl_indices: list or None
    :param flattened_hkl_indices: flattened hkl indices, for
        example 100, 110, etc. as [1, 0, 0, 1, 1, 0, ...]
    :type h_min: int
    :param h_min: minimum h index
    :type h_max: int
    :param h_max: maximum h index
    :type k_min: int
    :param k_min: minimum k index
    :type k_max: int
    :param k_max: maximum k index
    :type l_min: int
    :param l_min: minimum l index
    :type l_max: int
    :param l_max: maximum l index
    :raise ValueError: if there is an issue
    :rtype: list
    :return: contains sorted tuples of hkl triples
    """
    all_hkl_indices = []
    h_range = range(h_min, h_max + 1)
    k_range = range(k_min, k_max + 1)
    l_range = range(l_min, l_max + 1)
    hkl_indices = [
        atuple for atuple in itertools.product(h_range, k_range, l_range)
    ]
    all_hkl_indices.extend(hkl_indices)
    if flattened_hkl_indices:
        size = 3
        if len(flattened_hkl_indices) % size:
            msg = ('The list of flattened hkl indices must be a '
                   f'multiiple of {size}.')
            raise ValueError(msg)
        indices = range(0, len(flattened_hkl_indices), size)
        hkl_indices = [
            tuple(flattened_hkl_indices[idx:idx + size]) for idx in indices
        ]
        all_hkl_indices.extend(hkl_indices)
    all_hkl_indices = set(all_hkl_indices)
    all_hkl_indices.discard((0, 0, 0))
    if not all_hkl_indices:
        msg = ('No valid hkl Miller index triples could be found.')
        raise ValueError(msg)
    all_hkl_indices = sorted(all_hkl_indices, key=lambda x: (x[2], x[1], x[0]))
    return all_hkl_indices 
[docs]def write_to_file(file_name, structs):
    """
    Write the structures to file with the feature's WAM type
    :param str file_name: The path to the file
    :param list structs: List of structures to write to file
    """
    jobutils.write_mae_with_wam(structs, file_name,
                                jobutils.WAM_TYPES.MS_SURFACES_INTERFACES) 
[docs]class Surfaces(object):
    """
    Manage the enumeration of surfaces.
    """
[docs]    def __init__(self, xtal_asu, surface_kwargs=None, logger=None):
        """
        Create an instance.
        :type xtal_asu: `schrodinger.structure.Structure`
        :param xtal_asu: the crystalline ASU from which to create surfaces
        :type surface_kwargs: None or dict
        :param surface_kwargs: kwargs for the surface build or None if
            there are none in which case the defaults will be used
        :type logger: logging.Logger or None
        :param logger: output logger or None if there isn't one
        """
        self.xtal_asu = xtal_asu
        self.surface_kwargs = surface_kwargs
        if not self.surface_kwargs:
            self.surface_kwargs = {}
        self.logger = logger
        self.hkl_indices = []
        self.xtal_cell = None
        self.surfaces = [] 
[docs]    def getHKLIndices(self):
        """
        Collect and return all hkl Miller index triples for this
        surface enumeration and sort them according to increasing
        l then k then h.
        :rtype: list
        :return: contains sorted tuples of hkl triples
        """
        flattened_hkl_indices = self.surface_kwargs.pop('hkl_indices')
        h_min = self.surface_kwargs.pop('h_index_min')
        h_max = self.surface_kwargs.pop('h_index_max')
        l_min = self.surface_kwargs.pop('l_index_min')
        l_max = self.surface_kwargs.pop('l_index_max')
        k_min = self.surface_kwargs.pop('k_index_min')
        k_max = self.surface_kwargs.pop('k_index_max')
        self.hkl_indices = get_hkl_indices(
            flattened_hkl_indices=flattened_hkl_indices,
            h_min=h_min,
            h_max=h_max,
            k_min=k_min,
            k_max=k_max,
            l_min=l_min,
            l_max=l_max)
        return self.hkl_indices 
[docs]    def getSurface(self, cell, hkl, logger=None):
        """
        Build and return a slab.Surface.
        :type cell: `schrodinger.structure.Structure`
        :param cell: a cell
        :type hkl: tuple
        :param hkl: a triple of Miller indices
        :type logger: logging.Logger or None
        :param logger: output logger or None if there isn't one
        :rtype: slab.Surface
        :return: the surface object
        """
        h_index, k_index, l_index = hkl
        surface = slab.Surface(cell,
                               h_index=h_index,
                               k_index=k_index,
                               l_index=l_index,
                               logger=logger,
                               **self.surface_kwargs)
        surface.runIt()
        return surface 
[docs]    def logParams(self):
        """
        Log the parameters.
        """
        if self.logger is None:
            return
        num_surfaces = textlogger.get_param_string('Number of surfaces',
                                                   len(self.hkl_indices),
                                                   MSGWIDTH)
        self.logger.info(num_surfaces)
        self.logger.info('') 
[docs]    def getXtalCell(self):
        """
        Build and return the crystal cell from which surfaces
        will be created.
        :rtype: `schrodinger.structure.Structure`
        :return: the crystal cell
        """
        self.xtal_cell = xtal.get_cell(self.xtal_asu)
        self.xtal_cell = xtal.make_p1(self.xtal_cell)
        return self.xtal_cell 
[docs]    def getSurfaces(self):
        """
        Build and return the slab.Surface objects for all surfaces.
        :rtype: list of slab.Surface
        :return: contains surface objects
        """
        # FIXME - see MATSCI-2277, this can be done faster in parallel
        for hkl in self.hkl_indices:
            surface = self.getSurface(self.xtal_cell, hkl, logger=self.logger)
            self.surfaces.append(surface)
        return self.surfaces 
[docs]    def writeSurfaces(self, file_name):
        """
        Write surfaces to a Maestro file with the given file name.
        :type file_name: str
        :param file_name: file name of the Maestro file
        """
        if self.surfaces:
            write_to_file(file_name,
                          [surface.surface for surface in self.surfaces]) 
[docs]    def runIt(self):
        """
        Create the surfaces.
        """
        # FIXME - see MATSCI-2278, the hkl triples should probably be uniquified
        self.getHKLIndices()
        self.logParams()
        self.getXtalCell()
        self.getSurfaces()  
[docs]class Interfaces(object):
    """
    Manage the enumeration of interfaces.
    """
[docs]    def __init__(self,
                 ref_surfaces,
                 ads_surfaces,
                 interface_kwargs=None,
                 logger=None):
        """
        Create an instance.
        :type ref_surfaces: list of `schrodinger.structure.Structure`
        :param ref_surfaces: reference surface ASUs from which interfaces
            will be created
        :type ads_surfaces: list of `schrodinger.structure.Structure`
        :param ads_surfaces: adsorption surface ASUs from which interfaces
            will be created
        :type interface_kwargs: None or dict
        :param interface_kwargs: kwargs for the interface builds or None
            if there are none in which case the defaults will be used
        :type logger: logging.Logger or None
        :param logger: output logger or None if there isn't one
        """
        self.ref_surfaces = ref_surfaces
        self.ads_surfaces = ads_surfaces
        self.interface_kwargs = interface_kwargs
        self.logger = logger
        self.interfaces = [] 
[docs]    def logParams(self):
        """
        Log the parameters.
        """
        if self.logger is None:
            return
        num_interfaces = textlogger.get_param_string(
            'Number of interfaces',
            len(self.ref_surfaces) * len(self.ads_surfaces), MSGWIDTH)
        self.logger.info(num_interfaces)
        self.logger.info('') 
[docs]    def getInterface(self, ref_surface, ads_surface):
        """
        Build and return a interface_mod.Interface.
        :type ref_surface: `schrodinger.structure.Structure`
        :param ref_surface: reference surface ASU
        :type ads_surface: `schrodinger.structure.Structure`
        :param ads_surface: adsorption surface ASU
        :rtype: interface_mod.Interface
        :return: the interface object
        """
        interface = interface_mod.Interface(ref_surface,
                                            ads_surface,
                                            logger=self.logger,
                                            **self.interface_kwargs)
        interface.runIt()
        return interface 
[docs]    def getInterfaces(self):
        """
        Build and return the interface_mod.Interface objects for all
        interfaces.
        :rtype: list of interface_mod.Interface
        :return: contains interface objects
        """
        # FIXME - see MATSCI-2277, this can be done faster in parallel
        for ref_surface in self.ref_surfaces:
            for ads_surface in self.ads_surfaces:
                interface = self.getInterface(ref_surface, ads_surface)
                self.interfaces.append(interface)
        return self.interfaces 
[docs]    def writeInterfaces(self, file_name):
        """
        Write interfaces to a Maestro file with the given file name.
        :type file_name: str
        :param file_name: file name of the Maestro file
        """
        if self.interfaces:
            write_to_file(
                file_name,
                [interface.interface for interface in self.interfaces]) 
[docs]    def runIt(self):
        """
        Create the interfaces.
        """
        self.logParams()
        self.getInterfaces()  
[docs]class SurfacesInterfaces(object):
    """
    Manage the enumeration of surfaces and interfaces.
    """
[docs]    def __init__(self,
                 ref_xtal_asu,
                 ref_surface_kwargs=None,
                 ads_xtal_asu=None,
                 ads_surface_kwargs=None,
                 interface_kwargs=None,
                 logger=None):
        """
        Create an instance.
        :type ref_xtal_asu: `schrodinger.structure.Structure`
        :param ref_xtal_asu: the crystalline ASU from which to create
            reference surfaces
        :type ref_surface_kwargs: None or dict
        :param ref_surface_kwargs: kwargs for the reference surface build
            or None if there are none in which case the defaults will be used
        :type ads_xtal_asu: None or `schrodinger.structure.Structure`
        :param ads_xtal_asu: the crystalline ASU from which to create
            adsorption surfaces or None if interfaces are not needed
        :type ads_surface_kwargs: None or dict
        :param ads_surface_kwargs: kwargs for the adsorption surface build
            or None if there are none in which case the defaults will be used
        :type interface_kwargs: None or dict
        :param interface_kwargs: kwargs for the interface build or None
            if there are none in which case the defaults will be used
        :type logger: logging.Logger or None
        :param logger: output logger or None if there isn't one
        """
        self.ref_xtal_asu = ref_xtal_asu
        self.ads_xtal_asu = ads_xtal_asu
        self.ref_surface_kwargs = ref_surface_kwargs
        if not self.ref_surface_kwargs:
            self.ref_surface_kwargs = {}
        self.ads_surface_kwargs = ads_surface_kwargs
        if not self.ads_surface_kwargs:
            self.ads_surface_kwargs = {}
        self.interface_kwargs = interface_kwargs
        if not self.interface_kwargs:
            self.interface_kwargs = {}
        self.logger = logger
        self.ref_surfaces = None
        self.ads_surfaces = None
        self.interfaces = None 
[docs]    def createSurfaces(self):
        """
        Create the surfaces.
        :raise ValueError: if something goes wrong with either surface
            build
        """
        self.ref_surfaces = Surfaces(self.ref_xtal_asu, \
            
surface_kwargs=self.ref_surface_kwargs, logger=self.logger)
        if self.logger:
            header = 'Reference Surfaces'
            self.logger.info(header)
            self.logger.info('-' * len(header))
        try:
            self.ref_surfaces.runIt()
        except ValueError as err:
            msg = ('Working on reference surfaces.')
            raise ValueError(' '.join([msg, str(err)]))
        if self.ads_xtal_asu:
            self.ads_surfaces = Surfaces(self.ads_xtal_asu, \
                
surface_kwargs=self.ads_surface_kwargs, logger=self.logger)
            if self.logger:
                header = 'Epitaxial Surfaces'
                self.logger.info(header)
                self.logger.info('-' * len(header))
            try:
                self.ads_surfaces.runIt()
            except ValueError as err:
                msg = ('Working on epitaxial surfaces.')
                raise ValueError(' '.join([msg, str(err)])) 
[docs]    def createInterfaces(self):
        """
        Create the interfaces.
        """
        afunc = lambda x: [surface.surface.copy() for surface in x.surfaces]
        ref_surfaces, ads_surfaces = list(
            map(afunc, [self.ref_surfaces, self.ads_surfaces]))
        self.interfaces = Interfaces(ref_surfaces,
                                     ads_surfaces,
                                     interface_kwargs=self.interface_kwargs,
                                     logger=self.logger)
        if self.logger:
            header = 'Interfaces'
            self.logger.info(header)
            self.logger.info('-' * len(header))
        self.interfaces.runIt() 
[docs]    def runIt(self):
        """
        Create the surfaces and interfaces.
        """
        self.checkInput()
        self.createSurfaces()
        if self.ads_xtal_asu:
            self.createInterfaces()