import inspect
import typing
from typing import Dict
from typing import List
from typing import Set
from typing import Tuple
from schrodinger.models import parameters
from schrodinger.utils import scollections
[docs]def selective_set_value(src_model, dest_model, *, exclude=tuple()):
"""
Set the value of `dest_model` to the value of `src_model`, excluding any
params specified by `exclude`.
:param src_model: The param to take values from.
:type src_model: `CompoundParam`
:param dest_model: The param to apply values to.
:type dest_model: `CompoundParam`
:param exclude: An iterator of abstract params specifying values to ignore
when applying src_model to dest_model.
:type dest_model: iterator[Param]
"""
if type(src_model) is not type(dest_model):
err_msg = (
"Selective_set_value can only be called with params of the "
"exact same type.\n"
f"src_model type is {type(src_model)} while dest_model is {type(dest_model)}"
)
raise TypeError(err_msg)
exclude_set = scollections.IdSet()
for excluded_abs_param in exclude:
excluded_abs_param = _reroot_abstract_param(excluded_abs_param,
type(src_model))
if isinstance(excluded_abs_param, parameters.CompoundParam):
excluded_atomic_params = parameters.get_all_atomic_subparams(
excluded_abs_param)
exclude_set.update(excluded_atomic_params)
else:
exclude_set.add(excluded_abs_param)
abstract_model = type(dest_model)
all_params_set = scollections.IdSet(
parameters.get_all_atomic_subparams(abstract_model))
for abs_param in (all_params_set - exclude_set):
src_value = abs_param.getParamValue(src_model)
abs_param.setParamValue(dest_model, src_value)
[docs]def map_subparams(map_func, compound_param, subparam_type):
"""
Find all subparams of type `subparam_type` in `compound_param` and apply
`map_func` to it. The `compound_param` will be modified in place.
An example::
class Model(parameters.CompoundParam):
workflow_runtimes: List[float]
total_idle_time: float
min_max_runtimes: Tuple[float, float] = None
model = Model(workflow_runtimes = [60, 120, 180],
total_idle_time=90,
min_max_runtimes=(60,180))
def seconds_to_minutes(seconds):
return seconds/60
map_subparams(seconds_to_minutes, model, subparam_type=float)
model.workflow_runtimes # [1, 2, 3]
model.total_idle_time # 1.5
model.min_max_runtimes # (1, 3)
Optionally, the map_func may accept a second argument of abstract_param.
This may be useful in error messages or debugging. Example::
def seconds_to_minutes(seconds, abstract_param):
if seconds is None:
print(f'{abstract_subparam} was not set.')
return None
return seconds/60
Note that the argument must be named 'abstract_param' for it to get picked
up.
"""
abstract_subparams = compound_param.getAbstractParam().getSubParams()
subparams = compound_param.getSubParams()
type_hints = typing.get_type_hints(type(compound_param))
for param_name, param in subparams.items():
abstract_subparam = abstract_subparams[param_name]
if type_hints.get(param_name):
try:
retval = _mapGeneric(map_func, param, subparam_type,
type_hints[param_name], abstract_subparam)
except Exception as e:
raise RuntimeError(f'Exception raised from {map_func} for '
f'{abstract_subparam}:\n{str(e)}')
setattr(compound_param, param_name, retval)
elif isinstance(abstract_subparam, parameters.BaseMutableParam):
try:
retval = _mapGeneric(map_func, param, subparam_type,
abstract_subparam.getTypeHint(),
abstract_subparam)
except Exception as e:
raise RuntimeError(f'Exception raised from {map_func} for '
f'{abstract_subparam}:\n{str(e)}')
setattr(compound_param, param_name, retval)
return compound_param
#===========================================================================
# Private functions
#===========================================================================
def _reroot_abstract_param(abstract_param, root_param):
"""
Given an abstract param and a root param (i.e. a top-level param), return
the corresponding abstract param with the same root. For example, if you
wanted to get `Coord.x` from `Model.coord.x`, you could call this function
like so `_reroot_abstract_param(Model.coord.x, Coord)`.
Used by `selective_set_value`
"""
if issubclass(root_param, abstract_param.ownerChain()[0]):
return abstract_param
chain_names = []
for parent in abstract_param.ownerChain()[::-1]:
if issubclass(root_param,
type(parent)) and type(parent) is not parameters.Param:
break
chain_names.append(parent.paramName())
else:
raise ValueError(
f'Root param "{root_param}" not in owner chain of "{abstract_param}"'
)
normalized_param = root_param
for name in chain_names[::-1]:
normalized_param = getattr(normalized_param, name)
return normalized_param
def _mapGeneric(map_func,
obj,
subparam_type,
_obj_type,
abstract_subparam=None):
"""
Used by `map_subparams`.
"""
if _obj_type is None:
_obj_type = type(obj)
if parameters.permissive_issubclass(_obj_type, parameters.CompoundParam):
obj = map_subparams(map_func, obj, subparam_type)
elif parameters.permissive_issubclass(_obj_type, List):
obj = _mapListOfSubparams(map_func, obj, subparam_type, _obj_type)
elif parameters.permissive_issubclass(_obj_type, Dict):
obj = _mapDictOfSubparams(map_func, obj, subparam_type, _obj_type)
elif parameters.permissive_issubclass(_obj_type, Set):
obj = _mapSetOfSubparams(map_func, obj, subparam_type, _obj_type)
elif parameters.permissive_issubclass(_obj_type, Tuple):
obj = _mapTupleOfSubparams(map_func, obj, subparam_type, _obj_type)
elif parameters.permissive_issubclass(_obj_type, subparam_type):
if abstract_subparam is not None:
map_func_signature = inspect.signature(map_func)
if 'abstract_param' in map_func_signature.parameters:
return map_func(obj, abstract_param=abstract_subparam)
return map_func(obj)
return obj
def _mapListOfSubparams(map_func, obj, subparam_type, _obj_type):
"""
Used by `map_subparams`.
"""
if not hasattr(_obj_type, '__args__'):
return obj
if _obj_type.__args__:
item_type = _obj_type.__args__[0]
new_list = [
_mapGeneric(
map_func,
item,
subparam_type,
item_type,
) for item in obj
]
obj = new_list
return obj
def _mapDictOfSubparams(map_func, obj, subparam_type, _obj_type):
"""
Used by `map_subparams`.
"""
if not hasattr(_obj_type, '__args__'):
return obj
def arg_is_subparam_type(arg, subparam_type):
return parameters.permissive_issubclass(arg, subparam_type)
if _obj_type.__args__:
if arg_is_subparam_type(_obj_type.__args__[1], subparam_type):
value_type = _obj_type.__args__[1]
new_dict = {
k: _mapGeneric(map_func, v, subparam_type, value_type)
for k, v in obj.items()
}
obj = new_dict
if arg_is_subparam_type(_obj_type.__args__[0], subparam_type):
value_type = _obj_type.__args__[0]
new_dict = {
_mapGeneric(
map_func,
k,
subparam_type,
value_type,
): v for k, v in obj.items()
}
obj = new_dict
return obj
def _mapSetOfSubparams(map_func, obj, subparam_type, _obj_type):
"""
Used by `map_subparams`.
"""
if not hasattr(_obj_type, '__args__'):
return obj
if _obj_type.__args__:
item_type = _obj_type.__args__[0]
new_set = {
_mapGeneric(
map_func,
item,
subparam_type,
item_type,
) for item in obj
}
obj = new_set
return obj
def _mapTupleOfSubparams(map_func, obj, subparam_type, _obj_type):
"""
Used by `map_subparams`.
"""
if not hasattr(_obj_type, '__args__'):
return obj
if _obj_type.__args__:
new_tuple = tuple(
_mapGeneric(
map_func,
item,
subparam_type,
item_type,
) for item, item_type in zip(obj, _obj_type.__args__))
obj = new_tuple
return obj