"""
Block data analysis script.
For a block average data file, determine if the simulation passed or not
depending on the conditions in a test file.
Copyright Schrodinger, LLC. All rights reserved.
"""
# Contributors: Pranav Dalal
import optparse
import shlex
import sys
from past.utils import old_div
# Global variables:
is_debugging = False
eps = 0.00001
def _debug_print(s):
    """
    Prints the string 's' if 'is_debugging' is True.
    """
    if (is_debugging):
        print(s)
    else:
        pass
[docs]def print_usage(text=None):
    """
    """
    print("""
Desmond .sbt testing utility
Usage: $SCHRODINGER/utilities/python sba.py
   -i <Input .sba file>
   -t <Input .sbt file>
   -o <Output .sba file>
   [options]
  Options:
   -d  Debug
   """)
    if (text is not None):
        print(text)
    sys.exit(0) 
[docs]class Test(object):
[docs]    def __init__(self, name):
        """
       """
        Test.count = 0
        self.sd = 0.0
        self.slope = 0.0
        self.average = 0.0
        self.average_tol = 0.0
        self._name = ""
        self.start = 0.0
        self.pass_fail = "" 
[docs]    def set_val(self, key, token):
        """
        """
        if (key in self.__dict__):
            self.__dict__[key] = float(token[0])
        else:
            raise KeyError(" Unknown label '%s' at line #%d." % (key, token[1])) 
[docs]    def print_val(self):
        """
        """
        for key in self.__dict__:
            if (key[0] != '_'):
                if (not (key[-4:] == "file" and self.__dict__[key] == "")):
                    print("%32s: %-30s" % (key, self.__dict__[key]))
        sys.stdout.flush() 
[docs]    def update_pass_fail(self, sba_file):
        bl = Block(self.NAME)
        bl.find_block(sba_file, self.NAME)
        if ((abs(bl.sd) < abs(self.sd)) & (abs(bl.slope) < abs(self.slope)) &
            ((abs(self.average - 0.0) < eps) |
             (abs(self.average - bl.average) < self.average_tol))):
            self.pass_fail = "Pass"
        else:
            self.pass_fail = "Fail"
        return bl  
[docs]class Block(object):
[docs]    def __init__(self, name):
        """
       """
        Block.count = 0
        self.sd = 0.0
        self.slope = 0.0
        self.average = 0.0
        self.average_tol = 0.0
        self._name = name
        self.start = 0.0
        self.variance = 0.0
        self.prop_unit = ""
        self.time_unit = "" 
[docs]    def find_block(self, sba_fname, property):
        """
    Given the name of a sba file, find the block with the name 'property'.
    """
        from math import sqrt
        try:
            inp_file = open(sba_fname, "r")
            inp_content = inp_file.readlines()
            inp_file.close()
        except:
            raise IOError("Cannot open/read file %s \n" % sba_fname)
        self._name = property
        block_string = "Block: " + property + "\n"
        start_index = inp_content.index(block_string)
        end_block_string = "End_Block \n"
        stop_index = inp_content.index(end_block_string, start_index)
        unit_cont = inp_content[start_index + 1]
        unit_list = unit_cont.split()
        self.time_unit = unit_list[0][unit_list[0].find('(') +
                                      1:unit_list[0].find(')')]
        self.prop_unit = unit_list[1][unit_list[1].find('(') +
                                      1:unit_list[1].find(')')]
        block_data = inp_content[start_index + 2:stop_index]
        block_data_list = []
        x = {}
        start = self.start
        for data in block_data:
            ind_list = []
            for ind_data in shlex.split(data):
                ind_list.append(float(ind_data))
            if ind_list[0] > start:
                x[ind_list[0]] = ind_list[1]
            block_data_list.append(ind_list)
        mean = old_div(sum(x.values()), len(x))
        n = len(x)
        var = 0.0
        sumx = 0.0
        sumy = 0.0
        sumx2 = 0.0
        sumxy = 0.0
        for key in list(x):
            var += (x[key] - mean)**2.0
            sumx += key
            sumy += x[key]
            sumx2 += (key * key)
            sumxy += key * x[key]
        self.average = mean
        self.variance = old_div(var, (len(x) - 1))
        self.slope = old_div(((n * sumxy) - (sumx * sumy)),
                             ((n * sumx2) - (sumx * sumx)))
        self.sd = sqrt(self.variance)  
[docs]class Energy(Test):
    NAME = "E"
[docs]    def __init__(self):
        Test.__init__(self, Energy.NAME)  
[docs]class PotentialEnergy(Test):
    NAME = "E_p"
[docs]    def __init__(self):
        Test.__init__(self, PotentialEnergy.NAME)  
[docs]class Pressure(Test):
    NAME = "P"
[docs]    def __init__(self):
        Test.__init__(self, Pressure.NAME)  
[docs]class Temperature(Test):
    NAME = "T"
[docs]    def __init__(self):
        Test.__init__(self, Temperature.NAME)  
[docs]class BoxSize(Test):
    NAME = "Simulation_Box"
[docs]    def __init__(self):
        Test.__init__(self, BoxSize.NAME)  
[docs]class Job_Details(Test):
    NAME = "JobDetails"
[docs]    def __init__(self):
        Test.__init__(self, Job_Details.NAME)  
[docs]class T_0(Test):
    NAME = "T_0"
[docs]    def __init__(self):
        Test.__init__(self, T_0.NAME)  
[docs]class E_k(Test):
    NAME = "E_k"
[docs]    def __init__(self):
        Test.__init__(self, E_k.NAME)  
[docs]class E_x(Test):
    NAME = "E_x"
[docs]    def __init__(self):
        Test.__init__(self, E_k.NAME)  
[docs]class E_f(Test):
    NAME = "E_f"
[docs]    def __init__(self):
        Test.__init__(self, E_f.NAME)  
[docs]class V(Test):
    NAME = "V"
[docs]    def __init__(self):
        Test.__init__(self, V.NAME)  
SUPPORTED_TEST = [
    Energy, PotentialEnergy, Temperature, Pressure, BoxSize, E_k, E_x, E_f, V,
    T_0
]
SUPPORTED_TESTNAME = [e.NAME for e in SUPPORTED_TEST]
[docs]def gen_test(s):
    """
    Returns a `*Test` object based on the given string 's'.
    This function will raise a ValueError if no object can be constructed.
    """
    i = SUPPORTED_TESTNAME.index(s)
    return SUPPORTED_TEST[i]() 
[docs]def process_chunks(chunk, test):
    """
    """
    _debug_print("    Proccessing chunk: %s" % str(chunk))
    key = None
    has_sym = False
    for token in chunk:
        if (key is not None):
            if (token[0] == "="):
                if (has_sym):
                    raise SyntaxError("Syntax error at line# %d." % token[1])
                has_sym = True
                continue
            if (not has_sym):
                raise SyntaxError("Syntax error at line# %d." % token[1])
            test.set_val(key, token)
            key = None
            has_sym = False
        else:
            if (token[0] == "="):
                raise SyntaxError(
                    "Expecting a label, but found '=' at line# %d." % token[1])
            key = token[0]
    return test 
[docs]def parse_sbt(sbt_fname):
    """
    Given a .sbt file, parses the file and returns a list of `*test` objects.
    """
    _debug_print("\nPrinting debugging information for parsing sbt file:")
    test_list = []
    test_type = None
    sbt_file = open(sbt_fname, "r")
    sbt_content = sbt_file.readlines()
    sbt_file.close()
    #
    _debug_print(
        "Removing blank and comment lines and striping leading and trailing spaces..."
    )
    new_content = []
    i = 1
    for sbt_line in sbt_content:
        _debug_print("  Line to parse: %s" % sbt_line)
        sbt_line = sbt_line.strip()
        _debug_print("  Line stripped: %s" % sbt_line)
        if (sbt_line == ""):
            _debug_print("  Blank line")
        elif (sbt_line.startswith("#")):
            _debug_print("  Comment line: %s" % sbt_line)
        else:
            new_content.append((
                sbt_line,
                i,
            ))
        i += 1
    #
    _debug_print("Tokenizing the content...")
    token = []
    for sbt_line in new_content:
        t = shlex.split(sbt_line[0].replace("{", " { ").replace("}",
                                                                " } ").replace(
                                                                    "=", " = "))
        for item in t:
            token.append((item, sbt_line[1]))
    num_token = len(token)
    _debug_print("  %d tokens" % num_token)
    _debug_print(str(token))
    #
    _debug_print("Parsing the tokens...")
    i = 0
    while (i < num_token):
        try:
            if (token[i + 1][0] == "{"):
                try:
                    test = gen_test(token[i][0])
                    _debug_print("  Start of information for test# %d" %
                                 Test.count)
                    _debug_print("    This is a '%s' test" % test._name)
                except ValueError:
                    raise ValueError("Unknown task type: %s" % token[i][0])
            else:
                raise SyntaxError("Syntax error at line# %d." % token[i][1])
        except IndexError:
            raise SyntaxError("Syntax error at line# %d." % token[i][1])
        i += 2
        chunk = []
        while (i < num_token):
            if (token[i][0] == "}"):
                break
            chunk.append(token[i])
            i += 1
        else:
            raise SyntaxError("Syntax error at line# %d." % token[i][1])
        test_list.append(process_chunks(chunk, test))
        i += 1
    #
    _debug_print("\nParsing sbt file completed.\n")
    return test_list 
[docs]def get_job_details(sba_fname):
    """
    Given a .sba file, parses the file and return Job Detail information.
    """
    try:
        inp_file = open(sba_fname, "r")
        inp_content = inp_file.readlines()
        inp_file.close()
    except:
        raise IOError("Cannot open/read file %s \n" % sba_fname)
    block_string = "Block: Job_Details\n"
    start_index = inp_content.index(block_string)
    end_block_string = "End_Block \n"
    stop_index = inp_content.index(end_block_string, start_index)
    block_data = inp_content[start_index + 1:stop_index]
    job_details = {}
    str_keys = ['Status', 'Job_name', 'Ensemble']
    int_keys = ['Atoms', 'Degrees_of_freedom', 'Particles']
    flt_keys = ['Temperature', 'Pressure', 'Temperature']
    for data in block_data:
        #job_details[data.partition('=')[0]] = data.partition('=')[2]
        key = data.split()[0]
        if (key in int_keys):
            job_details[data.split()[0]] = int(data.split()[2])
        elif (key in flt_keys):
            job_details[data.split()[0]] = float(data.split()[2])
        else:
            # Ev:114557 Job_name may contain space. Simply splitting
            # can cause problem here
            equal_pos = data.find('=')
            # in .sba file, the output format is "%s = %s \n"%(key, val)
            # skip the first space after '=' and the ' \n' at the end.
            job_details[key] = data[equal_pos + 2:-2]
    return job_details 
[docs]def get_sbt_tests(sbt_fname):
    test_list = parse_sbt(sbt_fname)
    test_names = []
    for test in test_list:
        test_names.append(test.NAME)
    return test_names 
if (__name__ == '__main__'):
    # Parses arguments.
    usage = 'Usage: %prog <-i .sba file> <-t .sbt file> <-o .sbafile> [options]'
    opt = optparse.OptionParser(usage)
    opt.add_option('-i',
                   '--inp',
                   type='string',
                   default='',
                   help='the input .sba file name')
    opt.add_option('-o',
                   '--out',
                   type='string',
                   default='',
                   help='the output .sba file name')
    opt.add_option('-t',
                   '--test',
                   type='string',
                   default='',
                   help='the input test file name')
    opt.add_option('-d',
                   '--debug',
                   action='store_true',
                   default=False,
                   help='turn on debug mode.')
    opts, args = opt.parse_args()
    if (opts.debug):
        print("Debugging mode is on.\n")
        is_debugging = True
    #opts.inp = 'del3.sba'
    #opts.out = 'del3_out.sba'
    #opts.test = 'del3_7.sbt'
    if (not opts.inp):
        print("Error: Input .sba file is not defined")
        print_usage()
        sys.exit(1)
    if (not opts.out):
        print("Error: Output .sba file is not defined")
        print_usage()
        sys.exit(1)
    if (not opts.test):
        print("Error: Input .sbt file is not defined")
        print_usage()
        sys.exit(1)
    try:
        inp_file = open(opts.inp, "r")
        inp_content = inp_file.readlines()
        inp_file.close()
    except:
        raise IOError("Cannot open/read file %s \n" % opts.inp)
    try:
        test_file = open(opts.test, "r")
        test_content = test_file.readlines()
        test_file.close()
    except:
        raise IOError("Cannot open/read file %s \n" % opts.test)
    # Get job details from input sba file and print it to output sba file
    try:
        out_file = open(opts.out, "w")
    except:
        raise IOError("Cannot open file for writing %s \n" % opts.out)
    out_file.write("Block: Job_Details \n")
    job_details = {}
    job_details = get_job_details(opts.inp)
    for key in list(job_details):
        out_file.write(" %s = %s \n" % (key, str(job_details[key])))
    out_file.write("End_Block\n\n")
    # parse sbt file and get a list of tests.
    test_list = parse_sbt(opts.test)
    # Find each block specified in sbt file
    # and calculate its standard deviation and slope
    for test in test_list:
        bl = test.update_pass_fail(opts.inp)
        try:
            out_file = open(opts.out, "a")
        except:
            raise IOError("Cannot open file for writing %s \n" % opts.out)
        _debug_print("Testing: %s" % test.NAME)
        _debug_print(
            "Testing criteria: SD = %7.3f, Slope = %7.3f %s , Average = %7.3f +/- %7.3f %s \n"
            % (test.sd, test.slope, bl.prop_unit + "/" + bl.time_unit,
               test.average, test.average_tol, bl.prop_unit))
        _debug_print(
            "Block data: SD = %7.3f, Slope = %7.3f %s, Average = %7.3f %s \n\n"
            % (bl.sd, bl.slope, bl.prop_unit + "/" + bl.time_unit, bl.average,
               bl.prop_unit))
        _debug_print("Test: %s " % test.pass_fail)
        out_file.write("Test for %s: %s\n" % (test.NAME, test.pass_fail))
        out_file.write(
            "Testing criteria: SD = %7.3f, Slope = %7.3f %s , Average = %7.3f +/- %7.3f %s \n"
            % (test.sd, test.slope, bl.prop_unit + "/" + bl.time_unit,
               test.average, test.average_tol, bl.prop_unit))
        out_file.write(
            "Block data: SD = %7.3f, Slope = %7.3f %s, Average = %7.3f %s \n\n"
            % (bl.sd, bl.slope, bl.prop_unit + "/" + bl.time_unit, bl.average,
               bl.prop_unit))
    out_file.close()