Source code for schrodinger.application.matsci.nano.check
"""
Classes and functions for checking nanostructure input.
Copyright Schrodinger, LLC. All rights reserved."""
# Contributor: Thomas F. Hughes
import os
from schrodinger.application.matsci.nano import constants
from schrodinger.application.matsci.nano import util
from schrodinger.utils import fileutils
_version = '$Revision 0.0 $'
[docs]def validate_options(parser):
    """
    Check the user-specified input options.
    :param parser: The parser to get the options from
    :type parser: `parserutils.DriverParser`
    :rtype: Named tuple
    :return: The parsed command line options
    """
    opts = parser.options
    termfrags = constants.Constants.TERMFRAGS
    if not opts.cg and opts.termfrag not in termfrags:
        parser.error(
            'For an atomistic system choose a fragment from the following list'
            f'to terminate the lattice: {",".join(termfrags)}.')
    return opts
[docs]class CheckInput(object):
    """
    Check user input.
    """
    DEFAULTMSG = """
        You have specified a value for flag %s that is not supported.  Values
        must be %s.  Proceeding with the default value of %s."""
    MIDFIX = '-'
[docs]    def checkElements(self, element1, element2, logger=None):
        symbols = util.get_atomic_element_symbols()
        constraint = 'valid atomic elements as given in the periodic table'
        element1 = element1.title()
        if element1 not in symbols:
            msg = self.DEFAULTMSG % ('\'-element1\'', constraint,
                                     constants.Constants.ELEMENT1)
            if logger:
                logger.warning(msg)
            element1 = constants.Constants.ELEMENT1
        element2 = element2.title()
        if element2 not in symbols:
            msg = self.DEFAULTMSG % ('\'-element2\'', constraint,
                                     constants.Constants.ELEMENT2)
            if logger:
                logger.warning(msg)
            element2 = constants.Constants.ELEMENT2
        return element1, element2
[docs]    def checkBondlength(self, bondlength, logger=None):
        constraint = 'greater than zero'
        if bondlength <= 0:
            msg = self.DEFAULTMSG % ('\'-bondlength\'', constraint,
                                     constants.Constants.BONDLENGTH)
            if logger:
                logger.warning(msg)
            bondlength = constants.Constants.BONDLENGTH
        return bondlength
[docs]    def checkEdgetypes(self, edgetype1, edgetype2, logger=None):
        constraint = 'in %s' % constants.Constants.EDGETYPES
        if edgetype1 not in constants.Constants.EDGETYPES:
            msg = self.DEFAULTMSG % ('\'-edgetype1\'', constraint,
                                     constants.Constants.EDGETYPE1)
            if logger:
                logger.warning(msg)
            edgetype1 = constants.Constants.EDGETYPE1
        if edgetype2 not in constants.Constants.EDGETYPES:
            msg = self.DEFAULTMSG % ('\'-edgetype2\'', constraint,
                                     constants.Constants.EDGETYPE2)
            if logger:
                logger.warning(msg)
            edgetype2 = constants.Constants.EDGETYPE2
        return edgetype1, edgetype2
[docs]    def checkCellDims(self, ncell1, ncell2, logger=None):
        constraint = '1 or greater'
        if ncell1 < 1:
            msg = self.DEFAULTMSG % ('\'-ncell1\'', constraint,
                                     constants.Constants.NCELL1)
            if logger:
                logger.warning(msg)
            ncell1 = constants.Constants.NCELL1
        if ncell2 < 1:
            msg = self.DEFAULTMSG % ('\'-ncell2\'', constraint,
                                     constants.Constants.NCELL2)
            if logger:
                logger.warning(msg)
            ncell2 = constants.Constants.NCELL2
        return ncell1, ncell2
[docs]    def checkTermFrag(self, termfrag, logger=None):
        constraint = 'in %s' % constants.Constants.TERMFRAGS
        if termfrag not in constants.Constants.TERMFRAGS:
            msg = self.DEFAULTMSG % ('\'-termfrag\'', constraint,
                                     constants.Constants.TERMFRAG)
            if logger:
                logger.warning(msg)
            termfrag = constants.Constants.TERMFRAG
        return termfrag
[docs]    def checkBilayerSep(self, bilayersep, logger=None):
        constraint = 'greater than zero'
        if bilayersep <= 0:
            msg = self.DEFAULTMSG % ('\'-bilayersep\'', constraint,
                                     constants.Constants.BILAYERSEP)
            if logger:
                logger.warning(msg)
            bilayersep = constants.Constants.BILAYERSEP
        return bilayersep
[docs]    def checkNumBilayers(self, nbilayers, logger=None):
        constraint = 'positive'
        if nbilayers < 0:
            msg = self.DEFAULTMSG % ('\'-nbilayers\'', constraint,
                                     constants.Constants.NBILAYERS)
            if logger:
                logger.warning(msg)
            nbilayers = constants.Constants.NBILAYERS
        return nbilayers
[docs]    def checkBilayerStackType(self, stacktype, logger=None):
        constraint = 'in %s' % constants.Constants.STACKTYPES
        if stacktype not in constants.Constants.STACKTYPES:
            msg = self.DEFAULTMSG % ('\'-stacktype\'', constraint,
                                     constants.Constants.ABAB)
            if logger:
                logger.warning(msg)
            stacktype = constants.Constants.ABAB
        return stacktype
[docs]    def checkBilayerShift(self, bilayershift, logger=None):
        constraint = 'positive'
        if bilayershift < 0:
            msg = self.DEFAULTMSG % ('\'-bilayershift\'', constraint,
                                     constants.Constants.BILAYERSHIFT)
            if logger:
                logger.warning(msg)
            bilayershift = constants.Constants.BILAYERSHIFT
        return bilayershift
[docs]    def checkIndicies(self, nindex, mindex, logger=None):
        """
        Check n-index and m-index.
        :type nindex: int
        :param nindex: the first chiral index
        :type mindex: int
        :param mindex: the second chiral index
        :type logger: logging.getLogger
        :param logger: output logger
        """
        constraint = '1 or greater'
        if nindex < 1:
            msg = self.DEFAULTMSG % ('\'-nindex\'', constraint,
                                     constants.Constants.NINDEX)
            if logger:
                logger.warning(msg)
            nindex = constants.Constants.NINDEX
        constraint = 'positive'
        if mindex < 0:
            msg = self.DEFAULTMSG % ('\'-mindex\'', constraint,
                                     constants.Constants.MINDEX)
            if logger:
                logger.warning(msg)
            mindex = constants.Constants.MINDEX
        msg = """
            You have specified a value for flag \'-mindex\' that is not
            supported.  Values must be less than or equal to the value
            for flag \'-nindex\'.  Proceeding with the defined values for
            \'-nindex\' and \'-mindex\' swapped."""
        if mindex > nindex:
            if logger:
                logger.warning(msg)
            tmpindex = nindex
            nindex = mindex
            mindex = tmpindex
        msg = """
            You have specified to build a nanotube with (1,0) indicies.  Such
            nanotubes are not supported.  Proceeding with the default values
            of indicies, i.e. a (%s,%s) nanotube."""
        if nindex == 1 and mindex == 0:
            nindex = constants.Constants.NINDEX
            mindex = constants.Constants.MINDEX
            if logger:
                logger.warning(msg % (nindex, mindex))
        return nindex, mindex
[docs]    def checkNumCells(self, ncells, logger=None):
        """
        Check the number of unit cells.
        :type ncells: int
        :param ncells: the number of unit cells
        :type logger: logging.getLogger
        :param logger: output logger
        """
        constraint = '1 or greater'
        if ncells < 1:
            msg = self.DEFAULTMSG % ('\'-ncells\'', constraint,
                                     constants.Constants.NCELLS)
            if logger:
                logger.warning(msg)
            ncells = constants.Constants.NCELLS
        return ncells
[docs]    def checkUpToIndex(self, up_to_nindex, up_to_mindex, logger=None):
        """
        Check the enumeration options.
        :type up_to_nindex: bool
        :param up_to_nindex: enumerate on the n-index
        :type up_to_mindex: bool
        :param up_to_mindex: enumerate on the m-index
        :type logger: logging.getLogger
        :param logger: output logger
        """
        msg = """
            You have simultaneously specified the flags \'-up_to_nindex\'
            and \'-up_to_mindex\' which is currently not supported.
            Enumeration of only one type or the other is currently supported.
            Proceeding with the \'-up_to_mindex\' flag set."""
        if up_to_nindex and up_to_mindex:
            if logger:
                logger.warning(msg)
            up_to_nindex = constants.Constants.UP_TO_NINDEX
            up_to_mindex = True
        return up_to_nindex, up_to_mindex
[docs]    def checkNumWalls(self, nwalls, logger=None):
        """
        Check the number of walls.
        :type nwalls: int
        :param nwalls: the number of walls
        :type logger: logging.getLogger
        :param logger: output logger
        """
        constraint = '1 or greater'
        if nwalls < 1:
            msg = self.DEFAULTMSG % ('\'-nwalls\'', constraint,
                                     constants.Constants.NWALLS)
            if logger:
                logger.warning(msg)
            nwalls = constants.Constants.NWALLS
        return nwalls
[docs]    def checkWallSep(self, wallsep, logger=None):
        """
        Check the desired wall separation.
        :type wallsep: float
        :param wallsep: wall separation in Angstrom
        :type logger: logging.getLogger
        :param logger: output logger
        """
        constraint = 'greater than zero'
        if wallsep <= 0:
            msg = self.DEFAULTMSG % ('\'-wallsep\'', constraint,
                                     constants.Constants.WALLSEP)
            if logger:
                logger.warning(msg)
            wallsep = constants.Constants.WALLSEP
        return wallsep
[docs]    def checkMaeExt(self, infile):
        """
        Check that the infile has a supported Maestro extension.
        :type infile: str
        :param infile: file name to check
        :rtype: str
        :return: outfile, if infile is bad return its basename plus
            constants.DEFAULT_MAE_EXT
        """
        maestro_yes_no = fileutils.is_maestro_file(infile)
        if not maestro_yes_no:
            basename, ext = fileutils.splitext(infile)
            outfile = basename + constants.Constants.DEFAULT_MAE_EXT
        else:
            outfile = infile
        return outfile
[docs]    def checkExistingFile(self, infile):
        """
        Check if the infile already exists and find a new name if it does.
        :type infile: str
        :param infile: file name to check
        :rtype: str
        :return: outfile, if infile is bad return new file name
        """
        if os.path.exists(infile):
            outfile = fileutils.get_next_filename(infile, self.MIDFIX)
        else:
            outfile = infile
        return outfile