"""
Block data calculation script.
For a MD Simulation, calculates properties from log, energy and simbox files
Calculates block averages of properties in energy file.
Copyright Schrodinger, LLC. All rights reserved.
"""
# Contributors: Pranav Dalal
import optparse
import os
import re
import shlex
import sys
from past.utils import old_div
import numpy as n
from schrodinger.utils import sea
def _debug_print(s):
    """
    Prints the string 's' if 'is_debugging' is True.
    """
    if (is_debugging):
        print(s)
    else:
        pass
[docs]def print_usage(text=None):
    """
    """
    print("""
Desmond .sba file writing utility
Usage: $SCHRODINGER/utilities/python sba.py
   -e <Input .ene file>
   -l <Input .log file>
   -c <Output .sba file>
   [options]
  Options:
   -n  Input block length to calculate block averages (default: 10 ps)
   -s Simbox .dat file
   """)
    if (text is not None):
        print(text)
    sys.exit(0) 
[docs]class Array_Prop(object):
    """
   """
    prop_value = []
[docs]    def __init__(self):
        """
       """
        self.name = ""
        self.fname = ""
        self.units = ""
        self.start_line = -1
        self.stop_line = -1
        self.column = -1
        self.prop_value = [] 
[docs]    def get_prop_value(self):
        if (self.column == -1):
            raise ValueError("Column is not set \n")
        if (self.start_line == -1):
            raise ValueError("Start line for parsing property is not set \n")
        if (self.stop_line != -1 & self.stop_line < self.start_line):
            raise ValueError(
                "Stop line for parsing property is less than start line\n")
        if (not self.name):
            raise ValueError("Property name is not set \n")
        if (not self.fname):
            raise ValueError("File name is not set \n")
        try:
            prop_file = open(self.fname, "r")
            prop_content = prop_file.readlines()
            prop_file.close()
        except:
            raise IOError("Cannot open/read file %s \n" % self.fname)
        if self.stop_line == -1:
            self.stop_line = len(prop_content)
        for x in range(self.start_line + 1, self.stop_line):
            #print shlex.split(prop_content[x])[self.column]
            self.prop_value.append(shlex.split(prop_content[x])[self.column])  
[docs]def calc_block_average(ene_fname,
                       prop_list,
                       sba_fname,
                       start_line,
                       time_column=0,
                       start_time=0.0,
                       block_length=10.0):
    if (len(prop_list) == 0):
        raise ValueError("Property list for block averaging is not defined\n")
    if (start_line == -1):
        raise ValueError("Start line for parsing property is not set \n")
    try:
        prop_file = open(ene_fname, "r")
        prop_content = prop_file.readlines()
    except:
        raise IOError("Cannot open/read file %s \n" % ene_fname)
    finally:
        prop_file.close()
    stop_line = len(prop_content)
    if (stop_line == 0):
        raise ValueError("Ene file is empty\n")
    if (stop_line < start_line):
        raise ValueError(
            "Stop line for parsing property is less than start line\n")
    prop_unit = {}
    prop_name = {}
    # Initialize data for each property column
    for prop_column, prop in enumerate(prop_list):
        name_unit = shlex.split(prop_list[prop_column].replace("(", " ("))
        name = name_unit[0]
        unit = name_unit[1].replace("(", "").replace(")", "")
        prop_unit[prop_column] = unit
        prop_name[prop_column] = name
    time_list = []
    time_init = 0.0
    block_count = 0
    averaging = 0
    data = []
    idx = 0
    # Get the data from ene file and store it into arrays
    for x in range(start_line + 1, stop_line):
        if not len(prop_content[x].strip()):
            continue
        data.append(shlex.split(prop_content[x]))
        idx += 1
        time = float(shlex.split(prop_content[x])[time_column])
        # Time to do averaging?
        if ((time - time_init) > block_length):
            averaging = 1
            block_count += 1
            time_list.append(idx)
            time_init = block_count * block_length
    data_arr = n.array(data).astype(float)
    # Redefine total energy E as E_p + E_k. This assumes that these three
    # values always appear in columns 1,2 and 3 in the *.ene file. Also
    # redefine first data point of E_k if it is zero
    if data_arr[0, 3] == 0.0:
        data_arr[0, 3] = data_arr[1, 3]
    data_arr[:, 1] = data_arr[:, 2] + data_arr[:, 3]
    data_list = n.split(data_arr, time_list)
    avg_data = []
    for x in range(len(data_list) - 1):
        # Calculate Averages
        avg_data.append(data_list[x].mean(0))
    # Write out block average data
    try:
        sba_file = open(sba_fname, "a")
    except:
        raise IOError("Cannot open/read file %s \n" % sba_fname)
    for i in range(0, len(prop_name)):
        sba_file.write("Block: %s\n" % (prop_name[i]))
        sba_file.write("Time(ps) " + prop_name[i] + "(" + prop_unit[i] + ")\n")
        for j in range(0, len(avg_data)):
            sba_file.write("%7.3f   %7.3f \n" %
                           (avg_data[j][0], avg_data[j][i]))
        sba_file.write("End_Block \n\n")
    sba_file.close() 
[docs]def get_listof_props_ene(ene_fname):
    """
    Given a .ene file, parse the ]header and get a list of properties.
    """
    prop_list = []
    try:
        ene_file = open(ene_fname, "r")
        ene_content = ene_file.readlines()
        ene_file.close()
    except:
        raise IOError("Cannot open/read file %s \n" % ene_file)
    props = ""
    line_num = -1
    prop_list = []
    prop_line_pattern = re.compile(r"#\s+0:time")
    for k, line in enumerate(ene_content):
        if prop_line_pattern.match(line):
            line_num = k
            props = line
    prop_pattern = re.compile(r"\d+:([^\s]+)\s+(\([^\)]+\))")
    for p in prop_pattern.findall(props):
        prop_list.append(p[0] + p[1])
    return line_num, prop_list 
[docs]def get_print_simbox_data(simbox_fname, sba_fname):
    """
    Given a simbox.dat file, parse the file and print simulation box
    details in the sba file.
    """
    try:
        box_file = open(simbox_fname, "r")
        box_content = box_file.readlines()
        box_file.close()
    except:
        raise IOError("Cannot open/read file %s \n" % simbox_fname)
    try:
        sba_file = open(sba_fname, "a")
    except:
        raise IOError("Cannot open sba file %s \n" % sba_fname)
    simbox_data = []
    indx = -1
    for k, item in enumerate(box_content):
        tok = shlex.split(item)
        if (len(tok) > 0 and tok[0] != "#"):
            indx = indx + 1
            simbox_data.append([float(tok[0])])
            for i in range(1, 9):
                simbox_data[indx].append(float(tok[i]))
    sba_file.write("Block: Simulation_Box\n")
    sba_file.write("Time(ps)  Box_Dimensions(A) \n")
    for i in range(0, len(simbox_data)):
        for k in range(0, len(simbox_data[i])):
            sba_file.write(" %7.3f    " % simbox_data[i][k])
        sba_file.write("\n")
    sba_file.write("End_Block\n\n")
    sba_file.close() 
[docs]def get_top_level_props(ene_fname):
    """
    Given a .ene file, parse the header and get top level props.
    Print those in sba file.
    """
    try:
        ene_file = open(ene_fname, "r")
        ene_content = ene_file.readlines()
        ene_file.close()
    except:
        raise IOError("Cannot open()/read file %s \n" % ene_fname)
    hash_pattern = re.compile('#')
    equal_pattern = re.compile('=')
    n_atoms = -1
    n_dof = -1
    ene_time = -1
    for line in ene_content:
        if (hash_pattern.match(line)):
            if (equal_pattern.search(line)):
                if (line[1:].partition('=')[0].strip() == 'N atoms'):
                    n_atoms = int(line[1:].partition('=')[2])
                if (line[1:].partition('=')[0].strip() == 'N dof'):
                    n_dof = int(line[1:].partition('=')[2].split()[0])
    ene_time = float(ene_content[-1].split()[0])
    return n_atoms, n_dof, ene_time 
[docs]def get_ensemble_from_cfg(cfg_fname):
    """
    Given a cfg filename, get the ensemble class from it.
    :param cfg_fname: cfg filename to be parsed.
    :type cfg_fname: str
    :return: The parsed ensemble class
    :rtype: str
    """
    with open(cfg_fname, 'r') as fh:
        cfg_contents = fh.read()
    cfg_map = sea.Map(cfg_contents)
    return cfg_map['ORIG_CFG'].ensemble.class_.val 
[docs]def get_log_props(log_fname):
    """
    Given a .log file, parse the header and determine if it completed okay.
    """
    try:
        log_file = open(log_fname, "r")
        log_content = log_file.readlines()
        log_file.close()
    except:
        raise IOError("Cannot open/read file %s \n" % log_fname)
    normal_pat = re.compile(r'::: finished :::')
    temp_pat = re.compile(r"temperature = \[{")
    temp_pat_next = re.compile(r"T_ref =")
    temp_pat_1 = re.compile(r"temperature = \{")
    temp_pat_1_next = re.compile(r"T_ref =")
    press_pat = re.compile(r"P_ref=\[")
    quote_pat = re.compile('"')
    status = "Incomplete"
    temp_str = ""
    press_str = ""
    time_str = ""
    time = -1.0
    particles = 0
    temperature = -1.0
    pressure = -1.0
    for indx, line in enumerate(log_content):
        if len(line) == 0 or line[0] == '#' or line[0] == "\n":
            continue
        if normal_pat.search(line):
            status = "Normal"
        if temp_pat_1.search(line):
            for indx2, line2 in enumerate(log_content[indx:]):
                if temp_pat_1_next.search(line2):
                    temp_str = quote_pat.split(line2)[-2]
                    temperature = float(temp_str)
        if temp_pat.search(line):
            if temp_pat_next.search(log_content[indx + 1]):
                temp_str = quote_pat.split(log_content[indx + 1])[-2]
                temperature = float(temp_str)
        if press_pat.search(line):
            press_str = quote_pat.split(line)[-2]
            pressure = float(press_str)
        if line[0:13] == 'Chemical time':
            time_str = line
        if (line.find('Injected') != -1 and line.find('particles') != -1):
            if line[0] == '[':
                particles = int(line.split()[2])
            else:
                particles = int(line.split()[1])
    time = float(time_str.split()[2])
    return status, time, temperature, pressure, particles 
[docs]def parse_sba(sba_fname):
    """
    Given name of a sba file name find all blocks except time and job
    details and return their average, sd and slope.
    """
    unused_properties = ['Job_Details', 'time']
    try:
        inp_file = open(sba_fname, "r")
        inp_content = inp_file.readlines()
        inp_file.close()
    except:
        raise IOError("Cannot open/read file %s \n" % sba_fname)
    block_res = {}
    results = {}
    start_index = -1
    stop_index = -1
    for k, line in enumerate(inp_content):
        if (line[0:7] == 'Block: '):
            start_index = k
            property = line.split()[1]
            if (unused_properties.count(property) > 0):
                start_index = -1
        if (line[0:9] == 'End_Block'):
            stop_index = k
        if ((start_index > 0) & (stop_index > 0) & (start_index < stop_index)):
            block_res = get_block_summary(inp_content, start_index, stop_index)
            results.update(block_res)
            start_index = -1
            stop_index = -1
            block_res = {}
    return results 
[docs]def get_block_summary(inp_content, start_index, stop_index):
    """
    From block content get averages, sd, slope and units.
    Return them as dictionary.
    """
    from math import sqrt
    property = inp_content[start_index].split()[1]
    unit_cont = inp_content[start_index + 1]
    unit_list = unit_cont.split()
    time_unit = unit_list[0][unit_list[0].find('(') + 1:unit_list[0].find(')')]
    prop_unit = unit_list[1][unit_list[1].find('(') + 1:unit_list[1].find(')')]
    block_data = inp_content[start_index + 2:stop_index]
    block_data_list = []
    x = {}
    start = start_index
    for data in block_data:
        ind_list = []
        for ind_data in shlex.split(data):
            ind_list.append(float(ind_data))
        #if ind_list[0] > start:
        x[ind_list[0]] = ind_list[1]
        block_data_list.append(ind_list)
    try:
        mean = old_div(sum(x.values()), len(x))
    except:
        raise ValueError(
            "Error calulating mean from the block, most probably no data is available. \n"
        )
    n = len(x)
    var = 0.0
    sumx = 0.0
    sumy = 0.0
    sumx2 = 0.0
    sumxy = 0.0
    for key in list(x):
        var += (x[key] - mean)**2.0
        sumx += key
        sumy += x[key]
        sumx2 += (key * key)
        sumxy += key * x[key]
    average = mean
    variance = old_div(var, (len(x) - 1))
    slope = old_div(((n * sumxy) - (sumx * sumy)),
                    ((n * sumx2) - (sumx * sumx)))
    sd = sqrt(variance)
    res = {}
    avg_str = property + '.avg'
    sd_str = property + '.sd'
    slope_str = property + '.slope'
    unit_str = property + '.units'
    time_str = property + '.time_unit'
    block_str = property + '.block'
    res[avg_str] = average
    res[sd_str] = sd
    res[slope_str] = slope
    res[unit_str] = prop_unit
    res[time_str] = time_unit
    res[block_str] = x
    return res 
[docs]def write_sba_file(ene,
                   log,
                   cfg_fname,
                   sbafile,
                   simboxfile=None,
                   block_len=10.0):
    """
    Write the sba file from specified `*.ene`, `*.log` and `*.cfg` files.
    """
    try:
        sba_file = open(sbafile, "w")
    except:
        raise IOError("Cannot open/read file %s \n" % sbafile)
    sba_file.write("Calculating Block Averages \n")
    sba_file.write("                    Version: %s \n" % "1.0")
    sba_file.write("                Energy_File: %s \n" % ene)
    sba_file.write("                   Log_File: %s \n" % log)
    sba_file.write("                   Cfg_File: %s \n" % cfg_fname)
    if simboxfile:
        sba_file.write("                Simbox File: %s \n" % simboxfile)
    sba_file.write("Block: Job_Details\n")
    simulation_summary = {}
    simulation_summary['Job_name'] = ene.partition('.')[0]
    try:
        status, time, temperature, pressure, particles = get_log_props(log)
    except:
        raise IOError("Cannot get job details from log file to %s \n" % log)
    try:
        ensemble = get_ensemble_from_cfg(cfg_fname)
    except (ValueError, KeyError) as err:
        raise IOError(f" Could not parse ensemble from {cfg_fname}: {err}")
    simulation_summary['Status'] = status
    simulation_summary['Duration'] = old_div(time, 1000)
    simulation_summary['Temperature'] = temperature
    simulation_summary['Particles'] = particles
    simulation_summary['Ensemble'] = ensemble
    try:
        n_atoms, n_dof, ene_time = get_top_level_props(ene)
    except:
        raise IOError("Cannot get job details from ene file to %s \n" % ene)
    simulation_summary['Atoms'] = n_atoms
    simulation_summary['Degrees_of_freedom'] = n_dof
    if (ene_time != time):
        simulation_summary['Energy_file_duration'] = old_div(ene_time, 1000)
    for key in list(simulation_summary):
        sba_file.write(" %s = %s \n" % (key, simulation_summary[key]))
    sba_file.write("End_Block \n\n")
    sba_file.close()
    line_num, prop_list = get_listof_props_ene(ene)
    """
    #Changing the code to improve speed
    # To determine if it can be used without sacrificng speed
    for prop_column, prop in enumerate(prop_list):
       print prop
       z = Array_Prop()
       z.start_line = line_num
       name_unit = shlex.split(prop_list[prop_column].replace("("," ("))
       z.name = name_unit[0]
       z.unit = name_unit[1].replace("(","").replace(")","")
       z.fname = ene
       z.column = prop_column
    """
    calc_block_average(ene,
                       prop_list,
                       sbafile,
                       line_num,
                       block_length=block_len)
    if simboxfile:
        get_print_simbox_data(simboxfile, sbafile) 
if (__name__ == '__main__'):
    # Parses arguments.
    usage = 'Usage: %prog <-e .ene file> <-l .log file> <-s simbox.dat file> <-c .sba file> <-n block_length > [options]'
    opt = optparse.OptionParser(usage)
    opt.add_option('-e',
                   '--ene',
                   type='string',
                   default="",
                   help='.ene file for this run')
    opt.add_option('-l',
                   '--log',
                   type='string',
                   default="",
                   help='.log file for this run')
    opt.add_option('-cfg',
                   type='string',
                   default="",
                   help='.cfg file for this run')
    opt.add_option('-c',
                   '--sbafile',
                   type='string',
                   default='',
                   help='output .sba file name')
    opt.add_option('-s',
                   '--simboxfile',
                   type='string',
                   default='',
                   help='input simbox.dat file name')
    opt.add_option('-n',
                   '--block_len',
                   type='float',
                   default=10.0,
                   help='block length for calculating averages.')
    opt.add_option('-d',
                   '--debug',
                   action='store_true',
                   default=False,
                   help='turn on debug mode.')
    opts, args = opt.parse_args()
    if (opts.debug):
        print("Debugging mode is on.\n")
        is_debugging = True
    if (not os.path.isfile(opts.ene)):
        print("Error: Input ene not found: " + opts.ene)
        print_usage()
        sys.exit(1)
    if (not os.path.isfile(opts.log)):
        print("Error: Input log file not found: " + opts.log)
        print_usage()
        sys.exit(1)
    if (not opts.sbafile):
        print("Error: Output .sba file is not defined")
        print_usage()
        sys.exit(1)
    sys.stdout.flush()
    if (not os.path.isfile(opts.cfg)):
        print("Error: Input cfg file not found: " + opts.cfg)
        print_usage()
        sys.exit(1)
    write_sba_file(opts.ene, opts.log, opts.cfg, opts.sbafile, opts.simboxfile,
                   opts.block_len)
#EOF