import inspect
import typing
from typing import Dict
from typing import List
from typing import Set
from typing import Tuple
from schrodinger.models import parameters
from schrodinger.utils import scollections
[docs]def selective_set_value(src_model, dest_model, *, exclude=tuple()):
    """
    Set the value of `dest_model` to the value of `src_model`, excluding any
    params specified by `exclude`.
    :param src_model: The param to take values from.
    :type  src_model: `CompoundParam`
    :param dest_model: The param to apply values to.
    :type  dest_model: `CompoundParam`
    :param exclude: An iterator of abstract params specifying values to ignore
        when applying src_model to dest_model.
    :type  dest_model: iterator[Param]
    """
    if type(src_model) is not type(dest_model):
        err_msg = (
            "Selective_set_value can only be called with params of the "
            "exact same type.\n"
            f"src_model type is {type(src_model)} while dest_model is {type(dest_model)}"
        )
        raise TypeError(err_msg)
    exclude_set = scollections.IdSet()
    for excluded_abs_param in exclude:
        excluded_abs_param = _reroot_abstract_param(excluded_abs_param,
                                                    type(src_model))
        if isinstance(excluded_abs_param, parameters.CompoundParam):
            excluded_atomic_params = parameters.get_all_atomic_subparams(
                excluded_abs_param)
            exclude_set.update(excluded_atomic_params)
        else:
            exclude_set.add(excluded_abs_param)
    abstract_model = type(dest_model)
    all_params_set = scollections.IdSet(
        parameters.get_all_atomic_subparams(abstract_model))
    for abs_param in (all_params_set - exclude_set):
        src_value = abs_param.getParamValue(src_model)
        abs_param.setParamValue(dest_model, src_value) 
[docs]def map_subparams(map_func, compound_param, subparam_type):
    """
    Find all subparams of type `subparam_type` in `compound_param` and apply
    `map_func` to it. The `compound_param` will be modified in place.
    An example::
        class Model(parameters.CompoundParam):
            workflow_runtimes: List[float]
            total_idle_time: float
            min_max_runtimes: Tuple[float, float] = None
        model = Model(workflow_runtimes = [60, 120, 180],
                        total_idle_time=90,
                        min_max_runtimes=(60,180))
        def seconds_to_minutes(seconds):
            return seconds/60
        map_subparams(seconds_to_minutes, model, subparam_type=float)
        model.workflow_runtimes # [1, 2, 3]
        model.total_idle_time # 1.5
        model.min_max_runtimes # (1, 3)
    Optionally, the map_func may accept a second argument of abstract_param.
    This may be useful in error messages or debugging. Example::
        def seconds_to_minutes(seconds, abstract_param):
            if seconds is None:
                print(f'{abstract_subparam} was not set.')
                return None
            return seconds/60
    Note that the argument must be named 'abstract_param' for it to get picked
    up.
    """
    abstract_subparams = compound_param.getAbstractParam().getSubParams()
    subparams = compound_param.getSubParams()
    type_hints = typing.get_type_hints(type(compound_param))
    for param_name, param in subparams.items():
        abstract_subparam = abstract_subparams[param_name]
        if type_hints.get(param_name):
            try:
                retval = _mapGeneric(map_func, param, subparam_type,
                                     type_hints[param_name], abstract_subparam)
            except Exception as e:
                raise RuntimeError(f'Exception raised from {map_func} for '
                                   f'{abstract_subparam}:\n{str(e)}')
            setattr(compound_param, param_name, retval)
        elif isinstance(abstract_subparam, parameters.BaseMutableParam):
            try:
                retval = _mapGeneric(map_func, param, subparam_type,
                                     abstract_subparam.getTypeHint(),
                                     abstract_subparam)
            except Exception as e:
                raise RuntimeError(f'Exception raised from {map_func} for '
                                   f'{abstract_subparam}:\n{str(e)}')
            setattr(compound_param, param_name, retval)
    return compound_param 
#===========================================================================
# Private functions
#===========================================================================
def _reroot_abstract_param(abstract_param, root_param):
    """
    Given an abstract param and a root param (i.e. a top-level param), return
    the corresponding abstract param with the same root. For example, if you
    wanted to get `Coord.x` from `Model.coord.x`, you could call this function
    like so `_reroot_abstract_param(Model.coord.x, Coord)`.
    Used by `selective_set_value`
    """
    if issubclass(root_param, abstract_param.ownerChain()[0]):
        return abstract_param
    chain_names = []
    for parent in abstract_param.ownerChain()[::-1]:
        if issubclass(root_param,
                      type(parent)) and type(parent) is not parameters.Param:
            break
        chain_names.append(parent.paramName())
    else:
        raise ValueError(
            f'Root param "{root_param}" not in owner chain of "{abstract_param}"'
        )
    normalized_param = root_param
    for name in chain_names[::-1]:
        normalized_param = getattr(normalized_param, name)
    return normalized_param
def _mapGeneric(map_func,
                obj,
                subparam_type,
                _obj_type,
                abstract_subparam=None):
    """
    Used by `map_subparams`.
    """
    if _obj_type is None:
        _obj_type = type(obj)
    if parameters.permissive_issubclass(_obj_type, parameters.CompoundParam):
        obj = map_subparams(map_func, obj, subparam_type)
    elif parameters.permissive_issubclass(_obj_type, List):
        obj = _mapListOfSubparams(map_func, obj, subparam_type, _obj_type)
    elif parameters.permissive_issubclass(_obj_type, Dict):
        obj = _mapDictOfSubparams(map_func, obj, subparam_type, _obj_type)
    elif parameters.permissive_issubclass(_obj_type, Set):
        obj = _mapSetOfSubparams(map_func, obj, subparam_type, _obj_type)
    elif parameters.permissive_issubclass(_obj_type, Tuple):
        obj = _mapTupleOfSubparams(map_func, obj, subparam_type, _obj_type)
    elif parameters.permissive_issubclass(_obj_type, subparam_type):
        if abstract_subparam is not None:
            map_func_signature = inspect.signature(map_func)
            if 'abstract_param' in map_func_signature.parameters:
                return map_func(obj, abstract_param=abstract_subparam)
        return map_func(obj)
    return obj
def _mapListOfSubparams(map_func, obj, subparam_type, _obj_type):
    """
    Used by `map_subparams`.
    """
    if not hasattr(_obj_type, '__args__'):
        return obj
    if _obj_type.__args__:
        item_type = _obj_type.__args__[0]
        new_list = [
            _mapGeneric(
                map_func,
                item,
                subparam_type,
                item_type,
            ) for item in obj
        ]
        obj = new_list
    return obj
def _mapDictOfSubparams(map_func, obj, subparam_type, _obj_type):
    """
    Used by `map_subparams`.
    """
    if not hasattr(_obj_type, '__args__'):
        return obj
    def arg_is_subparam_type(arg, subparam_type):
        return parameters.permissive_issubclass(arg, subparam_type)
    if _obj_type.__args__:
        if arg_is_subparam_type(_obj_type.__args__[1], subparam_type):
            value_type = _obj_type.__args__[1]
            new_dict = {
                k: _mapGeneric(map_func, v, subparam_type, value_type)
                for k, v in obj.items()
            }
            obj = new_dict
        if arg_is_subparam_type(_obj_type.__args__[0], subparam_type):
            value_type = _obj_type.__args__[0]
            new_dict = {
                _mapGeneric(
                    map_func,
                    k,
                    subparam_type,
                    value_type,
                ): v for k, v in obj.items()
            }
            obj = new_dict
    return obj
def _mapSetOfSubparams(map_func, obj, subparam_type, _obj_type):
    """
    Used by `map_subparams`.
    """
    if not hasattr(_obj_type, '__args__'):
        return obj
    if _obj_type.__args__:
        item_type = _obj_type.__args__[0]
        new_set = {
            _mapGeneric(
                map_func,
                item,
                subparam_type,
                item_type,
            ) for item in obj
        }
        obj = new_set
    return obj
def _mapTupleOfSubparams(map_func, obj, subparam_type, _obj_type):
    """
    Used by `map_subparams`.
    """
    if not hasattr(_obj_type, '__args__'):
        return obj
    if _obj_type.__args__:
        new_tuple = tuple(
            _mapGeneric(
                map_func,
                item,
                subparam_type,
                item_type,
            ) for item, item_type in zip(obj, _obj_type.__args__))
        obj = new_tuple
    return obj