import os
import re
import sys
from functools import reduce
from past.utils import old_div
import numpy
from . import antlr3
from .enhanced_sampling.FcnTypes import getFcnSigs
from .enhanced_sampling.mexpLexer import mexpLexer
from .enhanced_sampling.mexpParser import ADDOP
from .enhanced_sampling.mexpParser import BIND
from .enhanced_sampling.mexpParser import BLOCK
from .enhanced_sampling.mexpParser import CUTOFF
from .enhanced_sampling.mexpParser import DECL_META
from .enhanced_sampling.mexpParser import DECL_OUTPUT
from .enhanced_sampling.mexpParser import DIM
from .enhanced_sampling.mexpParser import ELEM
from .enhanced_sampling.mexpParser import FIRST
from .enhanced_sampling.mexpParser import HEADER
from .enhanced_sampling.mexpParser import IF
from .enhanced_sampling.mexpParser import INITKER
from .enhanced_sampling.mexpParser import INTERVAL
from .enhanced_sampling.mexpParser import ITER
from .enhanced_sampling.mexpParser import LIT
from .enhanced_sampling.mexpParser import NAME
from .enhanced_sampling.mexpParser import SERIES
from .enhanced_sampling.mexpParser import STATIC
from .enhanced_sampling.mexpParser import STRING
from .enhanced_sampling.mexpParser import SUBTROP
from .enhanced_sampling.mexpParser import VAR
from .enhanced_sampling.mexpParser import mexpParser
# FIXME add line numbers to errors
[docs]def showtype(i):
if isinstance(type(i), str):
return 'string'
return 'length-' + str(i) + ' array'
[docs]class Node(object):
[docs] def __init__(self, env, children):
self.children = children
self.env = env
#def __getitem__(self, i):
# self.children[i]
[docs] def resolve_atomsel(self, aslobj, gids):
self.children = [c.resolve_atomsel(aslobj, gids) for c in self.children]
return self
[docs] def constant_fold(self):
# constant folding routine may assume that the program is well-typed
self.children = [c.constant_fold() for c in self.children]
return self
# get_type validates the type correctness of the expression and
# returns the type of the node and validates
# get_type also ensures that all variables can be uniquely resolved
# in the given scope
[docs] def get_type(self):
raise ValueError('Internal error. Function not defined')
[docs]class Lit(Node):
[docs] def __init__(self, env, value): # should validate input
Node.__init__(self, env, [])
self.value = value
[docs] def get_type(self, env):
return len(self.value)
def __str__(self):
if len(self.value) == 1:
return repr(self.value[0])
else:
return '[literal %s]' % ' '.join([repr(v) for v in self.value])
[docs]class String(Node):
[docs] def __init__(self, env, value): # should validate input
Node.__init__(self, env, [])
self.value = value
[docs] def get_type(self, env):
return 'string'
# the following str assumes that there are no characters than need escaping
def __str__(self):
return '"' + self.value + '"'
[docs]class Var(Node):
[docs] def __init__(self, env, name):
Node.__init__(self, env, [])
self.name = name
[docs] def get_type(self, env):
try:
return env.binds[self.name]
except KeyError as e:
str = 'Variable %s unknown' % self.name
raise KeyError(str)
def __str__(self):
return '$' + self.name
[docs]class Bind(Node):
[docs] def __init__(self, env, name, value):
Node.__init__(self, env, [value])
self.name = name
[docs] def get_type(self, env):
return self.children[0].get_type(env)
def __str__(self):
return '[$%s %s]' % (str(self.name), str(self.children[0]))
[docs]class FcnCall(Node):
[docs] def __init__(self, env, name, children):
Node.__init__(self, env, children)
self.name = name
def __str__(self):
arglist = ' '.join([str(n) for n in self.children])
return '[%s %s]' % (self.name, arglist)
[docs] def resolve_atomsel(self, aslobj, gids):
if self.name == 'atomsel':
if len(self.children) != 1:
raise TypeError('atomsel takes one argument')
s = self.children[0]
if type(s) is String:
env = Env()
#l = aslobj.atomsel.atomsel(s.value).get('index')
sel = aslobj.atomsel(s.value)
gids.update(sel)
cs = [Lit(env, [float(i)]) for i in sel]
return FcnCall(env, 'array', cs)
else:
raise ValueError('Argument to atomsel must have type string')
else:
return Node.resolve_atomsel(self, aslobj, gids)
[docs] def constant_fold(self):
self.children = [c.constant_fold() for c in self.children]
if not all([isinstance(c, Lit) for c in self.children]):
return self
vals = [c.value for c in self.children]
if self.name == 'array':
elems = list()
for v in vals:
elems.extend(v)
return Lit(Env(), elems)
# this simple folding is primarily intended to handle cases where the
# parser interprets "-5.0" as [* -1 5.0]
def bin_thread(f, arg1, arg2):
if len(arg1) == 1:
x = arg1[0]
return [f(x, y) for y in arg2]
elif len(arg2) == 1:
y = arg2[0]
return [f(x, y) for x in arg1]
elif len(arg1) == len(arg2):
return [f(x, y) for x, y in zip(arg1, arg2)]
else:
raise RuntimeError(
'Internal error. Invalid type on binary operation')
if self.name == '+' and len(vals) == 2:
return Lit(Env(), bin_thread(lambda x, y: x + y, vals[0], vals[1]))
elif self.name == '*' and len(vals) == 2:
return Lit(Env(), bin_thread(lambda x, y: x * y, vals[0], vals[1]))
elif self.name == '-' and len(vals) == 2:
return Lit(Env(), bin_thread(lambda x, y: x - y, vals[0], vals[1]))
elif self.name == '/' and len(vals) == 2:
return Lit(Env(),
bin_thread(lambda x, y: old_div(x, y), vals[0], vals[1]))
else:
return self
[docs] def get_type(self, env):
if self.name == 'load':
if len(self.children) != 1:
raise TypeError('Wrong number of arguments to load')
arg = self.children[0]
t = arg.get_type(env)
if not isinstance(t, str):
raise TypeError('Must pass string to load, not ' + str(t))
if arg.value not in env.statics:
raise ValueError('unkown variable %s' % arg.value)
return env.statics[arg.value]
elif self.name == 'store':
if len(self.children) != 2:
raise TypeError('Wrong number of arguments to store')
arg0 = self.children[0]
t = arg0.get_type(env)
if not isinstance(t, str):
raise TypeError('Must pass string as argument 1 of store, not ' \
+ showtype(t))
if arg0.value not in env.statics:
raise ValueError('variable %s is unknown in store' % arg0.value)
tdes = env.statics[arg0.value]
t = self.children[1].get_type(env)
if tdes != t:
raise TypeError('attempt to store %s in %s, but %s has type %s' % \
(showtype(t), arg0.value, arg0.value, showtype(tdes)))
return env.statics[arg0.value]
else:
child_types = [c.get_type(env) for c in self.children]
# check if my function name exists
t = env.sigs[self.name].check(child_types)
if self.name == 'meta' and type(self.children[0]) is Lit \
and len(self.children[0].value) == 1:
mid = int(round(self.children[0].value[0]))
if mid < 0 or mid >= len(env.metas):
raise ValueError(
'metadynamics accumulator id outsides range of accumulators'
)
d = env.metas[mid].dim
if d != child_types[2]:
raise TypeError(('metadynamics accumulator %i has dimension %i' + \
' but was passed a %s collective variable') \
% (mid, d, showtype(child_types[2])))
return t
[docs]class Iter(Node):
[docs] def __init__(self, env, name, lb, ub):
Node.__init__(self, env, [lb, ub])
self.name = name
[docs] def get_type(self, env):
tl = self.children[0].get_type(env)
tu = self.children[1].get_type(env)
errmsg = '%s bound of iterator %s must be a length-1 array but is a %s'
if tl != 1:
raise TypeError(errmsg % ('Lower', self.name, showtype(tl)))
if tu != 1:
raise TypeError(errmsg % ('Upper', self.name, showtype(tu)))
return 1
def __str__(self):
return '[$%s %s %s]' % (self.name, self.children[0], self.children[1])
[docs]class Let(Node):
[docs] def __init__(self, env, binds, value):
# for each bind, if it is not an assignment, make a gensym
Node.__init__(self, env, binds + [value])
self.binds = binds
for i in range(len(self.binds)):
if not type(self.binds[i]) is Bind:
self.binds[i] = Bind(env, env.gensym(), self.binds[i])
[docs] def get_type(self, env):
env.binds.enter_scope()
for b in self.binds:
env.binds.add_var(b.name, b.get_type(env))
t = self.children[-1].get_type(env)
env.binds.leave_scope()
return t
def __str__(self):
bindlist = ' '.join(map(str, self.binds))
return '[let [%s] %s]' % (bindlist, str(self.children[-1]))
[docs]class Series(Node):
[docs] def __init__(self, env, iters, value):
Node.__init__(self, env, iters + [value])
self.iters = iters
[docs] def get_type(self, env):
env.binds.enter_scope()
for b in self.iters:
env.binds.add_var(b.name, b.get_type(env))
t = self.children[-1].get_type(env)
env.binds.leave_scope()
return t
def __str__(self):
bindlist = ' '.join(map(str, self.iters))
return '[series [%s] %s]' % (bindlist, str(self.children[-1]))
[docs]class If(Node):
[docs] def __init__(self, env, cond, then_case, else_case):
Node.__init__(self, env, [cond, then_case, else_case])
[docs] def get_type(self, env):
tc = self.children[0].get_type(env)
tt = self.children[1].get_type(env)
te = self.children[2].get_type(env)
if tc != 1:
raise TypeError(
'Condition of if expression must have type 1 but has type %s' %
showtype(tc))
if tt != te:
raise TypeError(
'The branches of the if expression must have the same types but ' \
'the types are %s and %s' % (showtype(tt), showtype(te)))
return tt
def __str__(self):
return '[if %s %s %s]' % (str(self.children[0]), \
str(self.children[1]), \
str(self.children[2]))
# bindings represents a scoped mapping from names to Node's
[docs]class binding(object):
[docs] def __init__(self):
self.bs = [{}]
[docs] def enter_scope(self):
self.bs = [{}] + self.bs
[docs] def leave_scope(self):
self.bs = self.bs[1:]
[docs] def add_var(self, name, tp):
if name in self.bs[0]:
raise ValueError('%s declared twice in the same scope' % name)
self.bs[0][name] = tp
def __getitem__(self, it):
#print str(self.bs)
for d in self.bs:
if it in d:
return d[it]
# if we reach this statement, it has not been declared
raise KeyError
def __str__(self):
s = ''
for d in self.bs:
s = s + str(d)
return s
[docs]class Env(object):
[docs] def __init__(self):
self.statics = {}
self.metas = []
self.gensym_cnt = 0
self.output = None
self.binds = binding()
self.sigs = getFcnSigs()
[docs] def gensym(self):
i = self.gensym_cnt
self.gensym_cnt = i + 1
return 'gensym' + str(i)
[docs] def add_static(self, nm, type):
if nm in self.statics:
raise ValueError('Variable (%s) bound twice in same scope' % nm)
if type < 0:
raise ValueError("Type of static variable %s cannot be negative" %
nm)
self.statics[nm] = type
[docs] def add_output(self, nm, first, interval):
if not (self.output is None):
raise "Output cannot be declared twice"
self.output = (nm, first, interval)
def __str__(self):
ret = ''
d = self.statics
ret = ret + 'metadynamics_accumulators=[%s] ' % \
(' '.join(map(str, self.metas)))
ret = ret + \
"storage={%s} " % ' '.join(['%s=%i' % (k, d[k]) for k in sorted(d)])
if self.output is not None:
name, first, interval = self.output
else:
name, first, interval = ("", 0.0, 0.0)
ret = ret + 'name="%s" first=%s interval=%s' \
% (name, repr(first), repr(interval))
return ret
[docs]def bodyToNode(tree, env):
tok = tree.getToken()
if tok is None:
sys.stderr.write('failed to parse m-expression\n')
exit(1)
t = tok.getType()
cs = []
for c in tree.children:
cs.append(bodyToNode(c, env))
if t == BLOCK: # FIXME must handle this better for gensym
if len(cs) == 0:
EmptyBlock = "Error: empty block"
raise EmptyBlock
if len(cs) == 1:
return cs[0]
return Let(env, cs[0:len(cs) - 1], cs[len(cs) - 1])
elif t == IF:
return If(env, cs[0], cs[1], cs[2])
elif t == BIND:
return Bind(env, cs[0].name, cs[1])
elif t == ITER:
return Iter(env, cs[0].name, cs[1], cs[2])
elif t == SERIES:
return Series(env, cs[0:len(cs) - 1], cs[len(cs) - 1])
elif t == ELEM:
return reduce(lambda l, r: FcnCall(env, 'elem', [l, r]), cs[1:], cs[0])
elif t == VAR:
nm = tree.children[0].getText()
if nm in env.statics:
return FcnCall(env, 'load', [String(env, nm)])
return Var(env, nm)
elif t == LIT:
val = tok.getText()
return Lit(env, [float(val)])
elif t == STRING:
val = tok.getText() # includes quote marks
return String(env, val[1:len(val) - 1])
elif t == ADDOP and len(cs) == 1: # prefix +
return cs[0]
elif t == SUBTROP and len(cs) == 1: # prefix -
return FcnCall(env, '*', [Lit(env, [-1.0]), cs[0]])
else:
# due to some ambiguities in the processing, 'store' must be special
# cased
if tok.getText() == 'store' and len(cs) >= 1 and isinstance(cs[0], FcnCall) \
and cs[0].name == 'load':
return FcnCall(env, 'store', [cs[0].children[0]] + cs[1:])
return FcnCall(env, tok.getText(), cs)
[docs]def parse_indices(indices_string):
"""
This function parse indices string and return unique indices in ascending order.
Note that it only supports range selection (using '-') and individual index.
' ' and ',' is separator in ASL.
'7 3 4, 2- ,, 7' is equivalent to '3, 7 4, 2-7'
evaluate_asl and parse_indices does not agree on '-7, -3- ,,,4'. The former gives [5, 6, 7].
This is not consistent with the definition of ASL.
"""
def is_integer(t):
ret = False
try:
int(t)
ret = True
except ValueError:
pass
return ret
# ' ' and ',' are both considered as white space. Replace ',' with ' ' to make
# it easier when calling split function
s = indices_string.replace(',', ' ')
s = s.strip()
tokens = []
for tail in s.split():
while tail:
head, dash, tail = tail.partition('-')
head = head.strip()
if head:
if is_integer(head):
tokens.append(head)
else:
raise RuntimeError("Failed to prase indices: %s" %
indices_string)
if dash:
tokens.append(dash)
if not dash or not tail:
break
i = 0
stack = []
indices = []
while i < len(tokens):
tok = tokens[i]
if tok == '-':
i += 1
if i < len(tokens):
tok = tokens[i]
if is_integer(tok):
try:
# pop up from the stack when encounting '-'
begin = stack.pop()
indices.extend(list(range(int(begin), int(tok) + 1)))
indices.extend([int(e) for e in stack])
stack = []
except IndexError:
pass
else:
raise RuntimeError("Failed to prase indices: %s" %
indices_string)
else:
raise RuntimeError("Failed to prase indices: %s" %
indices_string)
else:
stack.append(tok)
i += 1
# prepare unique indices in ascending order to match ASL's behavior
indices.extend([int(e) for e in stack])
indices = set(indices)
indices = list(indices)
return indices
[docs]class ASLObject:
_index_only_pattern = re.compile(r'atom.\s+(.*)')
[docs] def __init__(self, model):
self._cms = model
[docs] def atomsel(self, asl_expr):
if self._cms:
indices = numpy.asarray(self._cms.select_atom(str(asl_expr)))
gids = self._cms.gid(indices)
else:
# FIXME
# .cms is needed in order to translate front end config file to
# backend config file. Unfortunately, to restart the simulation
# from checkpoint file, one does not have a .cms file.
# Below is a workaround to get gids without .cms file. It will only
# work if gid = atid - 1 is true for all atoms. In other words, it
# will fail on restarting FEP with metadynamics or the reaction
# coordinate happen to involve molecules with virtual site.
m = self._index_only_pattern.match(str(asl_expr))
if m:
atids = parse_indices(m.group(1))
atids = numpy.array(atids) - 1
gids = list(atids)
else:
raise RuntimeError(
"Failed to get gid from asl ('%s') without structure." %
asl_expr)
return gids
[docs]def procText(text):
lex = mexpLexer(antlr3.StringStream(text))
tokStream = antlr3.CommonTokenStream(lex)
pt = mexpParser(tokStream).prog().tree
#lex.reset()
#print 'Lexing:'
#for l in lex:
# print l.getText()
#print
#lex.reset()
#print 'Parsing:'
#for c in pt.children:
# print c.toStringTree()
#print
env = headerToEnv(pt.children[0])
action = bodyToNode(pt.children[1], env)
return (action, env)
[docs]def resolve_atomsel(body, model):
# VMD prints its banner to stdout but the banner needs to go to stderr
sys.stdout.flush()
sys.stderr.flush()
sout = os.dup(sys.stdout.fileno())
os.dup2(sys.stderr.fileno(), sys.stdout.fileno())
#print 'using VMD to resolve atom selections ...\n'
#import vmd
#struct_type = struct_name.strip().split('.')[-1]
#print 'using structure type %s' % struct_type
#vmd.molecule.load(struct_type, struct_name)
#print ''
aslobj = ASLObject(model)
sys.stdout.flush()
sys.stderr.flush()
os.dup2(sout, sys.stdout.fileno())
gids = set()
newbody = body.resolve_atomsel(aslobj, gids)
return (newbody, gids)
[docs]def parseStr(system, mexp):
"""
return partial frontend config file that contains enhanced_sampling plugin.
"""
action, env = procText(mexp)
action_resolved, gids = resolve_atomsel(action, system)
t = action_resolved.get_type(env)
# note that constant folding assumes a well-typed program
action_fold = action_resolved.constant_fold()
gid_text = ' '.join(map(str, sorted(gids)))
if type(t) is str or t != 1:
raise TypeError(
'Potential must be a length-1 array, but is currently ' +
showtype(t))
return '{type=enhanced_sampling gids=[%s] sexp=%s %s}' \
% (gid_text, str(action_fold), str(env))
[docs]def parse_mexpr(system, mexp):
"""
return partial backend config file that contains enhanced_sampling plugin.
"""
cfg = 'force.term{list[+]=ES ES=%s}' % parseStr(system, mexp)
return cfg