"""
Module for parameter validation. See `schrodinger.utils.sea` for more details.
Copyright Schrodinger, LLC. All rights reserved.
"""
import inspect
import os
import re
from copy import deepcopy
from .common import boolean
from .common import debug_print
from .common import is_equal
from .sea import Atom
from .sea import List
from .sea import Map
[docs]class Evalor:
"""
This is the evaluator class for checking validity of parameters.
"""
__slots__ = [
"_map",
"_err_break",
"_err",
"_unchecked_map",
]
[docs] def __init__(self, map, err_break="\n\n"):
"""
:param map: 'map' contains all parameters to be checked.
"""
self._map = map
self._err_break = err_break
self._err = ""
self._unchecked_map = []
def __call__(self, arg):
"""
:param arg: The validation criteria.
"""
return _eval(self._map, arg)
@property
def err(self):
return self._err
[docs] def is_ok(self):
"""
Returns true if there is no error and unchecked maps.
"""
return (not self._err and not self._unchecked_map)
[docs] def record_error(self, mapname=None, err=""):
"""
Records the error.
:param mapname: The name of the checked parameter.
:param err: The error message.
"""
debug_print("ERROR\n%s" % err)
if (mapname is not None):
self._err += mapname[1:] + ": "
self._err += err + self._err_break
@property
def unchecked_map(self):
"""
Returns a string that tell which parameters have not been checked.
"""
s = ""
for k in self._unchecked_map:
s += k[1:] + " "
return s
[docs] def copy_from(self, ev):
"""
Makes a copy from 'ev'.
:param ev: A 'Evalor' object.
"""
self._map = ev._map
self._err = ev._err
self._unchecked_map = ev._unchecked_map
[docs]def check_map(map, valid, ev, tag=set()): # noqa: M511
"""
Checks the validity of a map.
"""
if (not map.has_tag(tag)):
debug_print("(none is tagged with: %s)" % (", ".join(tag)))
return
map = map.sval
_check_map(map, valid, ev, "", tag)
debug_print("\nUnchecked maps:")
if (ev._unchecked_map == []):
debug_print("(none)")
else:
debug_print(ev.unchecked_map)
debug_print("\nError summary:")
if (ev._err == ""):
debug_print("(none)")
return
else:
debug_print(ev._err)
return ev._err
def __op_mul(map, arg):
"""
Evaluates the "multiplication" expression and returns product of the arg[0], arg[1], arg[3], ...
:param arg: The 'arg' should be a 'sea.List' object that contains two or more elements.
:param map: The original map that the elements in the 'arg' refer.
"""
prod = 1.0
for e in arg:
prod *= _eval(map, e)
return prod
def __op_eq(map, arg):
"""
Evaluates the "equal" expression and returns True the arg[0] and arg[1] are equal or False otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
a = _eval(map, arg[0])
b = _eval(map, arg[1])
if (isinstance(a, float) or isinstance(b, float)):
return is_equal(a, b)
return a == b
def __op_lt(map, arg):
"""
Evaluates the "less than" expression and returns True the arg[0] is less than arg[1] or False otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) < _eval(map, arg[1])
def __op_le(map, arg):
"""
Evaluates the "less or equal" expression and returns True the arg[0] is less than or equal to arg[1] or False otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) <= _eval(map, arg[1])
def __op_gt(map, arg):
"""
Evaluates the "greater than" expression and returns True the arg[0] is greater than arg[1] or False otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) > _eval(map, arg[1])
def __op_ge(map, arg):
"""
Evaluates the "greater or equal" expression and returns True the arg[0] is greater than or equal to arg[1] or False
otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) >= _eval(map, arg[1])
def __op_and(map, arg):
"""
Evaluates the "logic and" expression and returns True if both arg[0] and arg[1] are true.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) and _eval(map, arg[1])
def __op_or(map, arg):
"""
Evaluates the "logic or" expression and returns True if either arg[0] or arg[1] is true.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) or _eval(map, arg[1])
def __op_not(map, arg):
"""
Evaluates the "logic not" expression and returns True if arg[0] is false or False if arg[0] is true.
:param arg: The 'arg' should be a 'sea.List' object that contain only 1 element. More than 1 elements will cause a
'ValueError' exception.
:param map: The original map that the elements in the 'arg' refer.
"""
if (len(arg) != 1):
raise ValueError(
"'__op_not' function expects 1 argument, but there are %d" %
len(arg))
return not _eval(map, arg[0])
def __op_at(map, arg):
"""
Evaluates the "at" expression and returns the referenced value.
:param arg: The 'arg' should be a 'sea.List' object that contain only 1 element. More than 1 elements will cause a
'ValueError' exception.
:param map: The original map that the elements in the 'arg' refer.
"""
if (len(arg) != 1):
raise ValueError(
"'__op_at' function expects 1 argument, but there are %d" %
len(arg))
k = map[_eval(map, arg[0])]
try:
return k.val
except AttributeError:
return k
def __op_minus(map, arg):
"""
Evaluates the "minus" expression and returns arithmatic result (the difference between two values, or the negative value).
:param arg: The 'arg' should be a 'sea.List' object that contains at most two elements. More than two elements will cause a
'ValueError' exception.
:param map: The original map that the elements in the 'arg' refer.
"""
num_arg = len(arg)
if (num_arg > 2):
raise ValueError(
"'__op_minus' function expects 1 or 2 arguments, but there are %d" %
len(arg))
if (num_arg == 1):
return -_eval(map, arg[0])
else:
return _eval(map, arg[0]) - _eval(map, arg[1])
def __op_cat(map, arg):
"""
Contatenate two strings and returns the result.
:param arg: The 'arg' should be a 'sea.List' object that contains at least 1 elements.
:param map: The original map that the elements in the 'arg' refer.
"""
if (len(arg) < 1):
raise ValueError(
"'__op_cat' function expects at least 1 argument, but there is none"
)
ret = ""
for a in arg:
ret += str(_eval(map, a))
return ret
def __op_sizeof(map, arg):
"""
Evaluates the "less than" expression and returns True the arg[0] is less than arg[1] or False otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
if (len(arg) != 1):
raise ValueError(
"'__op_sizeof' function expects 1 argument, but there are %d" %
len(arg))
return len(_eval(map, arg[0]))
[docs]def is_powerof2(x):
"""
Returns True if 'x' is a power of 2, or False otherwise.
"""
return not (x & (x - 1))
def _regex_match(pattern):
"""
"""
return lambda s: re.match(pattern, s)
def _xchk_power2(map, valid, ev, prefix):
"""
This is an external checker. It checks whether an integer value is power of 2 or not.
:param map: 'map' contains the value to be checked. Use 'map.val' to get the value.
:param valid: 'valid' contains the validation criteria for the to-be-checked value.
:param ev: The evaluator, where the error messeages are collected.
:param prefix: The prefix of the checked parameter.
"""
val = map.val
if (not is_powerof2(val)):
debug_print("Error:\nValue %d is not an integer of power of 2" % val)
ev.record_error(prefix,
"Value %d is not an integer of power of 2" % val)
else:
debug_print("OK - value is an integer of powere of 2")
def _xchk_file_exists(map, valid, ev, prefix):
"""
This is an external checker. It checks whether a file (not a dir) exists.
:param map: 'map' contains the value to be checked. Use 'map.val' to get the valuefile name.
:param valid: 'valid' contains the validation criteria for the to-be-checked value.
:param ev: The evaluator, where the error messeages are collected.
:param prefix: The prefix of the checked parameter.
"""
val = map.val
if (val != "" and not os.path.isfile(val)):
debug_print("Error:\nFile not found: %s" % val)
ev.record_error(prefix, "File not found: %s" % val)
else:
debug_print("OK - file exists")
def _xchk_dir_exists(map, valid, ev, prefix):
"""
This is an external checker. It checks whether a dir (not a file) exists.
:param map: 'map' contains the value to be checked. Use 'map.val' to get the valuefile name.
:param valid: 'valid' contains the validation criteria for the to-be-checked value.
:param ev: The evaluator, where the error messeages are collected.
:param prefix: The prefix of the checked parameter.
"""
val = map.val
if (val != "" and not os.path.isdir(val)):
debug_print("Error:\nDirectory not found: %s" % val)
ev.record_error(prefix, "Directory not found: %s" % val)
else:
debug_print("OK - Directory exists")
def _eval(map, arg):
"""
Evaluates the expression and returns the results.
:param arg: 'arg' can be either a 'sea.List' object or a 'sea.Atom' object, representing a prefix expression.
:param map: The original map that the elements in the 'arg' refer.
"""
if (isinstance(arg, List)):
val0 = _eval(map, arg[0])
if (isinstance(val0, str)):
val0 = val0.strip()
if (val0 in __OP):
a = arg[1:]
return __OP[val0](map, arg[1:])
return [_eval(map, e) for e in arg]
else:
val = arg.val
if (val in ['-', '@', '']):
return val
try:
if (val[0] == "@"):
k = map[val[1:]]
try:
return k.val
except AttributeError:
return k
except TypeError:
pass
return val
__OP = {
"*": __op_mul,
"==": __op_eq,
"<": __op_lt,
"<=": __op_le,
">": __op_gt,
">=": __op_gt,
"&&": __op_and,
"||": __op_or,
"!": __op_not,
"@": __op_at,
"-": __op_minus,
"cat": __op_cat,
"sizeof": __op_sizeof,
}
__TYPE = {
"str": str,
"str1": (
str,
[1, 1000000000],
),
"float": float,
"float+": (
float,
[0, float("inf")],
),
"float-": (
float,
[float("-inf"), 0],
),
"float0_1": (
float,
[0, 1.0],
),
"int": int,
"int0": (
int,
[0, 1000000000],
),
"int1": (
int,
[1, 1000000000],
),
"bool": boolean,
"bool0": (
boolean,
[False],
),
"bool1": (
boolean,
[True],
),
"enum": str,
"list": list,
"none": None,
"regex": _regex_match,
}
__CONVERTIBLE_TO = {
int: [float, str],
float: [str],
}
__xcheck = {
"power2": _xchk_power2,
"file_exists": _xchk_file_exists,
"dir_exists": _xchk_dir_exists,
}
[docs]def reg_xcheck(name, func):
"""
Registers external checker.
:param name: Name of the checker.
:param func: Callable object that checks validity of a parameter. For interface requirement, see '_xchk_power2', or
'_xchk_file_exists', or '_xchk_dir_exists' for example.
"""
__xcheck[name] = func
def _match(map, valid, ev, prefix, tag):
"""
Finds the best match.
"""
kk = map
vv = valid
ev_list = []
for vv_ in vv:
try:
_if = ev(vv._if)
except AttributeError:
pass
else:
debug_print("_if: {} = {}".format(
str(vv._if),
_if,
), False)
if (_if):
debug_print("True")
else:
debug_print("False - Skip checking the whole map.")
return
ev_ = deepcopy(ev)
_check_map(kk, vv_, ev_, prefix)
ev_list.append(ev_)
if (ev_list != []):
# Tries to find the best match.
candidate = [
ev_list[0],
]
least = len(candidate[0]._unchecked_map)
for ev_ in ev_list[1:]:
num = len(ev_._unchecked_map)
if (num < least):
candidate = [
ev_,
]
least = num
elif (num == least):
candidate.append(ev_)
best_ev = []
for ev_ in candidate:
if (ev_._err == ev._err):
best_ev.append(ev_)
if (best_ev == []):
best_ev = candidate
candidate = best_ev
best_ev = candidate[0]
least = best_ev._err.count("Wrong type:")
if (len(candidate) > 1):
for ev_ in candidate[1:]:
num = ev_._err.count("Wrong type:")
if (num < least):
best_ev = ev_
least = num
ev.copy_from(best_ev)
def _check_atom(atom, valid, ev, prefix):
"""
Checks the validity of atom.
"""
rr = None # Range
# type
debug_print(prefix + ":")
debug_print(" checking its type...", False)
try:
t = ev(valid.type)
if (t.startswith("regex:")):
tt = __TYPE["regex"](t[6:])
else:
tt = __TYPE[t]
if (isinstance(tt, tuple)):
tt, rr = tt[0], tt[1]
except AttributeError:
ev.record_error(
prefix,
"Wrong type: expecting a composite parameter, but got an atom")
return
except KeyError:
ev.record_error(
prefix,
"Wrong type: %s. 'type' is likely a parameter than a description." %
t)
return
atom_val = atom.val
if (atom_val is None):
if (tt is None):
debug_print("OK - value None is acceptable")
else:
ev.record_error(prefix,
"Wrong value: expecting %s, but got None" % str(tt))
return
if (atom._type == str and inspect.isfunction(tt) and tt != boolean):
if (tt(atom_val)):
debug_print("OK - {} matches the pattern: {}".format(
atom_val, t[6:]))
else:
ev.record_error(
prefix,
"Wrong type: expecting a string matching {}, but got {}".format(
t[6:],
atom_val,
))
return
elif (atom._type != tt and (atom._type not in __CONVERTIBLE_TO or
tt not in __CONVERTIBLE_TO[atom._type])):
ev.record_error(
prefix, "Wrong type: expecting {}, but got {}".format(
"boolean" if tt == boolean else str(tt),
str(atom._type),
))
return
else:
debug_print("OK - %s" % t)
# range
debug_print(" checking its range...", False)
try:
if (rr is None):
rr = ev(valid.range)
except AttributeError:
debug_print("N/A")
else:
if (t == "enum" or tt == boolean):
if (atom_val not in rr):
ev.record_error(
prefix,
"Wrong value: should be one of {}, but got '{}'".format(
str(rr),
str(atom_val),
))
else:
debug_print("OK - '{}' is one of {}".format(
str(atom_val),
str(rr),
))
elif (tt == str):
if (atom._type != tt):
atom_val = str(atom_val)
length = len(atom_val)
if (length > int(rr[1])):
ev.record_error(prefix,
"String is too long (%d char's)" % length)
elif (length < int(rr[0])):
ev.record_error(
prefix,
"String is too short: it must have at least %d char's" %
rr[0])
else:
debug_print("OK - string has %d char's" % length)
else:
if (atom_val > tt(rr[1]) or atom_val < tt(rr[0])):
ev.record_error(
prefix,
"Value out of range: expecting within %s, but got '%s'" %
(str(rr), str(atom_val)))
else:
debug_print("OK - {} is within {}".format(
str(atom_val),
str(rr),
))
# _check
try:
cc = valid._check
except AttributeError:
pass
else:
debug_print(" external checking...")
if (isinstance(cc, List)):
for e in cc:
debug_print(" %s: " % e.val, False)
__xcheck[e.val](atom, valid, ev, prefix)
elif (cc.val != ""):
debug_print(" %s: " % cc.val, False)
__xcheck[cc.val](atom, valid, ev, prefix)
def _check_list(map, valid, ev, prefix, tag):
"""
Checks the validity of list.
"""
kk = map
vv = valid
# type
debug_print(prefix + ":")
debug_print(" checking its type...", False)
try:
t = ev(vv.type)
tt = __TYPE[t]
except AttributeError:
ev.record_error(
prefix,
"Wrong type: expecting a composite parameter, but got a list")
return
if (tt != list):
ev.record_error(
prefix, "Wrong type: expecting %s, but got <type 'list'>" % str(tt))
return
debug_print("OK - list")
# size
try:
debug_print(" checking its size...", False)
size = ev(vv.size)
ll = len(kk)
if (size > 0 and ll != size):
ev.record_error(
prefix,
"Wrong list length: expecting %d elements, but got %d" % (
size,
ll,
))
elif (size < 0 and ll < -size):
ev.record_error(
prefix,
"Wrong list length: expecting at least %d elements, but got %d"
% (
-size,
ll,
))
else:
debug_print("OK - %d" % ll)
except AttributeError:
pass
# elem
debug_print(" checking each element in list...", False)
try:
if (isinstance(vv.elem, List)):
lv, lk = len(vv.elem), len(kk)
[
_check_map(k, v, ev, ("%s[%d]" % (prefix, i)), tag)
for i, k, v in zip(list(range(lv)), kk, vv.elem)
]
if (lv < lk):
v = vv.elem[-1]
[
_check_map(kk[i], v, ev, ("%s[%d]" % (prefix, i)), tag)
for i in range(lv, lk)
]
else:
[
_check_map(elem, vv.elem, ev, ("%s[%d]" % (prefix, i)), tag)
for i, elem in enumerate(kk)
]
except AttributeError:
debug_print("OK - No requirement for elements")
# _check
try:
cc = vv._check
debug_print(" external checking for the whole list...")
if (isinstance(cc, List)):
for e in cc:
debug_print(" %s: " % e.val, False)
__xcheck[e.val](map, valid, ev, prefix)
elif (cc.val != ""):
debug_print(" %s: " % cc.val, False)
__xcheck[cc.val](map, valid, ev, prefix)
except AttributeError:
pass
def _check_map(map, valid, ev, prefix="", tag=set()): # noqa: M511
"""
Checks the validity of a map.
"""
# _if
try:
_if = ev(valid._if)
except AttributeError:
pass
else:
debug_print("_if: {} = {}".format(
str(valid._if),
_if,
), False)
if (_if):
debug_print("True")
else:
debug_print("False - Skip checking the whole map.")
return
if (isinstance(valid, List)):
return _match(map, valid, ev, prefix, tag)
if (isinstance(map, Atom)):
_check_atom(map, valid, ev, prefix)
elif (isinstance(map, List)):
_check_list(map, valid, ev, prefix, tag)
elif (isinstance(map, Map)):
# _skip
try:
skip = valid._skip.val
except AttributeError:
skip = []
else:
if (not isinstance(skip, list) and skip != "all"):
raise ValueError(
"_skip must be either a list of strings or the string \"all\""
)
# _mapcheck
try:
cc = valid._mapcheck
except AttributeError:
pass
else:
debug_print(prefix + ":")
debug_print(" external checking for the whole map...")
if (isinstance(cc, List)):
for e in cc:
debug_print(" %s: " % e.val, False)
__xcheck[e.val](map, valid, ev, prefix)
elif (cc.val != ""):
debug_print(" %s: " % cc.val, False)
__xcheck[cc.val](map, valid, ev, prefix)
# _enforce
try:
cc = valid._enforce
except AttributeError:
pass
else:
if (not isinstance(cc, List)):
raise ValueError("_enforce must be a list of strings")
debug_print(prefix + ":")
debug_print(" enforcing keys...", False)
missing_key = [e for e in cc.val if (e not in map)]
missing_key = ", ".join(missing_key)
if (missing_key == ""):
debug_print("OK - All enforced keys present")
else:
debug_print("Error\nMissing keys: " + missing_key[0:-2])
ev.record_error(prefix, "Missing keys: " + missing_key[0:-2])
if ("all" != skip):
# Key-value pairs
key_value = [
(k, kk) for k, kk in map.key_value(tag) if (k not in skip)
]
for k, kk in key_value:
try:
vv = valid[k]
except KeyError:
ev._unchecked_map.append(prefix + '.' + k)
continue
_check_map(kk, vv, ev, prefix + '.' + k, tag)