"""
Module for parameter validation. See `schrodinger.utils.sea` for more details.
Copyright Schrodinger, LLC. All rights reserved.
"""
import inspect
import os
import re
from copy import deepcopy
from .common import boolean
from .common import debug_print
from .common import is_equal
from .sea import Atom
from .sea import List
from .sea import Map
[docs]class Evalor:
    """
    This is the evaluator class for checking validity of parameters.
    """
    __slots__ = [
        "_map",
        "_err_break",
        "_err",
        "_unchecked_map",
    ]
[docs]    def __init__(self, map, err_break="\n\n"):
        """
        :param map: 'map' contains all parameters to be checked.
        """
        self._map = map
        self._err_break = err_break
        self._err = ""
        self._unchecked_map = [] 
    def __call__(self, arg):
        """
        :param arg: The validation criteria.
        """
        return _eval(self._map, arg)
    @property
    def err(self):
        return self._err
[docs]    def is_ok(self):
        """
        Returns true if there is no error and unchecked maps.
        """
        return (not self._err and not self._unchecked_map) 
[docs]    def record_error(self, mapname=None, err=""):
        """
        Records the error.
        :param mapname: The name of the checked parameter.
        :param err: The error message.
        """
        debug_print("ERROR\n%s" % err)
        if (mapname is not None):
            self._err += mapname[1:] + ": "
        self._err += err + self._err_break 
    @property
    def unchecked_map(self):
        """
        Returns a string that tell which parameters have not been checked.
        """
        s = ""
        for k in self._unchecked_map:
            s += k[1:] + " "
        return s
[docs]    def copy_from(self, ev):
        """
        Makes a copy from 'ev'.
        :param ev: A 'Evalor' object.
        """
        self._map = ev._map
        self._err = ev._err
        self._unchecked_map = ev._unchecked_map  
[docs]def check_map(map, valid, ev, tag=set()):  # noqa: M511
    """
    Checks the validity of a map.
    """
    if (not map.has_tag(tag)):
        debug_print("(none is tagged with: %s)" % (", ".join(tag)))
        return
    map = map.sval
    _check_map(map, valid, ev, "", tag)
    debug_print("\nUnchecked maps:")
    if (ev._unchecked_map == []):
        debug_print("(none)")
    else:
        debug_print(ev.unchecked_map)
    debug_print("\nError summary:")
    if (ev._err == ""):
        debug_print("(none)")
        return
    else:
        debug_print(ev._err)
        return ev._err 
def __op_mul(map, arg):
    """
    Evaluates the "multiplication" expression and returns product of the arg[0], arg[1], arg[3], ...
    :param arg: The 'arg' should be a 'sea.List' object that contains two or more elements.
    :param map: The original map that the elements in the 'arg' refer.
    """
    prod = 1.0
    for e in arg:
        prod *= _eval(map, e)
    return prod
def __op_eq(map, arg):
    """
    Evaluates the "equal" expression and returns True the arg[0] and arg[1] are equal or False otherwise.
    :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
            ignored.
    :param map: The original map that the elements in the 'arg' refer.
    """
    a = _eval(map, arg[0])
    b = _eval(map, arg[1])
    if (isinstance(a, float) or isinstance(b, float)):
        return is_equal(a, b)
    return a == b
def __op_lt(map, arg):
    """
    Evaluates the "less than" expression and returns True the arg[0] is less than arg[1] or False otherwise.
    :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
            ignored.
    :param map: The original map that the elements in the 'arg' refer.
    """
    return _eval(map, arg[0]) < _eval(map, arg[1])
def __op_le(map, arg):
    """
    Evaluates the "less or equal" expression and returns True the arg[0] is less than or equal to arg[1] or False otherwise.
    :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
            ignored.
    :param map: The original map that the elements in the 'arg' refer.
    """
    return _eval(map, arg[0]) <= _eval(map, arg[1])
def __op_gt(map, arg):
    """
    Evaluates the "greater than" expression and returns True the arg[0] is greater than arg[1] or False otherwise.
    :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
            ignored.
    :param map: The original map that the elements in the 'arg' refer.
    """
    return _eval(map, arg[0]) > _eval(map, arg[1])
def __op_ge(map, arg):
    """
    Evaluates the "greater or equal" expression and returns True the arg[0] is greater than or equal to arg[1] or False
    otherwise.
    :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
            ignored.
    :param map: The original map that the elements in the 'arg' refer.
    """
    return _eval(map, arg[0]) >= _eval(map, arg[1])
def __op_and(map, arg):
    """
    Evaluates the "logic and" expression and returns True if both arg[0] and arg[1] are true.
    :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
            ignored.
    :param map: The original map that the elements in the 'arg' refer.
    """
    return _eval(map, arg[0]) and _eval(map, arg[1])
def __op_or(map, arg):
    """
    Evaluates the "logic or" expression and returns True if either arg[0] or arg[1] is true.
    :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
            ignored.
    :param map: The original map that the elements in the 'arg' refer.
    """
    return _eval(map, arg[0]) or _eval(map, arg[1])
def __op_not(map, arg):
    """
    Evaluates the "logic not" expression and returns True if arg[0] is false or False if arg[0] is true.
    :param arg: The 'arg' should be a 'sea.List' object that contain only 1 element. More than 1 elements will cause a
            'ValueError' exception.
    :param map: The original map that the elements in the 'arg' refer.
    """
    if (len(arg) != 1):
        raise ValueError(
            "'__op_not' function expects 1 argument, but there are %d" %
            len(arg))
    return not _eval(map, arg[0])
def __op_at(map, arg):
    """
    Evaluates the "at" expression and returns the referenced value.
    :param arg: The 'arg' should be a 'sea.List' object that contain only 1 element. More than 1 elements will cause a
            'ValueError' exception.
    :param map: The original map that the elements in the 'arg' refer.
    """
    if (len(arg) != 1):
        raise ValueError(
            "'__op_at' function expects 1 argument, but there are %d" %
            len(arg))
    k = map[_eval(map, arg[0])]
    try:
        return k.val
    except AttributeError:
        return k
def __op_minus(map, arg):
    """
    Evaluates the "minus" expression and returns arithmatic result (the difference between two values, or the negative value).
    :param arg: The 'arg' should be a 'sea.List' object that contains at most two elements. More than two elements will cause a
            'ValueError' exception.
    :param map: The original map that the elements in the 'arg' refer.
    """
    num_arg = len(arg)
    if (num_arg > 2):
        raise ValueError(
            "'__op_minus' function expects 1 or 2 arguments, but there are %d" %
            len(arg))
    if (num_arg == 1):
        return -_eval(map, arg[0])
    else:
        return _eval(map, arg[0]) - _eval(map, arg[1])
def __op_cat(map, arg):
    """
    Contatenate two strings and returns the result.
    :param arg: The 'arg' should be a 'sea.List' object that contains at least 1 elements.
    :param map: The original map that the elements in the 'arg' refer.
    """
    if (len(arg) < 1):
        raise ValueError(
            "'__op_cat' function expects at least 1 argument, but there is none"
        )
    ret = ""
    for a in arg:
        ret += str(_eval(map, a))
    return ret
def __op_sizeof(map, arg):
    """
    Evaluates the "less than" expression and returns True the arg[0] is less than arg[1] or False otherwise.
    :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
            ignored.
    :param map: The original map that the elements in the 'arg' refer.
    """
    if (len(arg) != 1):
        raise ValueError(
            "'__op_sizeof' function expects 1 argument, but there are %d" %
            len(arg))
    return len(_eval(map, arg[0]))
[docs]def is_powerof2(x):
    """
    Returns True if 'x' is a power of 2, or False otherwise.
    """
    return not (x & (x - 1)) 
def _regex_match(pattern):
    """
    """
    return lambda s: re.match(pattern, s)
def _xchk_power2(map, valid, ev, prefix):
    """
    This is an external checker. It checks whether an integer value is power of 2 or not.
    :param map: 'map' contains the value to be checked. Use 'map.val' to get the value.
    :param valid: 'valid' contains the validation criteria for the to-be-checked value.
    :param ev: The evaluator, where the error messeages are collected.
    :param prefix: The prefix of the checked parameter.
    """
    val = map.val
    if (not is_powerof2(val)):
        debug_print("Error:\nValue %d is not an integer of power of 2" % val)
        ev.record_error(prefix,
                        "Value %d is not an integer of power of 2" % val)
    else:
        debug_print("OK - value is an integer of powere of 2")
def _xchk_file_exists(map, valid, ev, prefix):
    """
    This is an external checker. It checks whether a file (not a dir) exists.
    :param map: 'map' contains the value to be checked. Use 'map.val' to get the valuefile name.
    :param valid: 'valid' contains the validation criteria for the to-be-checked value.
    :param ev: The evaluator, where the error messeages are collected.
    :param prefix: The prefix of the checked parameter.
    """
    val = map.val
    if (val != "" and not os.path.isfile(val)):
        debug_print("Error:\nFile not found: %s" % val)
        ev.record_error(prefix, "File not found: %s" % val)
    else:
        debug_print("OK - file exists")
def _xchk_dir_exists(map, valid, ev, prefix):
    """
    This is an external checker. It checks whether a dir (not a file) exists.
    :param map: 'map' contains the value to be checked. Use 'map.val' to get the valuefile name.
    :param valid: 'valid' contains the validation criteria for the to-be-checked value.
    :param ev: The evaluator, where the error messeages are collected.
    :param prefix: The prefix of the checked parameter.
    """
    val = map.val
    if (val != "" and not os.path.isdir(val)):
        debug_print("Error:\nDirectory not found: %s" % val)
        ev.record_error(prefix, "Directory not found: %s" % val)
    else:
        debug_print("OK - Directory exists")
def _eval(map, arg):
    """
    Evaluates the expression and returns the results.
    :param arg: 'arg' can be either a 'sea.List' object or a 'sea.Atom' object, representing a prefix expression.
    :param map: The original map that the elements in the 'arg' refer.
    """
    if (isinstance(arg, List)):
        val0 = _eval(map, arg[0])
        if (isinstance(val0, str)):
            val0 = val0.strip()
            if (val0 in __OP):
                a = arg[1:]
                return __OP[val0](map, arg[1:])
        return [_eval(map, e) for e in arg]
    else:
        val = arg.val
        if (val in ['-', '@', '']):
            return val
        try:
            if (val[0] == "@"):
                k = map[val[1:]]
                try:
                    return k.val
                except AttributeError:
                    return k
        except TypeError:
            pass
        return val
__OP = {
    "*": __op_mul,
    "==": __op_eq,
    "<": __op_lt,
    "<=": __op_le,
    ">": __op_gt,
    ">=": __op_gt,
    "&&": __op_and,
    "||": __op_or,
    "!": __op_not,
    "@": __op_at,
    "-": __op_minus,
    "cat": __op_cat,
    "sizeof": __op_sizeof,
}
__TYPE = {
    "str": str,
    "str1": (
        str,
        [1, 1000000000],
    ),
    "float": float,
    "float+": (
        float,
        [0, float("inf")],
    ),
    "float-": (
        float,
        [float("-inf"), 0],
    ),
    "float0_1": (
        float,
        [0, 1.0],
    ),
    "int": int,
    "int0": (
        int,
        [0, 1000000000],
    ),
    "int1": (
        int,
        [1, 1000000000],
    ),
    "bool": boolean,
    "bool0": (
        boolean,
        [False],
    ),
    "bool1": (
        boolean,
        [True],
    ),
    "enum": str,
    "list": list,
    "none": None,
    "regex": _regex_match,
}
__CONVERTIBLE_TO = {
    int: [float, str],
    float: [str],
}
__xcheck = {
    "power2": _xchk_power2,
    "file_exists": _xchk_file_exists,
    "dir_exists": _xchk_dir_exists,
}
[docs]def reg_xcheck(name, func):
    """
    Registers external checker.
    :param name: Name of the checker.
    :param func: Callable object that checks validity of a parameter. For interface requirement, see '_xchk_power2', or
        '_xchk_file_exists', or '_xchk_dir_exists' for example.
    """
    __xcheck[name] = func 
def _match(map, valid, ev, prefix, tag):
    """
    Finds the best match.
    """
    kk = map
    vv = valid
    ev_list = []
    for vv_ in vv:
        try:
            _if = ev(vv._if)
        except AttributeError:
            pass
        else:
            debug_print("_if: {} = {}".format(
                str(vv._if),
                _if,
            ), False)
            if (_if):
                debug_print("True")
            else:
                debug_print("False - Skip checking the whole map.")
                return
        ev_ = deepcopy(ev)
        _check_map(kk, vv_, ev_, prefix)
        ev_list.append(ev_)
    if (ev_list != []):
        # Tries to find the best match.
        candidate = [
            ev_list[0],
        ]
        least = len(candidate[0]._unchecked_map)
        for ev_ in ev_list[1:]:
            num = len(ev_._unchecked_map)
            if (num < least):
                candidate = [
                    ev_,
                ]
                least = num
            elif (num == least):
                candidate.append(ev_)
        best_ev = []
        for ev_ in candidate:
            if (ev_._err == ev._err):
                best_ev.append(ev_)
        if (best_ev == []):
            best_ev = candidate
        candidate = best_ev
        best_ev = candidate[0]
        least = best_ev._err.count("Wrong type:")
        if (len(candidate) > 1):
            for ev_ in candidate[1:]:
                num = ev_._err.count("Wrong type:")
                if (num < least):
                    best_ev = ev_
                    least = num
        ev.copy_from(best_ev)
def _check_atom(atom, valid, ev, prefix):
    """
    Checks the validity of atom.
    """
    rr = None  # Range
    # type
    debug_print(prefix + ":")
    debug_print("   checking its type...", False)
    try:
        t = ev(valid.type)
        if (t.startswith("regex:")):
            tt = __TYPE["regex"](t[6:])
        else:
            tt = __TYPE[t]
        if (isinstance(tt, tuple)):
            tt, rr = tt[0], tt[1]
    except AttributeError:
        ev.record_error(
            prefix,
            "Wrong type: expecting a composite parameter, but got an atom")
        return
    except KeyError:
        ev.record_error(
            prefix,
            "Wrong type: %s. 'type' is likely a parameter than a description." %
            t)
        return
    atom_val = atom.val
    if (atom_val is None):
        if (tt is None):
            debug_print("OK - value None is acceptable")
        else:
            ev.record_error(prefix,
                            "Wrong value: expecting %s, but got None" % str(tt))
        return
    if (atom._type == str and inspect.isfunction(tt) and tt != boolean):
        if (tt(atom_val)):
            debug_print("OK - {} matches the pattern: {}".format(
                atom_val, t[6:]))
        else:
            ev.record_error(
                prefix,
                "Wrong type: expecting a string matching {}, but got {}".format(
                    t[6:],
                    atom_val,
                ))
            return
    elif (atom._type != tt and (atom._type not in __CONVERTIBLE_TO or
                                tt not in __CONVERTIBLE_TO[atom._type])):
        ev.record_error(
            prefix, "Wrong type: expecting {}, but got {}".format(
                "boolean" if tt == boolean else str(tt),
                str(atom._type),
            ))
        return
    else:
        debug_print("OK - %s" % t)
    # range
    debug_print("   checking its range...", False)
    try:
        if (rr is None):
            rr = ev(valid.range)
    except AttributeError:
        debug_print("N/A")
    else:
        if (t == "enum" or tt == boolean):
            if (atom_val not in rr):
                ev.record_error(
                    prefix,
                    "Wrong value: should be one of {}, but got '{}'".format(
                        str(rr),
                        str(atom_val),
                    ))
            else:
                debug_print("OK - '{}' is one of {}".format(
                    str(atom_val),
                    str(rr),
                ))
        elif (tt == str):
            if (atom._type != tt):
                atom_val = str(atom_val)
            length = len(atom_val)
            if (length > int(rr[1])):
                ev.record_error(prefix,
                                "String is too long (%d char's)" % length)
            elif (length < int(rr[0])):
                ev.record_error(
                    prefix,
                    "String is too short: it must have at least %d char's" %
                    rr[0])
            else:
                debug_print("OK - string has %d char's" % length)
        else:
            if (atom_val > tt(rr[1]) or atom_val < tt(rr[0])):
                ev.record_error(
                    prefix,
                    "Value out of range: expecting within %s, but got '%s'" %
                    (str(rr), str(atom_val)))
            else:
                debug_print("OK - {} is within {}".format(
                    str(atom_val),
                    str(rr),
                ))
    # _check
    try:
        cc = valid._check
    except AttributeError:
        pass
    else:
        debug_print("   external checking...")
        if (isinstance(cc, List)):
            for e in cc:
                debug_print("      %s: " % e.val, False)
                __xcheck[e.val](atom, valid, ev, prefix)
        elif (cc.val != ""):
            debug_print("      %s: " % cc.val, False)
            __xcheck[cc.val](atom, valid, ev, prefix)
def _check_list(map, valid, ev, prefix, tag):
    """
    Checks the validity of list.
    """
    kk = map
    vv = valid
    # type
    debug_print(prefix + ":")
    debug_print("   checking its type...", False)
    try:
        t = ev(vv.type)
        tt = __TYPE[t]
    except AttributeError:
        ev.record_error(
            prefix,
            "Wrong type: expecting a composite parameter, but got a list")
        return
    if (tt != list):
        ev.record_error(
            prefix, "Wrong type: expecting %s, but got <type 'list'>" % str(tt))
        return
    debug_print("OK - list")
    # size
    try:
        debug_print("   checking its size...", False)
        size = ev(vv.size)
        ll = len(kk)
        if (size > 0 and ll != size):
            ev.record_error(
                prefix,
                "Wrong list length: expecting %d elements, but got %d" % (
                    size,
                    ll,
                ))
        elif (size < 0 and ll < -size):
            ev.record_error(
                prefix,
                "Wrong list length: expecting at least %d elements, but got %d"
                % (
                    -size,
                    ll,
                ))
        else:
            debug_print("OK - %d" % ll)
    except AttributeError:
        pass
    # elem
    debug_print("   checking each element in list...", False)
    try:
        if (isinstance(vv.elem, List)):
            lv, lk = len(vv.elem), len(kk)
            [
                _check_map(k, v, ev, ("%s[%d]" % (prefix, i)), tag)
                for i, k, v in zip(list(range(lv)), kk, vv.elem)
            ]
            if (lv < lk):
                v = vv.elem[-1]
                [
                    _check_map(kk[i], v, ev, ("%s[%d]" % (prefix, i)), tag)
                    for i in range(lv, lk)
                ]
        else:
            [
                _check_map(elem, vv.elem, ev, ("%s[%d]" % (prefix, i)), tag)
                for i, elem in enumerate(kk)
            ]
    except AttributeError:
        debug_print("OK - No requirement for elements")
    # _check
    try:
        cc = vv._check
        debug_print("   external checking for the whole list...")
        if (isinstance(cc, List)):
            for e in cc:
                debug_print("      %s: " % e.val, False)
                __xcheck[e.val](map, valid, ev, prefix)
        elif (cc.val != ""):
            debug_print("      %s: " % cc.val, False)
            __xcheck[cc.val](map, valid, ev, prefix)
    except AttributeError:
        pass
def _check_map(map, valid, ev, prefix="", tag=set()):  # noqa: M511
    """
    Checks the validity of a map.
    """
    # _if
    try:
        _if = ev(valid._if)
    except AttributeError:
        pass
    else:
        debug_print("_if: {} = {}".format(
            str(valid._if),
            _if,
        ), False)
        if (_if):
            debug_print("True")
        else:
            debug_print("False - Skip checking the whole map.")
            return
    if (isinstance(valid, List)):
        return _match(map, valid, ev, prefix, tag)
    if (isinstance(map, Atom)):
        _check_atom(map, valid, ev, prefix)
    elif (isinstance(map, List)):
        _check_list(map, valid, ev, prefix, tag)
    elif (isinstance(map, Map)):
        # _skip
        try:
            skip = valid._skip.val
        except AttributeError:
            skip = []
        else:
            if (not isinstance(skip, list) and skip != "all"):
                raise ValueError(
                    "_skip must be either a list of strings or the string \"all\""
                )
        # _mapcheck
        try:
            cc = valid._mapcheck
        except AttributeError:
            pass
        else:
            debug_print(prefix + ":")
            debug_print("   external checking for the whole map...")
            if (isinstance(cc, List)):
                for e in cc:
                    debug_print("      %s: " % e.val, False)
                    __xcheck[e.val](map, valid, ev, prefix)
            elif (cc.val != ""):
                debug_print("      %s: " % cc.val, False)
                __xcheck[cc.val](map, valid, ev, prefix)
        # _enforce
        try:
            cc = valid._enforce
        except AttributeError:
            pass
        else:
            if (not isinstance(cc, List)):
                raise ValueError("_enforce must be a list of strings")
            debug_print(prefix + ":")
            debug_print("   enforcing keys...", False)
            missing_key = [e for e in cc.val if (e not in map)]
            missing_key = ", ".join(missing_key)
            if (missing_key == ""):
                debug_print("OK - All enforced keys present")
            else:
                debug_print("Error\nMissing keys: " + missing_key[0:-2])
                ev.record_error(prefix, "Missing keys: " + missing_key[0:-2])
        if ("all" != skip):
            # Key-value pairs
            key_value = [
                (k, kk) for k, kk in map.key_value(tag) if (k not in skip)
            ]
            for k, kk in key_value:
                try:
                    vv = valid[k]
                except KeyError:
                    ev._unchecked_map.append(prefix + '.' + k)
                    continue
                _check_map(kk, vv, ev, prefix + '.' + k, tag)