import os
import re
import sys
from functools import reduce
from past.utils import old_div
import numpy
from . import antlr3
from .enhanced_sampling.FcnTypes import getFcnSigs
from .enhanced_sampling.mexpLexer import mexpLexer
from .enhanced_sampling.mexpParser import ADDOP
from .enhanced_sampling.mexpParser import BIND
from .enhanced_sampling.mexpParser import BLOCK
from .enhanced_sampling.mexpParser import CUTOFF
from .enhanced_sampling.mexpParser import DECL_META
from .enhanced_sampling.mexpParser import DECL_OUTPUT
from .enhanced_sampling.mexpParser import DIM
from .enhanced_sampling.mexpParser import ELEM
from .enhanced_sampling.mexpParser import FIRST
from .enhanced_sampling.mexpParser import HEADER
from .enhanced_sampling.mexpParser import IF
from .enhanced_sampling.mexpParser import INITKER
from .enhanced_sampling.mexpParser import INTERVAL
from .enhanced_sampling.mexpParser import ITER
from .enhanced_sampling.mexpParser import LIT
from .enhanced_sampling.mexpParser import NAME
from .enhanced_sampling.mexpParser import SERIES
from .enhanced_sampling.mexpParser import STATIC
from .enhanced_sampling.mexpParser import STRING
from .enhanced_sampling.mexpParser import SUBTROP
from .enhanced_sampling.mexpParser import VAR
from .enhanced_sampling.mexpParser import mexpParser
# FIXME add line numbers to errors
[docs]def showtype(i):
    if isinstance(type(i), str):
        return 'string'
    return 'length-' + str(i) + ' array' 
[docs]class Node(object):
[docs]    def __init__(self, env, children):
        self.children = children
        self.env = env 
    #def __getitem__(self, i):
    #  self.children[i]
[docs]    def resolve_atomsel(self, aslobj, gids):
        self.children = [c.resolve_atomsel(aslobj, gids) for c in self.children]
        return self 
[docs]    def constant_fold(self):
        # constant folding routine may assume that the program is well-typed
        self.children = [c.constant_fold() for c in self.children]
        return self 
    # get_type validates the type correctness of the expression and
    # returns the type of the node and validates
    # get_type also ensures that all variables can be uniquely resolved
    # in the given scope
[docs]    def get_type(self):
        raise ValueError('Internal error.  Function not defined')  
[docs]class Lit(Node):
[docs]    def __init__(self, env, value):  # should validate input
        Node.__init__(self, env, [])
        self.value = value 
[docs]    def get_type(self, env):
        return len(self.value) 
    def __str__(self):
        if len(self.value) == 1:
            return repr(self.value[0])
        else:
            return '[literal %s]' % ' '.join([repr(v) for v in self.value]) 
[docs]class String(Node):
[docs]    def __init__(self, env, value):  # should validate input
        Node.__init__(self, env, [])
        self.value = value 
[docs]    def get_type(self, env):
        return 'string' 
    # the following str assumes that there are no characters than need escaping
    def __str__(self):
        return '"' + self.value + '"' 
[docs]class Var(Node):
[docs]    def __init__(self, env, name):
        Node.__init__(self, env, [])
        self.name = name 
[docs]    def get_type(self, env):
        try:
            return env.binds[self.name]
        except KeyError as e:
            str = 'Variable %s unknown' % self.name
            raise KeyError(str) 
    def __str__(self):
        return '$' + self.name 
[docs]class Bind(Node):
[docs]    def __init__(self, env, name, value):
        Node.__init__(self, env, [value])
        self.name = name 
[docs]    def get_type(self, env):
        return self.children[0].get_type(env) 
    def __str__(self):
        return '[$%s %s]' % (str(self.name), str(self.children[0])) 
[docs]class FcnCall(Node):
[docs]    def __init__(self, env, name, children):
        Node.__init__(self, env, children)
        self.name = name 
    def __str__(self):
        arglist = ' '.join([str(n) for n in self.children])
        return '[%s %s]' % (self.name, arglist)
[docs]    def resolve_atomsel(self, aslobj, gids):
        if self.name == 'atomsel':
            if len(self.children) != 1:
                raise TypeError('atomsel takes one argument')
            s = self.children[0]
            if type(s) is String:
                env = Env()
                #l = aslobj.atomsel.atomsel(s.value).get('index')
                l = aslobj.atomsel(s.value)
                gids.update(l)
                cs = [Lit(env, [float(i)]) for i in l]
                return FcnCall(env, 'array', cs)
            else:
                raise ValueError('Argument to atomsel must have type string')
        else:
            return Node.resolve_atomsel(self, aslobj, gids) 
[docs]    def constant_fold(self):
        self.children = [c.constant_fold() for c in self.children]
        if not all([isinstance(c, Lit) for c in self.children]):
            return self
        vals = [c.value for c in self.children]
        if self.name == 'array':
            elems = list()
            for v in vals:
                elems.extend(v)
            return Lit(Env(), elems)
        # this simple folding is primarily intended to handle cases where the
        #   parser interprets "-5.0" as [* -1 5.0]
        def bin_thread(f, arg1, arg2):
            if len(arg1) == 1:
                x = arg1[0]
                return [f(x, y) for y in arg2]
            elif len(arg2) == 1:
                y = arg2[0]
                return [f(x, y) for x in arg1]
            elif len(arg1) == len(arg2):
                return [f(x, y) for x, y in zip(arg1, arg2)]
            else:
                raise RuntimeError(
                    'Internal error.  Invalid type on binary operation')
        if self.name == '+' and len(vals) == 2:
            return Lit(Env(), bin_thread(lambda x, y: x + y, vals[0], vals[1]))
        elif self.name == '*' and len(vals) == 2:
            return Lit(Env(), bin_thread(lambda x, y: x * y, vals[0], vals[1]))
        elif self.name == '-' and len(vals) == 2:
            return Lit(Env(), bin_thread(lambda x, y: x - y, vals[0], vals[1]))
        elif self.name == '/' and len(vals) == 2:
            return Lit(Env(),
                       bin_thread(lambda x, y: old_div(x, y), vals[0], vals[1]))
        else:
            return self 
[docs]    def get_type(self, env):
        if self.name == 'load':
            if len(self.children) != 1:
                raise TypeError('Wrong number of arguments to load')
            arg = self.children[0]
            t = arg.get_type(env)
            if not isinstance(t, str):
                raise TypeError('Must pass string to load, not ' + str(t))
            if arg.value not in env.statics:
                raise ValueError('unkown variable %s' % arg.value)
            return env.statics[arg.value]
        elif self.name == 'store':
            if len(self.children) != 2:
                raise TypeError('Wrong number of arguments to store')
            arg0 = self.children[0]
            t = arg0.get_type(env)
            if not isinstance(t, str):
                raise TypeError('Must pass string as argument 1 of store, not ' \
                    
+ showtype(t))
            if arg0.value not in env.statics:
                raise ValueError('variable %s is unknown in store' % arg0.value)
            tdes = env.statics[arg0.value]
            t = self.children[1].get_type(env)
            if tdes != t:
                raise TypeError('attempt to store %s in %s, but %s has type %s' % \
                    
(showtype(t), arg0.value, arg0.value, showtype(tdes)))
            return env.statics[arg0.value]
        else:
            child_types = [c.get_type(env) for c in self.children]
            # check if my function name exists
            t = env.sigs[self.name].check(child_types)
            if self.name == 'meta' and type(self.children[0]) is Lit \
                
and len(self.children[0].value) == 1:
                mid = int(round(self.children[0].value[0]))
                if mid < 0 or mid >= len(env.metas):
                    raise ValueError(
                        'metadynamics accumulator id outsides range of accumulators'
                    )
                d = env.metas[mid].dim
                if d != child_types[2]:
                    raise TypeError(('metadynamics accumulator %i has dimension %i' + \
                        
' but was passed a %s collective variable') \
                        
% (mid, d, showtype(child_types[2])))
            return t  
[docs]class Iter(Node):
[docs]    def __init__(self, env, name, lb, ub):
        Node.__init__(self, env, [lb, ub])
        self.name = name 
[docs]    def get_type(self, env):
        tl = self.children[0].get_type(env)
        tu = self.children[1].get_type(env)
        errmsg = '%s bound of iterator %s must be a length-1 array but is a %s'
        if tl != 1:
            raise TypeError(errmsg % ('Lower', self.name, showtype(tl)))
        if tu != 1:
            raise TypeError(errmsg % ('Upper', self.name, showtype(tu)))
        return 1 
    def __str__(self):
        return '[$%s %s %s]' % (self.name, self.children[0], self.children[1]) 
[docs]class Let(Node):
[docs]    def __init__(self, env, binds, value):
        # for each bind, if it is not an assignment, make a gensym
        Node.__init__(self, env, binds + [value])
        self.binds = binds
        for i in range(len(self.binds)):
            if not type(self.binds[i]) is Bind:
                self.binds[i] = Bind(env, env.gensym(), self.binds[i]) 
[docs]    def get_type(self, env):
        env.binds.enter_scope()
        for b in self.binds:
            env.binds.add_var(b.name, b.get_type(env))
        t = self.children[-1].get_type(env)
        env.binds.leave_scope()
        return t 
    def __str__(self):
        bindlist = ' '.join(map(str, self.binds))
        return '[let [%s] %s]' % (bindlist, str(self.children[-1])) 
[docs]class Series(Node):
[docs]    def __init__(self, env, iters, value):
        Node.__init__(self, env, iters + [value])
        self.iters = iters 
[docs]    def get_type(self, env):
        env.binds.enter_scope()
        for b in self.iters:
            env.binds.add_var(b.name, b.get_type(env))
        t = self.children[-1].get_type(env)
        env.binds.leave_scope()
        return t 
    def __str__(self):
        bindlist = ' '.join(map(str, self.iters))
        return '[series [%s] %s]' % (bindlist, str(self.children[-1])) 
[docs]class If(Node):
[docs]    def __init__(self, env, cond, then_case, else_case):
        Node.__init__(self, env, [cond, then_case, else_case]) 
[docs]    def get_type(self, env):
        tc = self.children[0].get_type(env)
        tt = self.children[1].get_type(env)
        te = self.children[2].get_type(env)
        if tc != 1:
            raise TypeError(
                'Condition of if expression must have type 1 but has type %s' %
                showtype(tc))
        if tt != te:
            raise TypeError(
                  'The branches of the if expression must have the same types but ' \
                  
'the types are %s and %s' % (showtype(tt), showtype(te)))
        return tt 
    def __str__(self):
        return '[if %s %s %s]' % (str(self.children[0]), \
                                  
str(self.children[1]), \
                                  
str(self.children[2])) 
# bindings represents a scoped mapping from names to Node's
[docs]class binding(object):
[docs]    def __init__(self):
        self.bs = [{}] 
[docs]    def enter_scope(self):
        self.bs = [{}] + self.bs 
[docs]    def leave_scope(self):
        self.bs = self.bs[1:] 
[docs]    def add_var(self, name, tp):
        if name in self.bs[0]:
            raise ValueError('%s declared twice in the same scope' % name)
        self.bs[0][name] = tp 
    def __getitem__(self, it):
        #print str(self.bs)
        for d in self.bs:
            if it in d:
                return d[it]
        # if we reach this statement, it has not been declared
        raise KeyError
    def __str__(self):
        s = ''
        for d in self.bs:
            s = s + str(d)
        return s 
[docs]class Env(object):
[docs]    def __init__(self):
        self.statics = {}
        self.metas = []
        self.gensym_cnt = 0
        self.output = None
        self.binds = binding()
        self.sigs = getFcnSigs() 
[docs]    def gensym(self):
        i = self.gensym_cnt
        self.gensym_cnt = i + 1
        return 'gensym' + str(i) 
[docs]    def add_static(self, nm, type):
        if nm in self.statics:
            raise ValueError('Variable (%s) bound twice in same scope' % nm)
        if type < 0:
            raise ValueError("Type of static variable %s cannot be negative" %
                             nm)
        self.statics[nm] = type 
[docs]    def add_output(self, nm, first, interval):
        if not (self.output is None):
            raise "Output cannot be declared twice"
        self.output = (nm, first, interval) 
    def __str__(self):
        ret = ''
        d = self.statics
        ret = ret + 'metadynamics_accumulators=[%s] ' % \
                 
(' '.join(map(str, self.metas)))
        ret = ret + \
            
"storage={%s} " % ' '.join(['%s=%i' % (k, d[k]) for k in sorted(d)])
        if self.output is not None:
            name, first, interval = self.output
        else:
            name, first, interval = ("", 0.0, 0.0)
        ret = ret + 'name="%s" first=%s interval=%s' \
            
% (name, repr(first), repr(interval))
        return ret 
[docs]def bodyToNode(tree, env):
    tok = tree.getToken()
    if tok is None:
        sys.stderr.write('failed to parse m-expression\n')
        exit(1)
    t = tok.getType()
    cs = []
    for c in tree.children:
        cs.append(bodyToNode(c, env))
    if t == BLOCK:  # FIXME must handle this better for gensym
        if len(cs) == 0:
            EmptyBlock = "Error: empty block"
            raise EmptyBlock
        if len(cs) == 1:
            return cs[0]
        return Let(env, cs[0:len(cs) - 1], cs[len(cs) - 1])
    elif t == IF:
        return If(env, cs[0], cs[1], cs[2])
    elif t == BIND:
        return Bind(env, cs[0].name, cs[1])
    elif t == ITER:
        return Iter(env, cs[0].name, cs[1], cs[2])
    elif t == SERIES:
        return Series(env, cs[0:len(cs) - 1], cs[len(cs) - 1])
    elif t == ELEM:
        return reduce(lambda l, r: FcnCall(env, 'elem', [l, r]), cs[1:], cs[0])
    elif t == VAR:
        nm = tree.children[0].getText()
        if nm in env.statics:
            return FcnCall(env, 'load', [String(env, nm)])
        return Var(env, nm)
    elif t == LIT:
        val = tok.getText()
        return Lit(env, [float(val)])
    elif t == STRING:
        val = tok.getText()  # includes quote marks
        return String(env, val[1:len(val) - 1])
    elif t == ADDOP and len(cs) == 1:  # prefix +
        return cs[0]
    elif t == SUBTROP and len(cs) == 1:  # prefix -
        return FcnCall(env, '*', [Lit(env, [-1.0]), cs[0]])
    else:
        # due to some ambiguities in the processing, 'store' must be special
        # cased
        if tok.getText() == 'store' and len(cs) >= 1 and isinstance(cs[0], FcnCall) \
            
and cs[0].name == 'load':
            return FcnCall(env, 'store', [cs[0].children[0]] + cs[1:])
        return FcnCall(env, tok.getText(), cs) 
[docs]def parse_indices(indices_string):
    """
    This function parse indices string and return unique indices in ascending order.
    Note that it only supports range selection (using '-') and individual index.
    ' ' and ',' is separator in ASL.
    '7 3 4, 2-  ,, 7' is equivalent to '3, 7 4, 2-7'
    evaluate_asl and parse_indices does not agree on '-7, -3- ,,,4'. The former gives [5, 6, 7].
    This is not consistent with the definition of ASL.
    """
    def is_integer(t):
        ret = False
        try:
            int(t)
            ret = True
        except ValueError:
            pass
        return ret
    # ' ' and ',' are both considered as white space. Replace ',' with ' ' to make
    # it easier when calling split function
    s = indices_string.replace(',', ' ')
    s = s.strip()
    tokens = []
    for tail in s.split():
        while tail:
            head, dash, tail = tail.partition('-')
            head = head.strip()
            if head:
                if is_integer(head):
                    tokens.append(head)
                else:
                    raise RuntimeError("Failed to prase indices: %s" %
                                       indices_string)
            if dash:
                tokens.append(dash)
            if not dash or not tail:
                break
    i = 0
    stack = []
    indices = []
    while i < len(tokens):
        tok = tokens[i]
        if tok == '-':
            i += 1
            if i < len(tokens):
                tok = tokens[i]
                if is_integer(tok):
                    try:
                        # pop up from the stack when encounting '-'
                        begin = stack.pop()
                        indices.extend(list(range(int(begin), int(tok) + 1)))
                        indices.extend([int(e) for e in stack])
                        stack = []
                    except IndexError:
                        pass
                else:
                    raise RuntimeError("Failed to prase indices: %s" %
                                       indices_string)
            else:
                raise RuntimeError("Failed to prase indices: %s" %
                                   indices_string)
        else:
            stack.append(tok)
        i += 1
    # prepare unique indices in ascending order to match ASL's behavior
    indices.extend([int(e) for e in stack])
    indices = set(indices)
    indices = list(indices)
    return indices 
[docs]class ASLObject:
    _index_only_pattern = re.compile(r'atom.\s+(.*)')
[docs]    def __init__(self, model):
        self._cms = model 
[docs]    def atomsel(self, asl_expr):
        if self._cms:
            indices = numpy.asarray(self._cms.select_atom(str(asl_expr)))
            gids = self._cms.gid(indices)
        else:
            # FIXME
            # .cms is needed in order to translate front end config file to
            # backend config file. Unfortunately, to restart the simulation
            # from checkpoint file, one does not have a .cms file.
            # Below is a workaround to get gids without .cms file. It will only
            # work if gid = atid - 1 is true for all atoms. In other words, it
            # will fail on restarting FEP with metadynamics or the reaction
            # coordinate happen to involve molecules with virtual site.
            m = self._index_only_pattern.match(str(asl_expr))
            if m:
                atids = parse_indices(m.group(1))
                atids = numpy.array(atids) - 1
                gids = list(atids)
            else:
                raise RuntimeError(
                    "Failed to get gid from asl ('%s') without structure." %
                    asl_expr)
        return gids  
[docs]def procText(text):
    lex = mexpLexer(antlr3.StringStream(text))
    tokStream = antlr3.CommonTokenStream(lex)
    pt = mexpParser(tokStream).prog().tree
    #lex.reset()
    #print 'Lexing:'
    #for l in lex:
    #  print l.getText()
    #print
    #lex.reset()
    #print 'Parsing:'
    #for c in pt.children:
    #  print c.toStringTree()
    #print
    env = headerToEnv(pt.children[0])
    action = bodyToNode(pt.children[1], env)
    return (action, env) 
[docs]def resolve_atomsel(body, model):
    # VMD prints its banner to stdout but the banner needs to go to stderr
    sys.stdout.flush()
    sys.stderr.flush()
    sout = os.dup(sys.stdout.fileno())
    os.dup2(sys.stderr.fileno(), sys.stdout.fileno())
    #print 'using VMD to resolve atom selections ...\n'
    #import vmd
    #struct_type = struct_name.strip().split('.')[-1]
    #print 'using structure type %s' % struct_type
    #vmd.molecule.load(struct_type, struct_name)
    #print ''
    aslobj = ASLObject(model)
    sys.stdout.flush()
    sys.stderr.flush()
    os.dup2(sout, sys.stdout.fileno())
    gids = set()
    newbody = body.resolve_atomsel(aslobj, gids)
    return (newbody, gids) 
[docs]def parseStr(system, mexp):
    """
  return partial frontend config file that contains enhanced_sampling plugin.
  """
    action, env = procText(mexp)
    action_resolved, gids = resolve_atomsel(action, system)
    t = action_resolved.get_type(env)
    # note that constant folding assumes a well-typed program
    action_fold = action_resolved.constant_fold()
    gid_text = ' '.join(map(str, sorted(gids)))
    if type(t) is str or t != 1:
        raise TypeError(
            'Potential must be a length-1 array, but is currently ' +
            showtype(t))
    return '{type=enhanced_sampling gids=[%s] sexp=%s %s}' \
             
% (gid_text, str(action_fold), str(env)) 
[docs]def parse_mexpr(system, mexp):
    """
  return partial backend config file that contains enhanced_sampling plugin.
  """
    cfg = 'force.term{list[+]=ES ES=%s}' % parseStr(system, mexp)
    return cfg