import copy
import enum
import inspect
import typing
from typing import Dict
from typing import List
from typing import Set
from typing import Tuple
from schrodinger import structure
from schrodinger.infra import util
from schrodinger.models import json
from schrodinger.models import jsonable
from schrodinger.Qt import QtCore
from schrodinger.utils.future import get_args
from schrodinger.utils.future import get_origin
from schrodinger.utils.qt_utils import get_signals
from schrodinger.utils.qt_utils import suppress_signals
from schrodinger.utils.scollections import IdDict
DEFAULT = object()
INFER = object()
BUILTIN_TYPES = {int, float, set, frozenset, dict, tuple, list, str, bin, bool}
#===============================================================================
# Hybrid Methods Support
#===============================================================================
[docs]def classandinstancemethod(func):
"""
Decorator used to indicate that a particular method in a class declaration
should work as both a class method and an instance method. Only works when
used in a HybridMethodsMixin subclass. Example::
class Foo(HybridMethodsMixin):
x = 1
@classandinstancemethod:
def bar(self):
print(x)
f = Foo()
f.x = 3
Foo.bar() # prints '1'
f.bar() # prints '3'
"""
clsmethod = classmethod(func)
clsmethod._is_hybrid_method = True
clsmethod._original_func = func
return clsmethod
[docs]class HybridMethodsMixin:
"""
Mixin that allows the use of the `classandinstancemethod` decorator.
"""
_hybrid_methods = {}
def __init_subclass__(cls):
super().__init_subclass__()
# Iterate through all attrs to find and register hybrid methods
cls._hybrid_methods = copy.copy(cls._hybrid_methods)
for name, attr in cls.__dict__.items():
if hasattr(attr, '_is_hybrid_method'):
cls._hybrid_methods[name] = attr
[docs] def __init__(self, *args, **kwargs):
self._addHybridMethods()
super().__init__()
def _addHybridMethods(self):
# Iterate through all hybrid methods
for name, method in self._hybrid_methods.items():
func = method._original_func # Original unbound method
bound_method = func.__get__(self, self.__class__)
# Shadow class method with bound instance method
setattr(self, name, bound_method)
#===============================================================================
# Param class declaration components
#===============================================================================
_SENTINEL = object()
def _is_in_CompoundParam_class_construction():
"""
Determine whether this function is being called from the class declaration
body of a CompoundParam subclass.
"""
frame = inspect.currentframe()
while frame:
loc = frame.f_locals
if loc.get('_in_CompoundParam_class_construction') is _SENTINEL:
return True
frame = frame.f_back
return False
[docs]def abstract_only(method):
def new_method(self, *args, **kwargs):
if not self._param_type in (ParamType.ABSTRACT, ParamType.DESCRIPTOR):
raise ValueError(f'{method.__name__} may only be called on an '
f'abstract param. {self} is {self._param_type}')
return method(self, *args, **kwargs)
return new_method
[docs]def concrete_only(method):
def new_method(self, *args, **kwargs):
if not self._param_type is ParamType.CONCRETE:
raise ValueError(f'{method.__name__} may only be called on a '
f'concrete param. {self} is {self._param_type}')
return method(self, *args, **kwargs)
return new_method
[docs]def descriptor_only(method):
def new_method(self, *args, **kwargs):
if not self._param_type is ParamType.DESCRIPTOR:
raise ValueError(f'{method.__name__} may only be called on a '
f'descriptor param. {self} is {self._param_type}')
return method(self, *args, **kwargs)
return new_method
[docs]class ParamType(enum.Enum):
ABSTRACT = 'Abstract'
CONCRETE = 'Concrete'
DESCRIPTOR = 'Descriptor'
[docs]class SignalType:
Changed = 'Changed'
Mutated = 'mutated'
ItemChanged = 'itemChanged'
Replaced = 'Replaced'
#===============================================================================
# Base Param Classes
#===============================================================================
[docs]class NonParamAttribute:
"""
This class can be used to declare a public attribute on a `CompoundParam`.
Declared public attributes can be used without error.
Example usage::
class Coord(CompoundParam):
x: int
y: int
note = NonParamAttribute()
coord = Coord()
coord.note = "hello" # No error
"""
def __set_name__(self, owner_class, name):
self._attr_name = '_nonparam_attribute_' + name
def __get__(self, owner_instance, owner_class):
if owner_instance is None:
return self
return getattr(owner_instance, self._attr_name)
def __set__(self, owner_instance, value):
return setattr(owner_instance, self._attr_name, value)
[docs]class Param(HybridMethodsMixin, QtCore.QObject):
"""
Base class for all Param classes. A Param is a descriptor for storing data,
which means that a single Param instance will manage the data values for
multiple instances of the class that owns it. Example::
class Coord(CompoundParam):
x: int
y: int
An instance of the Coord class can be created normally, and Params can be
accessed as normal attributes::
coord = Coord()
coord.x = 4
When a Param value is set, the :code:`valueChanged` signal is emitted.
Params can be serialized and deserialized to and from JSON. Params can also
be nested::
class Atom(CompoundParam):
coord: Coord
element: str
"""
#TODO: Add validation: type, bounds, custom, etc.
_param_name = None
_owner = None
_param_type = ParamType.ABSTRACT
DataClass = object
#===========================================================================
# Param - Construction
#===========================================================================
def __set_name__(self, owner_class, name):
if not issubclass(owner_class, (CompoundParam, CompoundParamMixin)):
raise TypeError(f'Cannot add {name} of type {type(self)} to '
f'{owner_class}. Subparams may only be added to '
f'CompoundParams.')
self._param_name = name
def __new__(cls, *args, _param_type=INFER, **kwargs):
"""
We store the init args and kwargs so they can be reused when
instantiating new abstract params based on this. We do this in new as
that allows to store all the args/kwargs generically regardless of how
subclasses may alter the signature of __init__.
"""
instance = super().__new__(cls,
*args,
_param_type=_param_type,
**kwargs)
instance._init_args = args
instance._init_kwargs = kwargs
return instance
[docs] def __init__(self,
default_value=DEFAULT,
DataClass=None,
deepcopyable=True,
*,
_param_type=INFER):
"""
:param default_value: The value to use in constructing the default
value for the param. A new param will have the value returned by
`DataClass(default_value)`.
:type default_value: object
:param DataClass: The type to use for values of this param.
:type DataClass: type
:param deepcopyable: Whether values of this param are deepcopyable. If
this param is not deepcopyable and its owner param is deepcopied,
the copy's subparam value will be identical.
:type deepcopyable: bool
:param _param_type: For internal use only.
"""
self._deepcopyable = deepcopyable
if _param_type is INFER:
self._inferParamType()
self._assertValidScope()
else:
self._param_type = _param_type
super().__init__()
self._setupDataClass(DataClass)
self._default_value_arg = default_value
if self.isAbstract():
self.initAbstract()
self._validateDataClass()
def _validateDataClass(self):
if self.DataClass in BUILTIN_TYPES or self._default_value_arg is None:
return
try:
dataclass_constructor_signature = inspect.signature(self.DataClass)
except ValueError:
# There are a few classes (e.g. subclasses of builtin types or
# QObject) that don't have proper signatures. We just skip them.
return
def is_required(arg):
return (arg.kind in (arg.POSITIONAL_ONLY, arg.POSITIONAL_OR_KEYWORD)
and arg.default is arg.empty)
num_required_args = sum(
1 for arg in dataclass_constructor_signature.parameters.values()
if is_required(arg))
self._inspectDataclassSignature(dataclass_constructor_signature,
num_required_args)
def _inspectDataclassSignature(self, dataclass_constructor_signature,
num_required_args):
"""
Make sure that we'll be able to instantiate a DataClass object when we
instantiate a Param object.
:param dataclass_constructor_signature: The signature for the DataClass
constructor.
:type dataclass_constructor_signature: `inspect.Signature`
:param num_required_args: The number of required arguments for
instantiating a DataClass object.
:type num_required_args: int
"""
if self._default_value_arg is DEFAULT:
if num_required_args > 0:
raise TypeError("Default value undefined for DataClasses "
"that require a constructor argument. Try "
"specifying None as the default value "
"instead.")
[docs] def initAbstract(self):
pass
#===========================================================================
# Param - Common API
#===========================================================================
[docs] def getTypeHint(self):
return self.DataClass
[docs] @classandinstancemethod
def paramName(self):
"""
Get the name of the param::
# Can be called on an abstract param:
print(Coord.x.paramName()) # 'x'
# ...or on an instance of a CompoundParam
a = Atom()
a.coord.paramName() # 'coord'
"""
return self._param_name
@classandinstancemethod
def _fullyQualifiedName(self) -> str:
"""
Get the fully qualified name of this param. For example::
print(Coord.x._fullyQualifiedName()) # 'Coord.x'
print(Coord._fullyQualifiedName()) # 'Coord'
print(Coord().x._fullyQualifiedName()) # 'Coord.x'
print(Coord()._fullyQualifiedName()) # 'Coord'
"""
if isinstance(self, type):
return self.__name__
elif self.owner() is None:
return type(self).__name__
else:
return self.owner()._fullyQualifiedName() + '.' + self.paramName()
[docs] @classandinstancemethod
def ownerChain(self):
"""
Returns a list of param owners starting from the toplevel param and
ending with self. Examples:
:code:`foo.bar.atom.coord.ownerChain()` will return :code:`[foo, bar,
atom, coord]` where every item is a concrete param.
:code:`Foo.bar.atom.coord.x.ownerChain()` will return :code:`[Foo,
Foo.bar, Foo.atom.coord, Foo.atom.coord.x]` where every item is an
abstract params.
"""
if self.owner() is None:
return [self]
return self.owner().ownerChain() + [self]
[docs] @classandinstancemethod
def owner(self):
"""
Get the owner of the param::
# Can be called on an abstract param:
assert Coord.x.owner() == Coord
# ...or on an instance of a CompoundParam
a = Atom()
assert a.coord.owner() == a
"""
return self._owner
[docs] @classandinstancemethod
def isAbstract(self):
"""
Whether the param is an "abstract" param.
"""
return self._param_type in (ParamType.ABSTRACT, ParamType.DESCRIPTOR)
#===========================================================================
# Param - Abstract API
#===========================================================================
[docs] @classandinstancemethod
@abstract_only
def getParamValue(self, obj):
"""
Enables access to a param value on a compound param via an abstract
param reference::
a = Atom()
assert Atom.coord.x.getParamValue(a) == 0 # ints default to 0
a.coord.x = 3
assert Atom.coord.x.getParamValue(a) == 3
:param param: The owner param to get a param value from
:type param: CompoundParam
"""
conc_param = obj
chain = self.ownerChain()
for abs_param in chain[1:]:
if type(self) is type(conc_param):
return conc_param
try:
conc_param = conc_param.getSubParam(abs_param.paramName())
except AttributeError:
raise TypeError(
f'Concrete param "{obj}" does not match abstract param "{self}".'
)
return conc_param
[docs] @classandinstancemethod
@abstract_only
def setParamValue(self, obj, value):
"""
Set the value of a param on an object by specifying the instance and the
value::
# Setting the param value of a basic param
a = Atom()
Atom.coord.x.setParamValue(a, 5)
assert a.coord.x == 5
# setParamValue can also be used to set the value of CompoundParams
c = Coord()
c.x = 10
atom.coord.setParamValue(a, c)
assert atom.coord.x == 10
:param param: The owner param to set a subparam value of.
:param value: The value to set the subparam value to.
"""
concrete_param = self.getParamValue(obj)
if isinstance(concrete_param, _MutateToMatchMixin):
concrete_param._mutateToMatch(value)
else:
owner = self.owner()
concrete_owner = owner.getParamValue(obj)
concrete_owner._setSubParam(self.paramName(), value)
[docs] @classandinstancemethod
@abstract_only
def defaultValue(self):
"""
Returns the default value for this abstract param::
default_atom = Atom.defaultValue()
assert Atom.coord.x == 0
"""
return self._makeConcreteParam()
[docs] @classandinstancemethod
@abstract_only
def getParamSignal(self, obj, signal_type=SignalType.Changed):
if self.owner() is None:
if signal_type is SignalType.Changed:
return obj.valueChanged
else:
raise ValueError(f'Cannot get {signal_type} signal on '
f'top-level param {self}.')
if signal_type in (SignalType.Mutated, SignalType.ItemChanged):
concrete_param = self.getParamValue(obj)
return getattr(concrete_param, signal_type)
conc_param_owner = self.owner().getParamValue(obj)
return conc_param_owner._getSignal(self.paramName(), signal_type)
#===========================================================================
# Param - Concrete API
#===========================================================================
# There is no such thing as a concrete atomic Param.
#===========================================================================
# Param - Descriptor methods
#===========================================================================
def _setupDataClass(self, DataClass):
if DataClass is self.DataClass:
return
if DataClass is not None:
if self.DataClass is not object:
raise ValueError(
'Attempting to set DataClass of a param with an already '
'set DataClass')
self.DataClass = DataClass
if self.DataClass is None:
raise NotImplementedError('DataClass must be defined.')
@descriptor_only
def __get__(self, owner_instance, owner_class):
if owner_instance is None:
owner = owner_class
else:
owner = owner_instance
return owner.getSubParam(self.paramName())
@descriptor_only
def __set__(self, owner_instance, value):
if value is owner_instance.getSubParam(self.paramName()):
return
owner_instance._setSubParam(self.paramName(), value)
@descriptor_only
def _makeAbstractParam(self):
abstract_param = self._instantiateAbstractParam()
abstract_param._param_name = self.paramName()
return abstract_param
@descriptor_only
def _instantiateAbstractParam(self):
cls = type(self)
abstract_param = cls(*self._init_args,
_param_type=ParamType.ABSTRACT,
**self._init_kwargs)
return abstract_param
@abstract_only
def _makeConcreteParam(self, owner=None):
def_arg = self._default_value_arg
if def_arg is None:
return None
if def_arg is DEFAULT:
if self.DataClass is object:
return None
return self.DataClass()
else:
return self._prepareDefaultValue(def_arg)
@abstract_only
def _prepareDefaultValue(self, def_arg):
"""
Given the default value provided when this param was created, return a
value to initially assign to the param.
"""
if isinstance(def_arg, self.DataClass):
return def_arg
else:
# convert the value to the appropriate type if required
return self.DataClass(def_arg)
#===========================================================================
# Param - Internal Implementation Methods
#===========================================================================
def _inferParamType(self):
if _is_in_CompoundParam_class_construction():
self._param_type = ParamType.DESCRIPTOR
else:
self._param_type = ParamType.CONCRETE
def _assertValidScope(self):
if self._param_type is not ParamType.DESCRIPTOR:
raise RuntimeError(f'Atomic param {self} may not be instantiated '
'outside of a CompoundParam class declaration.')
@classandinstancemethod
def _setOwner(self, owner):
if owner is not None and self._owner is not None:
raise ValueError(
f'{repr(self)} already belongs to {repr(self._owner)}. Cannot '
f'move the param to {repr(owner)}')
self._owner = owner
@classandinstancemethod
def _chainString(self):
chain = []
for o in self.ownerChain():
if isinstance(o, type):
chain.append(o.__name__)
else:
chain.append(str(o.paramName()))
return '.'.join(chain)
@classandinstancemethod
@abstract_only
def _setDefaultValueFrom(self, concrete_param):
self._default_value_arg = concrete_param
def __repr__(self):
chain = self._chainString()
return f'<{self._param_type.value}:{chain}>'
[docs]class FloatParam(Param):
DataClass = float
[docs]class IntParam(Param):
DataClass = int
[docs]class StructureParam(Param):
DataClass = structure.Structure
[docs] def __init__(self, *args, **kwargs):
super().__init__(*args,
deepcopyable=False,
default_value=None,
**kwargs)
[docs]class StringParam(Param):
DataClass = str
[docs]class BoolParam(Param):
DataClass = bool
[docs]class TupleParam(Param):
DataClass = tuple
#===============================================================================
# Compound Param
#===============================================================================
class _MutateToMatchMixin:
def _mutateToMatch(self, new_value):
raise NotImplementedError
[docs]class CompoundParamMixin(metaclass=CompoundParamMeta):
"""
A base class for sharing params between classes. For example::
class ColorMixin(CompoundParamMixin):
color: str = 'red'
class ColoredSquare(ColorMixin, CompoundParam):
length: int = 5
class ColoredTriangle(ColorMixin, CompoundParam):
base: int = 5
height: int = 3
Both `ColoredSquare` and `ColoredTriangle` will have a param `color`.
"""
def __init_subclass__(cls):
if not issubclass(cls, CompoundParam):
if '__annotations__' in cls.__dict__:
cls._mixin_annotations = _get_uninherited_type_hints(cls)
mixin_params = {}
for attr_name, attr in cls.__dict__.items():
if isinstance(attr, Param):
mixin_params[attr_name] = attr
if mixin_params:
cls._mixin_params = mixin_params
super().__init_subclass__()
[docs]class CompoundParam(json.JsonableClassMixin,
_MutateToMatchMixin,
Param,
metaclass=CompoundParamMeta):
"""
=============
Serialization
=============
All `CompoundParam` instances are automatically serializable if their
subparams are serializable. To serialize and deserialize, use the
schrodinger json module::
from schrodinger.models import json
class Coord(parameters.CompoundParam):
x: int
y: int
c1 = Coord(x=1, y=2)
c1_string = json.dumps(c1)
c2 = json.loads(c1_string, DataClass=Coord)
assert c1 == c2
"""
valueChanged = QtCore.pyqtSignal(object)
skip_eq_check = util.flag_context_manager('_skip_eq_check')
block_signal_propagation = util.flag_context_manager(
'_block_signal_propagation')
_reference_defs = []
DataClass = NonParamAttribute()
def __setattr__(self, name, value):
if not name.startswith('_') and not hasattr(type(self), name):
raise AttributeError(
f"Error while setting attribute \"{name}\". "
"Undeclared public attributes are not allowed "
"on CompoundParams. Use an attribute with a leading underscore "
"instead or declare your attribute using "
"parameters.NonParamAttribute.")
else:
return super().__setattr__(name, value)
#===========================================================================
# CompoundParam - Construction
#===========================================================================
def __init_subclass__(cls):
cls._processAnnotations()
super().__init_subclass__()
cls._populateClassParams()
cls._populateSignalsOnClass()
cls._in_configureParam = True
cls._reference_defs = []
try:
cls.configureParam()
finally:
cls._in_configureParam = False
[docs] @classmethod
def setReference(cls, param1, param2):
"""
Call this class method from configureParam to indicate that two params
should be kept in sync. The initial values will start with the default
value of :code:`param1`.
Example::
class Square(CompoundParam):
width: float = 5
height: float = 10
@classmethod
def configureParam(cls):
super().configureParam()
cls.setReference(cls.width, cls.height)
square = Square()
assert square.width == square.height == 5 # Default value of width
# takes priority
square.height = 7
assert square.width == square.height == 7
square.width = 6
assert square.width == square.height == 6
:param param1: The first abstract param to keep synced
:param param2: The second abstract param. After instantiation, this
param will take on the value of param1.
"""
cls._reference_defs.append((param1, param2))
[docs] def __init__(self, default_value=DEFAULT, _param_type=INFER, **kwargs):
self._max_param_repr_len = 80
self._block_signal_propagation = False
self._skip_eq_check = False
self._sub_params = {}
self.DataClass = self.__class__
super().__init__(default_value=default_value, _param_type=_param_type)
if default_value is None:
raise TypeError("Can't set the default value of a compound param "
"to None")
self._non_param_kwargs = self._popNonParamKwargs(kwargs)
self._param_kwargs = kwargs
self._populateSubParamsOnInstance()
if self.isAbstract():
self._initializeAbstractParam()
else:
self._initializeConcreteParam()
[docs] def initializeValue(self):
"""
Override to dynamically set up the default value of the param. Useful
for default values that are determined at runtime. This is called any
time the param is reset.
"""
pass
[docs] def initConcrete(self):
"""
Override to customize initialization of concrete params.
"""
pass
#===========================================================================
# CompoundParam - Common API
#===========================================================================
[docs] @classandinstancemethod
def getSubParam(self, name):
"""
Get the value of a subparam using the string name::
c = Coord()
assert c.getSubParam('x') == 0
.. note::
Using the string name to access params is generally discouraged,
but can be useful for serializing/deserializing param data.
:param name: The name of the subparam to get the value for.
:type name: str
"""
try:
return self._sub_params[name]
except KeyError:
raise AttributeError(f'{self} has no subparam {name}')
[docs] @classandinstancemethod
def getSubParams(self):
"""
Return a dictionary mapping subparam names to their values.
"""
return self._sub_params
#===========================================================================
# CompoundParam - Abstract API
#===========================================================================
[docs] @classandinstancemethod
@abstract_only
def defaultValue(self):
if isinstance(self, type):
return self()
else:
return self._makeConcreteParam()
[docs] @classmethod
def getJsonBlacklist(cls):
"""
Override to customize what params are serialized.
Implementations should return a list of abstract params that should
be omitted from serialization.
..NOTE
Returned abstract params must be direct child params of `cls`,
e.g. `cls.name`, not `cls.coord.x`.
"""
return []
#===========================================================================
# CompoundParam - Concrete API
#===========================================================================
[docs] @concrete_only
def getAbstractParam(self):
"""
Return the corresponding abstract param for this instance.
"""
if self.owner() is None:
return self.__class__
owner = self.owner()
return getattr(owner.getAbstractParam(), self.paramName())
[docs] @concrete_only
def isDefault(self):
"""
Whether the current value of this instance matches the default value.
"""
self._assertIsDefaultValid()
return self == self.getAbstractParam().defaultValue()
[docs] @concrete_only
def setValue(self, value=None, **kwargs):
"""
Set the value of this `CompoundParam` to match :code:`value`.
:param value: The value to set this `CompoundParam` to. It should be
the same type as this `CompoundParam`.
:param kwargs: For internal use only.
"""
if value is None and not kwargs:
raise ValueError('Must supply a value to set to.')
skip_signals = False
if value is None:
value = kwargs
# prevent unnecessary "changed" signals
skip_signals = (not self._skip_eq_check and
value.items() <= self.toDict().items())
else:
# prevent unnecessary "changed" signals
skip_signals = (not self._skip_eq_check and self == value)
with self.block_signal_propagation():
if isinstance(value, CompoundParam):
self._setValueFromParam(value)
elif isinstance(value, dict):
self._setValueFromDict(value)
else:
raise TypeError(f'Cannot set value of {self} to {value}')
if not skip_signals:
self._emitChangedSignals()
[docs] @concrete_only
def toDict(self):
"""
Return a dictionary version of this `CompoundParam`. The returned
dictionary is fully nested and contains no `CompoundParam` instances ::
a = Atom()
a_dict = a.toDict()
assert a_dict['coord']['x'] == 0
assert a_dict['coord'] == {'x':0, 'y':0}
"""
param_dict = ParamDict()
for name, param in self.getSubParams().items():
if isinstance(param, CompoundParam):
param_dict[name] = param.toDict()
else:
param_dict[name] = param
return param_dict
[docs] @concrete_only
def reset(self, *abstract_params):
"""
Resets this compound param to its default value::
class Line(CompoundParam):
start = Coord(x=1, y=2)
end = Coord(x=4, y=5)
line = Line()
line.start.x = line.end.x = 10
assert line.start.x == line.end.x == 10
line.reset()
assert line.start.x == 1
assert line.end.x == 4
Any number of abstract params may be passed in to perform a partial
reset of only the specified params::
line.start.x = line.end.x = 10
line.reset(Line.start.x) # resets just start.x
assert line.start.x == 1
assert line.end.x == 10
line.reset(Line.end) # resets the entire end point
assert line.end.x == 4
line.start.y = line.end.y = 10
line.reset(Line.start.y, Line.end.y) # resets the y-coord of both
assert line.start.y == 2
assert line.end.y == 5
"""
no_params_spec = False
if not abstract_params:
abstract_params = [self.getAbstractParam()]
no_params_spec = True
owner_chain = self.ownerChain()
default_value = self.getAbstractParam().defaultValue()
if abstract_params and len(self.ownerChain()) > 1:
# `reset` needs to be called through the top-level param, so we
# get the top-level param and pass the call on.
normalized_params = []
for abs_param in abstract_params:
normalized_param = _normalize_abstract_param(abs_param, self)
normalized_params.append(normalized_param)
self.ownerChain()[0].reset(*normalized_params)
return
for abstract_param in abstract_params:
if not abstract_param.isAbstract():
raise TypeError('Arguments to reset must be abstract params.')
top_abstract_param = abstract_param.ownerChain()[0]
for obj in owner_chain:
if issubclass(type(obj), top_abstract_param):
break
else:
raise TypeError(f'Cannot find {abstract_param} in {self}.')
abstract_param.setParamValue(
self, abstract_param.getParamValue(default_value))
if no_params_spec:
self.initializeValue()
[docs] @concrete_only
def toJsonImplementation(self):
"""
Returns a JSON representation of this value object.
.. WARNING:: This should never be called directly.
"""
# TODO: Add a way to allow users to extend this for non-param attributes
blacklist = self.getJsonBlacklist()
blacklist_names = {
blacklisted_param.paramName() for blacklisted_param in blacklist
}
subparams = self.getSubParams()
subparams = {
subparam_name: subparams[subparam_name]
for subparam_name in subparams
if subparam_name not in blacklist_names
}
return subparams
[docs] @classmethod
def fromJsonImplementation(cls, json_dict):
"""
Sets the value of this compound param value object from a JSON dict.
.. WARNING:: This should never be called directly.
"""
for k, v in list(json_dict.items()):
param = cls.getSubParam(k)
if v is None:
json_dict[k] = None
else:
json_dict[k] = json.decode(v, DataClass=param.getTypeHint())
new_param = cls(_param_type=ParamType.CONCRETE, **json_dict)
return new_param
#===========================================================================
# CompoundParam - Descriptor methods
#===========================================================================
@descriptor_only
def __set__(self, owner_instance, value):
if not isinstance(value, self.DataClass):
raise TypeError(f"Cannot set {self.DataClass.__name__} to {value}")
old_value = owner_instance.getSubParam(self.paramName())
if value is old_value:
return
super().__set__(owner_instance, value)
signal = owner_instance._getSignal(self.paramName(),
SignalType.Replaced)
signal.emit(old_value, value)
@descriptor_only
def _makeAbstractParam(self):
abstract_param = super()._makeAbstractParam()
self._initializeAbstractParam(abstract_param)
return abstract_param
#===========================================================================
# CompoundParam - Internal Impelementation Methods
#===========================================================================
@classmethod
def _processAnnotations(cls):
"""
Process type annotations and turn them into params. This allows
the following to be equivalent::
class Coord(parameters.CompoundParam):
x = FloatParam
y = FloatParam(10)
label = StringParam('my_coord')
class Coord(parameters.CompoundParam):
x: float
y: float = 10
label: str = 'my_coord'
"""
annots = {}
for base_cls in cls.mro()[:0:-1]:
annots.update(getattr(base_cls, '_mixin_annotations', {}))
if issubclass(base_cls, CompoundParam):
# If the inherited class is a subclass of CompoundParam, then
# it has already processed the mixin annotations.
annots = {}
for attr in cls.__dict__:
if attr in annots:
annots.pop(attr)
annots.update(_get_uninherited_type_hints(cls))
if not annots:
# We look directly on the class dict otherwise if no annotations
# are defined on `cls`, we'll end up redundantly defining params
# defined on the parent class.
return
for attr_name, (type_, default_value) in annots.items():
if get_origin(type_):
type_args = get_args(type_)
type_ = get_origin(type_)
else:
type_, type_args = type_, tuple()
if (default_value is not DEFAULT and
permissive_issubclass(type_, CompoundParam)):
# my_param: MyParam = MyParam(foo=1, bar=2)
param = type_(default_value)
elif not isinstance(type_, type) and isinstance(
type_, CompoundParam):
# my_param: MyParam(foo=1, bar=2)
param = type_
elif permissive_issubclass(type_, CompoundParam):
# my_param: MyParam
param = type_()
elif cls._isEnumType(type_):
# my_shape_param: enum.Enum('Shape', 'square triangle')
param = EnumParam(type_, default_value=default_value)
elif permissive_issubclass(type_, _SUPPORTED_MUTABLE_TYPES):
# my_list_param: list
# my_list_param: List[int]
# etc for other mutable params
param = cls._getMutableParamInstance(type_, type_args,
default_value)
else:
# my_basic_param: int
# my_basic_param: float
# my_basic_param: str
# etc
param = Param(DataClass=type_, default_value=default_value)
param.__set_name__(cls, attr_name)
setattr(cls, attr_name, param)
@classmethod
def _getMutableParamInstance(cls, type_, type_args, default_value):
"""
Figure out what type of mutable param corresponds to `type_` and
return an instance of it. Also check to see if the item type is
specified and use it when initializing the mutable param.
:param `type_`: One of (list, typing.List, set, typing.Set, dict,
typing.Dict, scollections.IdDict)
:type `type_`: type
:return: An instance of ListParam, ParamListParam, SetParam,
IdDictParam or DictParam
"""
if issubclass(type_, list):
is_list_subclass = not (type_ is list or List in type_.mro())
if not is_list_subclass:
if type_args:
item_class = type_args[0]
else:
item_class = None
if permissive_issubclass(item_class, CompoundParam):
param = ParamListParam(item_class,
default_value=default_value)
else:
param = ListParam(item_class, default_value=default_value)
return param
else:
class ListSubclassParam(ListParam):
DataClass = type_
return ListSubclassParam(default_value=default_value)
elif issubclass(type_, set):
is_set_subclass = not (type_ is set or Set in type_.mro())
if not is_set_subclass:
if type_args:
item_class = type_args[0]
else:
item_class = None
return SetParam(item_class, default_value=default_value)
else:
class SetSubclassParam(SetParam):
DataClass = type_
return SetSubclassParam(default_value=default_value)
elif issubclass(type_, dict):
is_dict_subclass = not (type_ is dict or Dict in type_.mro())
if not is_dict_subclass:
if type_args:
value_class = type_args[1]
else:
value_class = None
return DictParam(value_class, default_value=default_value)
else:
class DictSubclassParam(DictParam):
DataClass = type_
return DictSubclassParam(default_value=default_value)
else:
raise RuntimeError('A subclass of (list, set, dict) should have '
f'been passed to this method, not {type_}')
@classmethod
def _isEnumType(cls, type_):
"""
Determine whether `type_` is a subclass of enum. This includes special
logic for `enum_speedup`, which fails a simple `issubclass(speedy_enum,
enum.Enum)` check.
"""
if issubclass(type_, enum.Enum):
return True
try:
enum_member = next(iter(type_))
except (TypeError, StopIteration):
return False
return isinstance(enum_member, enum.Enum)
@classmethod
def _populateClassParams(cls):
bases = list(cls.mro())
bases.reverse()
for base_class in bases[:-1]:
if issubclass(base_class, (CompoundParam, CompoundParamMixin)):
for name, descriptor in base_class._sub_params.items():
# abstract params on class are descriptors
cls.addSubParam(name, descriptor, update_owner=False)
if hasattr(base_class, '_mixin_params'):
for name, attr in list(base_class._mixin_params.items()):
descriptor = attr
cls.addSubParam(name, descriptor, update_owner=False)
for name, attr in list(cls.__dict__.items()):
if isinstance(attr, Param):
descriptor = attr
cls.addSubParam(name, descriptor)
def _assertValidScope(self):
"""
@overrides: Param
"""
# CompoundParams can be created outside of class scope
@classmethod
def _populateSignalsOnClass(cls):
for name in cls._sub_params:
signal_name = name + SignalType.Changed
signal = QtCore.pyqtSignal(object)
setattr(cls, signal_name, signal)
signal_name = name + SignalType.Replaced
signal = QtCore.pyqtSignal(object, object)
setattr(cls, signal_name, signal)
def _populateSubParamsOnInstance(self):
cls = type(self)
if self.isAbstract():
for name, descriptor in cls.getSubParams().items():
new_param = descriptor._makeAbstractParam()
self.addSubParam(name, new_param)
else:
for name, abs_param in cls.getSubParams().items():
new_param = abs_param._makeConcreteParam(owner=self)
self.addSubParam(name, new_param)
def _popNonParamKwargs(self, kwargs):
np_kwargs = {}
for key in list(kwargs.keys()):
if key not in self.__class__.getSubParams():
np_kwargs[key] = kwargs.pop(key)
return np_kwargs
[docs] @classandinstancemethod
def addSubParam(self, name, param, update_owner=True):
if isinstance(param, Param) and update_owner:
param._setOwner(self)
self._sub_params[name] = param
@classandinstancemethod
@abstract_only
def _applyDefaultValueTo(self, concrete_param):
for name, abs_subparam in self.getSubParams().items():
if isinstance(abs_subparam, CompoundParam):
conc_subparam = concrete_param.getSubParam(name)
abs_subparam._applyDefaultValueTo(conc_subparam)
else:
def_value = abs_subparam.defaultValue()
if isinstance(def_value, _MutateToMatchMixin):
concrete_param.getSubParam(name)._mutateToMatch(def_value)
else:
concrete_param._setSubParam(name, def_value)
concrete_param.initializeValue()
@classandinstancemethod
@abstract_only
def _setDefaultValueFrom(self, value):
if isinstance(value, CompoundParam):
if value.isAbstract():
abs_param = value
value = abs_param._makeConcreteParam()
if isinstance(value, CompoundParam):
self._setDefaultValueFromParam(value)
elif isinstance(value, dict):
self._setDefaultValueFromDict(value)
else:
raise TypeError(f'{value} is not a valid default value for a '
'param.')
@classandinstancemethod
@abstract_only
def _setDefaultValueFromParam(self, concrete_param):
for name, abs_subparam in self.getSubParams().items():
conc_subparam = concrete_param.getSubParam(name)
abs_subparam._setDefaultValueFrom(conc_subparam)
@classandinstancemethod
@abstract_only
def _setDefaultValueFromDict(self, value_dict):
for name, value in value_dict.items():
abs_subparam = self.getSubParam(name)
abs_subparam._setDefaultValueFrom(value)
@abstract_only
def _initializeAbstractParam(self, abs_param=None):
if abs_param is None:
abs_param = self
abs_param._non_param_kwargs = self._non_param_kwargs
if self._default_value_arg is not DEFAULT:
abs_param._setDefaultValueFrom(self._default_value_arg)
if self._param_kwargs:
abs_param._setDefaultValueFrom(self._param_kwargs)
@abstract_only
def _makeConcreteParam(self, owner=None):
param = self.__class__(_param_type=ParamType.CONCRETE,
**self._non_param_kwargs)
param._param_type = ParamType.CONCRETE
self._applyDefaultValueTo(param)
param._param_name = self.paramName()
if self._param_kwargs:
with param.skip_eq_check():
param.setValue(self._param_kwargs)
return param
@concrete_only
def _initializeConcreteParam(self):
if self._default_value_arg is not DEFAULT:
with self.skip_eq_check():
self.setValue(self._default_value_arg)
self.initializeValue()
if self._param_kwargs:
with self.skip_eq_check():
abstract_self = self.getAbstractParam()
for k, v in self._param_kwargs.items():
abstract_subparam = abstract_self.getSubParam(k)
# For MutableParams, we just mutate to match. For all other
# params, we use the actual values passed into the kwargs.
if isinstance(abstract_subparam, BaseMutableParam):
current_subparam = self.getSubParam(k)
current_subparam._mutateToMatch(v)
else:
self._setSubParam(k, v)
self.initConcrete(**self._non_param_kwargs)
self._setUpReferences()
@concrete_only
def _setUpReferences(self):
from schrodinger.models import mappers
if not self._reference_defs:
return
self._reference_mapper = mappers.TargetParamMapper(_display_ok=False)
for param1, param2 in self._reference_defs:
self._reference_mapper.addMapping(
mappers.ParamTargetSpec(self, param2), param1)
self._reference_mapper.setModel(self)
@concrete_only
def _assertIsDefaultValid(self):
"""
Confirm `isDefault` is safe to check. Specifically, confirm that
all atomic params have simple dataclasses (e.g. ints, floats, lists).
"""
abs_param = self.getAbstractParam()
for name, abs_subparam in abs_param.getSubParams().items():
dc = abs_subparam.DataClass
if isinstance(abs_subparam, CompoundParam):
self.getSubParam(name)._assertIsDefaultValid()
elif issubclass(dc, _SUPPORTS_ISDEFAULT_DATACLASSES):
continue
else:
raise TypeError(f"Cannot use isDefault because {abs_subparam}"
f" is of type {dc}.")
@concrete_only
def _setSubParam(self, name, param):
old_param = self.getSubParam(name)
if isinstance(param, CompoundParam):
old_param._setOwner(None)
old_param._param_name = None
param._setOwner(self)
param._param_name = name
self._sub_params[name] = param
if not self._skip_eq_check:
if old_param != param:
self._emitChangedSignals(name, param)
@concrete_only
def _getSignal(self, name, signal_type=SignalType.Changed):
if signal_type not in (SignalType.Changed, SignalType.Replaced):
raise ValueError('Invalid signal type.')
signal_name = name + signal_type
signal = getattr(self, signal_name)
return signal
@concrete_only
def _emitChangedSignals(self, name=None, value=None):
if name is not None:
signal = self._getSignal(name)
signal.emit(value)
if self._block_signal_propagation:
return
self.valueChanged.emit(self)
if self.owner() is not None:
self.owner()._emitChangedSignals(self.paramName(), self)
@concrete_only
def _setValueFromDict(self, value):
for key, val in value.items():
subparam = self.getSubParam(key)
if isinstance(subparam, CompoundParam):
with subparam.skip_eq_check():
subparam.setValue(val)
else:
if isinstance(subparam, _MutateToMatchMixin):
subparam._mutateToMatch(val)
else:
self._setSubParam(key, val)
@concrete_only
def _setValueFromParam(self, value):
if value.isAbstract():
value = value._makeConcreteParam()
for name, val_param in value.getSubParams().items():
if isinstance(val_param, _MutateToMatchMixin):
subparam = self.getSubParam(name)
if self._skip_eq_check and isinstance(subparam, CompoundParam):
with subparam.skip_eq_check():
subparam._mutateToMatch(val_param)
else:
subparam._mutateToMatch(val_param)
else:
self._setSubParam(name, val_param)
def _mutateToMatch(self, value):
self.setValue(value)
@concrete_only
def __copy__(self):
return type(self)(self)
@concrete_only
def __deepcopy__(self, memo):
"""
Create a special deepcopy of this param. The normal deepcopy will
raise an exception if any of the values are unpicklable. The deepcopy
returned here will simply skip copying any unpicklable values. This
means that the deepcopy value will refer to the same original value.
This is done so that params can be used with unpicklable datatypes
(e.g. QPixMap)
"""
cp_inst = self.__class__(_param_type=ParamType.CONCRETE)
self_class = type(self)
for subparam_name, subparam in self.getSubParams().items():
if self_class.getSubParam(subparam_name)._deepcopyable:
setattr(cp_inst, subparam_name, copy.deepcopy(subparam, memo))
else:
setattr(cp_inst, subparam_name, subparam)
return cp_inst
@concrete_only
def __eq__(self, other):
if id(self) == id(other):
return True
else:
if isinstance(other, self.__class__):
return self.toDict() == other.toDict()
return False
def __hash__(self):
if self.isAbstract():
return id(self)
else:
raise TypeError(
'Concrete parameters are unhashable. Consider '
'using an scollections.IdDict or IdSet if you need to '
'use this as a key.')
@concrete_only
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
try:
if self._param_type is ParamType.CONCRETE:
abs_param = self.getAbstractParam()
abs_string = abs_param._chainString()
s = str(self.toDict())
if len(s) > self._max_param_repr_len:
s = s[:self._max_param_repr_len - 4] + '...}'
return f'<Concrete:{abs_string} {s}>'
except AttributeError:
return 'Error encountered during __repr__. ' + super().__repr__()
return super().__repr__()
ParamModel = CompoundParam
"""
DEPRECATED
Used to be a placeholder in case we wanted to add any logic for the top-
level model.
"""
#===============================================================================
# Param Helpers
#===============================================================================
[docs]class ParamAttributeError(RuntimeError):
"""
We catch `AttributeError`s raised inside of DataClass initialization and
inside of param __get__ and __set__ methods and reraise them as
`ParamAttributeError`s. If we allow them to be raised as normal
`AttributeError`s then `QObject.__getattr__` will swallow the exception
and its helpful traceback and raise a different and vague exception.
If you see this error, it's most likely caused by an attempt to
access an attribute that doesn't exist in a param's DataClass.
"""
_SUPPORTED_MUTABLE_TYPES = (list, set, dict, IdDict)
_SUPPORTED_ANNOTATIONS = (List, Set, Dict)
_SUPPORTS_ISDEFAULT_DATACLASSES = (
set,
list,
dict,
enum.Enum,
int,
float,
bool,
str,
)
[docs]class ParamDict(dict):
"""
A helper class to distinguish between dicts used to represent params and
regular dicts. It has no additional functionality and should be used
sparingly.
"""
[docs]def get_all_replaced_signals(obj, abs_param=None):
if abs_param is None:
abs_param = obj.__class__
signals = []
for abs_subparam in abs_param.getSubParams().values():
if isinstance(abs_subparam, CompoundParam):
signals.append(
abs_subparam.getParamSignal(obj,
signal_type=SignalType.Replaced))
signals.extend(get_all_replaced_signals(obj, abs_subparam))
return signals
[docs]def get_all_atomic_subparams(abs_param):
atomic_params = []
for abs_subparam in abs_param.getSubParams().values():
if isinstance(abs_subparam, CompoundParam):
atomic_params.extend(get_all_atomic_subparams(abs_subparam))
else:
atomic_params.append(abs_subparam)
return atomic_params
[docs]def get_all_compound_subparams(param):
for subparam in param.getSubParams().values():
if isinstance(subparam, CompoundParam):
yield subparam
yield from get_all_compound_subparams(subparam)
def _normalize_abstract_param(abs_param, concrete_param):
"""
Given an abstract param and a concrete param of some submodel, return an
abstract param whose toplevel matches the toplevel of the submodel. For
example::
_normalize_abstract_param(Coord.x, my_model.coord) -> Model.coord.x
_normalize_abstract_param(Coord, my_model.coord) -> Model.coord
_normalize_abstract_param(Model, my_model.coord) -> Model
"""
if abs_param.ownerChain()[0] is not concrete_param.getAbstractParam(
).ownerChain()[0]:
value_param = abs_param.getParamValue(concrete_param)
if isinstance(value_param, CompoundParam):
abs_param = value_param.getAbstractParam()
else:
parent_param = abs_param.ownerChain()[-2]
parent_value_param = parent_param.getParamValue(concrete_param)
parent_abs_param = parent_value_param.getAbstractParam()
abs_param = getattr(parent_abs_param, abs_param.paramName())
return abs_param
def _get_uninherited_type_hints(cls) -> Dict[str, Tuple[type, object]]:
"""
Get all type hints and their values defined specifically on `cls`. This is
a strict subset of the type hints returned by `typing.get_type_hints`,
which returns _all_ type hints of a class, including ones defined by super-
classes.
"""
uninherited_type_annotations = {}
if '__annotations__' not in cls.__dict__:
return uninherited_type_annotations
all_annotations = typing.get_type_hints(cls, globals())
for annot_name, hint in all_annotations.items():
if annot_name in cls.__dict__['__annotations__']:
attr_value = cls.__dict__.get(annot_name, DEFAULT)
uninherited_type_annotations[annot_name] = (hint, attr_value)
return uninherited_type_annotations
[docs]def permissive_issubclass(obj, cls):
"""
Return `issubclass(obj, cls)` without raising an exception if `obj` is not
a class. Usable with `typing` classes as well.
"""
if get_origin(obj):
obj = get_origin(obj)
return isinstance(obj, type) and issubclass(obj, cls)
#===============================================================================
# Primitive Param Classes
#===============================================================================
[docs]class EnumParam(Param):
[docs] def __init__(self, enum_class, default_value=DEFAULT, *args, **kwargs):
"""
EnumParam must be initialized with the Enum class that this param is to
be based on as well as the default value.
:param enum_class: the enum class to base this pram on
:type enum_class: enum.Enum
:param default_value: The default enum value. If not provided, the
param will default to the first value of the enum.
:type default_value: a member of the enum_class
"""
self.DataClass = enum_class
if default_value is DEFAULT:
# If not default_value is provided, use the first member of
# the enum as the default.
default_value = list(enum_class)[0]
super().__init__(*args, default_value=default_value, **kwargs)
@descriptor_only
def __set__(self, owner_instance, value):
if not isinstance(value, (self.DataClass, type(None))):
raise TypeError(f"Value should be of type {type(self.DataClass)} "
f"or None. Got {value}.")
super().__set__(owner_instance, value)
#===============================================================================
# Mutable Param Framework
#===============================================================================
class _SignalContainer(QtCore.QObject):
"""
A simple QObject to hold the mutated signal. This is needed because it is
not always possible to do multiple inheritance from QObject and another
class (e.g. dict).
:cvar mutated: a signal emitted when the mutable param object (i.e. the
container) is mutated (by adding or removing and item). Signal is
emitted with the current and previous state of the param object.
:cvar itemChanged: a signal emitted when the value of an item in the
mutable param object (i.e. the container) is changed. This signal is
not emitted if the item is removed from the container.
"""
mutated = QtCore.pyqtSignal(object, object)
itemChanged = QtCore.pyqtSignal(CompoundParam)
[docs]def generate_method(parent_cls, method_name, signal_name):
"""
Creates a new method with the given name that first calls the method of the
same name from the parent class and then emits the specified signal.
:param parent_cls: the parent class
:param method_name: the name of the new method to be generated
:type method_name: str
:param signal: the signal that should be emitted whenever the method is
called
:type signal: QtCore.pyqtSignal
"""
og_method = getattr(parent_cls, method_name)
def new_method(self, *args, **kwargs):
if method_name in self.ITERABLE_ARG_METHOD_NAMES:
# We need to make sure that all iterable arguments are put into
# a data structure so they aren't exhausted when we pass
# them to the echo.
args = [self.DataClass(arg) for arg in args]
kwargs = {k: self.DataClass(v) for k, v in kwargs.items()}
ret = og_method(self, *args, **kwargs)
signal = getattr(self, signal_name)
signal.emit(self, self._echo)
getattr(self.DataClass, method_name)(self._echo, *args, **kwargs)
return ret
return new_method
[docs]class BaseMutableParam(Param):
"""
Base class for mutable params (eg lists, sets, dicts, etc). Child classes
should specify the names of mutation methods on the class, the method used
to match an instance of the class to another instance (eg 'extend' for
list), what methods take iterables, and the actual DataClass. Upon subclass
declaration, the `DataClass` will be dynamically wrapped so signals are
emitted whenever an instance is mutated.
"""
MUTATE_METHOD_NAMES = ()
REPLACE_METHOD_NAME = None
ITERABLE_ARG_METHOD_NAMES = None
DataClass = None
def __init_subclass__(cls):
if 'WithSignal' in cls.DataClass.__name__:
return
class DataClassWithSignal(_MutateToMatchMixin, cls.DataClass):
DataClass = cls.DataClass
ITERABLE_ARG_METHOD_NAMES = cls.ITERABLE_ARG_METHOD_NAMES
def __init__(self, *args, **kwargs):
self._signals = _SignalContainer()
for name, signal in get_signals(self._signals).items():
setattr(self, name, signal)
self._starting_state = None # For use with emitMutated
super().__init__(*args, **kwargs)
# We keep a normal copy of the DataClass that gets
# updated on a delay. This is then used in the `mutated`
# signal so we can emit what this instance looked like
# before a mutation.
self._echo = self.DataClass(*args, **kwargs)
def __deepcopy__(self, memo):
return copy.deepcopy(copy.copy(self), memo)
def __copy__(self):
return self.copy()
def blockSignals(self, should_block):
if should_block and not self._signals.signalsBlocked():
self._starting_state = copy.copy(self)
return self._signals.blockSignals(should_block)
def emitMutated(self):
"""
This method can be called after signals are blocked and
unblocked on the object. It will cause the object to
emit the `mutated` signal with the state of the object at
the start of signal blocking. This is useful if slots
need to be notified that the object was modified after
blocking signals and doing multiple mutations at once.
Example usage::
my_model.list_param = [1,2,3]
lp = my_model.list_param
lp.blockSignals(True)
lp.extend([4,5])
lp.pop(0)
lp.blockSignals(False)
lp.emitMutated # lp.mutated will be emitted with [1,2,3]
If signals haven't been blocked, then the mutated signal is
just emitted with the current state of the object.
"""
if self._starting_state is not None:
self.mutated.emit(self, self._starting_state)
self._starting_state = None
else:
self.mutated.emit(self, copy.copy(self))
def _mutateToMatch(self,
new_value,
*,
suppress_mutated_signal=False):
"""
Mutate this param so it equals `new_value`.
:param new_value: The value to mutate to
:type new_value: DataClassWithSignal
:param suppress_mutated_signal: If True, the `mutated` signal
will not be emitted even if `new_value` doesn't equal `self`
(and `_starting_state` won't be updated). In this scenario,
the calling code is responsible for calling `emitMutated`.
:type suppress_mutated_signal: bool
"""
if new_value is self:
return
should_emit_signals = (not suppress_mutated_signal and
self != new_value)
with suppress_signals(self):
self.clear()
getattr(self, cls.REPLACE_METHOD_NAME)(new_value)
if should_emit_signals:
self.emitMutated()
for method_name in cls.MUTATE_METHOD_NAMES:
new_method = generate_method(DataClassWithSignal, method_name,
'mutated')
setattr(DataClassWithSignal, method_name, new_method)
DataClassWithSignal.__name__ = cls.DataClass.__name__ + 'WithSignal'
cls.DataClass = DataClassWithSignal
[docs] def __init__(self, default_value=DEFAULT, *args, **kwargs):
if default_value is None:
raise ValueError("Can't set the default value of a mutable param "
"to None")
super().__init__(default_value=default_value, *args, **kwargs)
def __set__(self, owner_instance, new_value):
"""
The value is initialized as an empty instance of `DataClass` with
attached signals. All assignments will then simply mutate the original
value to match the new value. This is done to preserve signal-slot
connections.
"""
if not hasattr(new_value, self.REPLACE_METHOD_NAME):
err_msg = (
f'{self.paramName()} can only be set as an instance of classes '
f'that implement {self.REPLACE_METHOD_NAME}.')
raise TypeError(err_msg)
old_value = owner_instance.getSubParam(self.paramName())
old_value._mutateToMatch(new_value)
# We suppress signals and then just emit one ourselves so we don't
# get multiple signal emissions for setting one value.
def _makeConcreteParam(self, owner=None):
new_value = super()._makeConcreteParam(owner)
if owner is not None:
def slot():
owner._emitChangedSignals(self.paramName(), new_value)
new_value.mutated.connect(slot)
new_value.itemChanged.connect(slot)
return new_value
@abstract_only
def _prepareDefaultValue(self, def_arg):
"""
See parent class for method documentation. Note that here, we need to
convert def_arg from DataClass to DataClassWithSignal. We also need to
generate a shallow copy of def_arg; otherwise, mutations in one param
instance would affect other param instances.
"""
return self.DataClass(def_arg)
def _inspectDataclassSignature(self, dataclass_constructor_signature,
num_required_args):
# See parent class for method documentation
# Note that this method currently has no effect due to PANEL-14810
super()._inspectDataclassSignature(dataclass_constructor_signature,
num_required_args)
if self._default_value_arg is not DEFAULT:
num_args = len(dataclass_constructor_signature.parameters)
if num_required_args > 1:
raise TypeError("More than one arguments are required to "
"instantiate the DataClass. Try specifying "
"None as the default value.")
if num_args == 0:
raise TypeError(
"The default value is passed into the DataClass's "
"constructor but the constructor does not "
"accept any arguments.")
#===============================================================================
# Mutable Params
#===============================================================================
[docs]class DictParam(BaseMutableParam):
"""
A Param to represent dictionaries. Values of this param will have a
`mutated` signal that will be emitted whenever any mutation method
is called.
The constructor optionally takes a `value_class` keyword argument
to specify what type of class the values will be. This information will be
used for jsonifying the dictionary if specified. (Note that non-string
keys are not currently supported for jsonification. This may change in
the future. See PANEL-13029).
"""
DataClass = dict
MUTATE_METHOD_NAMES = ('__setitem__', '__delitem__', 'pop', 'popitem',
'clear', 'update', 'setdefault')
ITERABLE_ARG_METHOD_NAMES = {'update'}
REPLACE_METHOD_NAME = 'update'
[docs] def __init__(self, value_class=None, **kwargs):
self.value_class = value_class
super().__init__(**kwargs)
[docs] def getTypeHint(self):
if self.value_class:
return Dict[str, self.value_class]
else:
return Dict
[docs]class ItemClassMixin:
[docs] def __init__(self, item_class, **kwargs):
super().__init__(**kwargs)
self.item_class = item_class
class _ListWithoutIadd(list):
def __iadd__(self, other):
err_msg = "iadd not supported for ListParams. Use extend instead."
raise NotImplementedError(err_msg)
[docs]class ListParam(ItemClassMixin, BaseMutableParam):
"""
A Param to represent lists. Values of this param will have a
`mutated` signal that will be emitted whenever any mutation method
is called.
The constructor optionally takes a `item_class` keyword argument
to specify what type of class the items in the list will be. This
information will be used for jsonifying the list if specified.
"""
[docs] class DataClass(_ListWithoutIadd):
pass
DataClass.__name__ = 'list'
MUTATE_METHOD_NAMES = ('__setitem__', 'append', 'insert', '__delitem__',
'pop', 'remove', 'extend', 'reverse', 'sort',
'clear', '__iadd__')
ITERABLE_ARG_METHOD_NAMES = {'extend'}
REPLACE_METHOD_NAME = 'extend'
[docs] def __init__(self, item_class=None, **kwargs):
self.item_class = item_class
super().__init__(item_class, **kwargs)
[docs] def getTypeHint(self):
if self.item_class:
return List[self.item_class]
else:
return List
[docs]class SetParam(ItemClassMixin, BaseMutableParam):
"""
A Param to represent sets. Values of this param will have a
`mutated` signal that will be emitted whenever any elment is added
or removed from the set.
The constructor optionally takes a `item_class` keyword argument
to specify what type of class the items in the list will be. This
information will be used for jsonifying the set if specified.
"""
DataClass = jsonable.JsonableSet
MUTATE_METHOD_NAMES = {
'update', 'intersection_update', 'difference_update',
'symmetric_difference_update', 'add', 'remove', 'discard', 'pop',
'clear'
}
ITERABLE_ARG_METHOD_NAMES = {
'update', 'intersection_update', 'difference_update',
'symmetric_difference_update'
}
REPLACE_METHOD_NAME = 'update'
[docs] def __init__(self, item_class=None, **kwargs):
self.item_class = item_class
super().__init__(item_class, **kwargs)
[docs] def getTypeHint(self):
if self.item_class:
return Set[self.item_class]
else:
return Set
class _PLPSignalContainer(QtCore.QObject):
"""
Additional signals for ParamListParam.DataClass. These signals are used in
plptable to emit the appropriate QAbstractItemModel signals in response to
data changes.
:ivar itemsAboutToBeInserted: A signal emitted just before items are
inserted into the list. Emitted with the list indices that the new items
will be positioned between, inclusive. E.g., if this signal is emitted
with `(2, 4)`, then three new items are about to be inserted.
:ivar itemsInserted: A signal emitted just after items are inserted into the
list. Emitted with the same values as the corresponding
`itemsAboutToBeInserted` signal.
:ivar itemsAboutToBeRemoved: A signal emitted just before items are removed
from the list. Emitted with the list first and last indices of the
items to be removed. E.g., if this signal is emitted with `(2, 4)`,
then three items are about to be removed.
:ivar itemsRemoved: A signal emitted just after items are removed from the
list. Emitted with the same values as the corresponding
`itemsAboutToBeRemoved` signal.
:ivar itemsAboutToBeReset: A signal emitted just before the list is
completely changed. This can consist of clearing the list, reordering
elements within the list, or reverting the contents of the list back to
the default values.
:ivar itemsReset: A signal emitted just after the list has been completely
changed. See the `itemsAboutToBeReset` documentation for a description
of what this can entail.
:ivar itemsAtIndicesReplaced: A signal emitted when items in the list have
been replaced. Emitted with the list first and last indices of the
replaced items.
"""
itemsAboutToBeInserted = QtCore.pyqtSignal(int, int)
itemsInserted = QtCore.pyqtSignal(int, int)
itemsAboutToBeRemoved = QtCore.pyqtSignal(int, int)
itemsRemoved = QtCore.pyqtSignal(int, int)
itemsAboutToBeReset = QtCore.pyqtSignal()
itemsReset = QtCore.pyqtSignal()
itemsAtIndicesReplaced = QtCore.pyqtSignal(int, int)
[docs]class ParamListParam(ListParam):
"""
A list param that contains `CompoundParam` instances. Signals will be
emitted any time an item in the list changes or the contents of the list
itself change. See `_SignalContainer` and `_PLPSignalContainer` for
information on specific signals.
"""
[docs] class DataClass(_ListWithoutIadd):
def __init__(self, init_value=None, **kwargs):
self._plp_signals = _PLPSignalContainer()
for name, signal in get_signals(self._plp_signals).items():
setattr(self, name, signal)
super().__init__(init_value, **kwargs)
def _canonicalizeSliceIndex(self, index):
"""
Convert a list index used in a slice (or `.insert`, which
handles indices the same way as slices) so it's between 0 and
`len(list)`. Note that this method will never raise an
IndexError, as slice indices are allowed to extend past the end
of the list.
:param index: The index to canonicalize
:type index: int
:return: The canonicalized index
:rtype: int
"""
if index < 0:
return max(0, len(self) + index)
else:
return min(len(self), index)
def clear(self):
self.itemsAboutToBeReset.emit()
super().clear()
self.itemsReset.emit()
def extend(self, other):
# we can assume that other is a non-empty list because of the
# ListWithSignal.extend implementation below
start = len(self)
end = start + len(other) - 1
self.itemsAboutToBeInserted.emit(start, end)
super().extend(other)
self.itemsInserted.emit(start, end)
def append(self, item):
index = len(self)
self.itemsAboutToBeInserted.emit(index, index)
super().append(item)
self.itemsInserted.emit(index, index)
def remove(self, item):
index = self.index(item)
item = self[index]
self.itemsAboutToBeRemoved.emit(index, index)
super().remove(item)
self.itemsRemoved.emit(index, index)
def insert(self, index, item):
canonical_index = self._canonicalizeSliceIndex(index)
self.itemsAboutToBeInserted.emit(canonical_index, canonical_index)
super().insert(index, item)
self.itemsInserted.emit(canonical_index, canonical_index)
def pop(self, index=-1):
canonical_index = index % len(self)
self.itemsAboutToBeRemoved.emit(canonical_index, canonical_index)
retval = super().pop(index)
self.itemsRemoved.emit(canonical_index, canonical_index)
return retval
def __delitem__(self, index):
if isinstance(index, slice):
# TODO: handle slices (PANEL-19047)
raise RuntimeError("Slice deletion not yet implemented "
"for ParamListParam")
canonical_index = index % len(self)
self.itemsAboutToBeRemoved.emit(canonical_index, canonical_index)
super().__delitem__(index)
self.itemsRemoved.emit(canonical_index, canonical_index)
def __setitem__(self, index, new_item):
if isinstance(index, slice):
# TODO: handle slices (PANEL-19047)
raise RuntimeError("Slice assignment not yet implemented "
"for ParamListParam")
super().__setitem__(index, new_item)
self.itemsAtIndicesReplaced.emit(index, index)
def reverse(self):
self.itemsAboutToBeReset.emit()
super().reverse()
self.itemsReset.emit()
def sort(self, *args, **kwargs):
self.itemsAboutToBeReset.emit()
super().sort(*args, **kwargs)
self.itemsReset.emit()
DataClass.__name__ = 'list'
[docs] def __init__(self, item_class, *args, **kwargs):
super().__init__(item_class, *args, **kwargs)
class ListWithSignal(self.DataClass):
def __init__(self, init_value=None, **kwargs):
init_value = self._processInitValue(init_value)
super().__init__(init_value, **kwargs)
self._value_changed_slots = IdDict()
self.item_class = item_class
for item in self:
self._connectItem(item)
def _processInitValue(self, init_value):
if init_value is None:
init_value = []
processed_items = []
for item in init_value:
if item.isAbstract():
item = item._makeConcreteParam()
processed_items.append(item)
return processed_items
def clear(self):
for item in self:
self._disconnectItem(item)
super().clear()
def _disconnectItem(self, item):
slots = self._value_changed_slots[item]
item.valueChanged.disconnect(slots.pop())
def _connectItem(self, item):
def value_changed_slot():
self.itemChanged.emit(item)
item.valueChanged.connect(value_changed_slot)
self._value_changed_slots.setdefault(
item, []).append(value_changed_slot)
def extend(self, other):
# make sure we don't consume an iterator
other = list(other)
if not other:
return
if any(not isinstance(item, Param) for item in other):
raise ValueError(
'New members of a ParamListParam must be Params')
for item in other:
self._connectItem(item)
super().extend(other)
def append(self, item):
if not isinstance(item, Param):
raise ValueError(
'New members of a ParamListParam must be Params')
self._connectItem(item)
super().append(item)
def remove(self, item):
item = self[self.index(item)]
self._disconnectItem(item)
super().remove(item)
def insert(self, index, item):
if not isinstance(item, Param):
raise ValueError(
'New members of a ParamListParam must be Params')
self._connectItem(item)
super().insert(index, item)
def pop(self, index=-1):
item = self[index]
self._disconnectItem(item)
return super().pop(index)
def __delitem__(self, index):
item = self[index]
self._disconnectItem(item)
super().__delitem__(index)
def __setitem__(self, index, new_item):
if not isinstance(new_item, Param):
raise ValueError(
'New members of a ParamListParam must be Params')
old_item = self[index]
self._disconnectItem(old_item)
self._connectItem(new_item)
super().__setitem__(index, new_item)
def _mutateToMatch(self, new_value):
# we override the superclass method to make sure that the
# appropriate signals are emitted and that they're emitted in
# the right order
if new_value is self:
return
should_emit_signals = (self != new_value)
if should_emit_signals:
self.itemsAboutToBeReset.emit()
with suppress_signals(self, self._plp_signals):
super()._mutateToMatch(new_value,
suppress_mutated_signal=True)
if should_emit_signals:
self.itemsReset.emit()
self.emitMutated()
self.DataClass = ListWithSignal