import collections
import inspect
import numpy as np
from nengo.exceptions import ConfigError, ObsoleteError, ReadonlyError, ValidationError
from nengo.rc import rc
from nengo.utils.numpy import (
array_hash,
compare,
is_array,
is_array_like,
is_integer,
is_number,
)
from nengo.utils.stdlib import WeakKeyIDDictionary, checked_call
[docs]class DefaultType:
"""Placeholder object used to represent default values for a parameter."""
def __init__(self, name):
self.name = name
def __repr__(self):
return self.name
Default = DefaultType("Default")
ConnectionDefault = DefaultType("ConnectionDefault")
Unconfigurable = DefaultType("Unconfigurable")
[docs]def is_param(obj):
"""Check if ``obj`` is a Parameter."""
return isinstance(obj, Parameter)
[docs]def iter_params(obj):
"""Iterate over the names of all parameters of an object."""
obj = obj if inspect.isclass(obj) else type(obj)
return (
name
for name in dir(obj)
if is_param(getattr(obj, name))
and not isinstance(getattr(obj, name), ObsoleteParam)
)
[docs]def equal(a, b):
"""Check if two (possibly array-like) objects are equal."""
if is_array_like(a) or is_array_like(b):
return np.array_equal(a, b)
else:
return a == b
[docs]class Parameter:
"""Simple descriptor for storing configuration parameters.
Parameters
----------
name : str
Name of the parameter.
default : object
The value returned if the parameter hasn't been explicitly set.
optional : bool, optional
Whether this parameter accepts the value None. By default,
parameters are not optional (i.e., cannot be set to ``None``).
readonly : bool, optional
If true, the parameter can only be set once.
By default, parameters can be set multiple times.
Attributes
----------
coerce_defaults : bool
If True, validate values for this parameter when they are set in a
`.Config` object. Setting a parameter directly on an object will
always be validated.
equatable : bool
If True, parameter values can be compared for equality
(``a==b``); otherwise equality checks will just compare object
identity (``a is b``).
"""
coerce_defaults = True
equatable = False
def __init__(self, name, default=Unconfigurable, optional=False, readonly=None):
# freeze Unconfigurables by default
readonly = default is Unconfigurable if readonly is None else readonly
if not isinstance(name, str):
raise ValueError("'name' must be a string (got %r)" % name)
if not isinstance(optional, bool):
raise ValueError("'optional' must be boolean (got %r)" % optional)
if not isinstance(readonly, bool):
raise ValueError("'readonly' must be boolean (got %r)" % readonly)
self.name = name
self.default = default
self.optional = optional
self.readonly = readonly
# default values set by config system
self._defaults = WeakKeyIDDictionary()
# param values set on objects
self.data = WeakKeyIDDictionary()
def __getstate__(self):
state = {}
state.update(self.__dict__)
state["_defaults"] = dict(state["_defaults"].items())
state["data"] = dict(state["data"].items())
return state
def __setstate__(self, state):
for k, v in state.items():
if k in ["_defaults", "data"]:
v = WeakKeyIDDictionary(v)
setattr(self, k, v)
def __contains__(self, key):
return key in self.data or key in self._defaults
def __delete__(self, instance):
del self.data[instance]
def __get__(self, instance, type_):
if instance is None:
# Return self so default can be inspected
return self
if not self.configurable and instance not in self.data:
raise ValidationError(
"Unconfigurable parameters have no defaults. Please ensure the"
" value of the parameter is set before trying to access it.",
attr=self.name,
obj=instance,
)
return self.data.get(instance, self.default)
def __set__(self, instance, value):
self.data[instance] = self.coerce(instance, value)
def __repr__(self):
return "%s(%r, default=%s, optional=%s, readonly=%s)" % (
type(self).__name__,
self.name,
self.default,
self.optional,
self.readonly,
)
@property
def configurable(self):
return self.default is not Unconfigurable
def del_default(self, obj):
del self._defaults[obj]
def get_default(self, obj):
return self._defaults.get(obj, self.default)
def set_default(self, obj, value):
if not self.configurable:
raise ConfigError("Parameter '%s' is not configurable" % self)
self._defaults[obj] = self.coerce(obj, value) if self.coerce_defaults else value
def check_type(self, instance, value, type_):
if value is not None and not isinstance(value, type_):
if isinstance(type_, tuple):
type_str = " or ".join((t.__name__ for t in type_))
else:
type_str = type_.__name__
raise ValidationError(
"Must be of type %r (got type %r)." % (type_str, type(value).__name__),
attr=self.name,
obj=instance,
)
def coerce(self, instance, value):
if isinstance(value, DefaultType):
raise ValidationError(
"Default is not a valid value. To reset a parameter, use 'del'.",
attr=self.name,
obj=instance,
)
if self.readonly and instance in self.data:
raise ReadonlyError(attr=self.name, obj=instance)
if not self.optional and value is None:
raise ValidationError(
"Parameter is not optional; cannot set to None",
attr=self.name,
obj=instance,
)
return value
def equal(self, instance_a, instance_b):
a = self.__get__(instance_a, None)
b = self.__get__(instance_b, None)
if self.equatable:
return equal(a, b)
else:
return a is b
[docs] def hashvalue(self, instance):
"""Returns a hashable value (`hash` can be called on the output)."""
value = self.__get__(instance, None)
if self.equatable:
return value
else:
return id(value)
[docs]class ObsoleteParam(Parameter):
"""A parameter that is no longer supported."""
def __init__(self, name, short_msg, since=None, url=None):
self.short_msg = short_msg
self.since = since
self.url = url
super().__init__(name, optional=True)
def __get__(self, instance, type_):
if instance is None:
# Return self so default can be inspected
return self
self.raise_error()
def coerce(self, instance, value):
if value is not Unconfigurable:
# don't allow setting to anything other than unconfigurable default
self.raise_error()
return value
def raise_error(self):
raise ObsoleteError(self.short_msg, since=self.since, url=self.url)
[docs]class BoolParam(Parameter):
"""A parameter where the value is a boolean."""
equatable = True
def coerce(self, instance, value):
self.check_type(instance, value, bool)
return super().coerce(instance, value)
[docs]class NumberParam(Parameter):
"""A parameter where the value is a number."""
equatable = True
def __init__(
self,
name,
default=Unconfigurable,
low=None,
high=None,
low_open=False,
high_open=False,
optional=False,
readonly=None,
):
self.low = low
self.high = high
self.low_open = low_open
self.high_open = high_open
super().__init__(name, default, optional, readonly)
def coerce(self, instance, num):
if num is not None:
if is_array(num) and num.shape == ():
num = num.item() # convert scalar array to Python object
if not is_number(num):
raise ValidationError(
"Must be a number; got '%s'" % num, attr=self.name, obj=instance
)
low_comp = 0 if self.low_open else -1
if self.low is not None and compare(num, self.low) <= low_comp:
raise ValidationError(
"Value must be greater than %s%s (got %s)"
% ("" if self.low_open else "or equal to ", self.low, num),
attr=self.name,
obj=instance,
)
high_comp = 0 if self.high_open else 1
if self.high is not None and compare(num, self.high) >= high_comp:
raise ValidationError(
"Value must be less than %s%s (got %s)"
% ("" if self.high_open else "or equal to ", self.high, num),
attr=self.name,
obj=instance,
)
return super().coerce(instance, num)
[docs]class IntParam(NumberParam):
"""A parameter where the value is an integer."""
def coerce(self, instance, num):
self.check_type(instance, num, (int, np.integer))
return super().coerce(instance, num)
[docs]class StringParam(Parameter):
"""A parameter where the value is a string."""
equatable = True
def coerce(self, instance, string):
self.check_type(instance, string, (str,))
return super().coerce(instance, string)
[docs]class EnumParam(StringParam):
"""A parameter where the value must be one of a finite set of strings."""
def __init__(
self,
name,
default=Unconfigurable,
values=(),
lower=True,
optional=False,
readonly=None,
):
assert all(isinstance(s, str) for s in values)
if lower:
values = tuple(s.lower() for s in values)
value_set = set(values)
assert len(values) == len(value_set)
self.values = values
self.value_set = value_set
self.lower = lower
super().__init__(name, default, optional, readonly)
def coerce(self, instance, string):
string = super().coerce(instance, string)
string = string.lower() if self.lower else string
if string not in self.value_set:
raise ValidationError(
"String %r must be one of %s" % (string, list(self.values)),
attr=self.name,
obj=instance,
)
return string
[docs]class TupleParam(Parameter):
"""A parameter where the value is a tuple."""
def __init__(
self, name, default=Unconfigurable, length=None, optional=False, readonly=None
):
self.length = length
super().__init__(name, default, optional, readonly)
def coerce(self, instance, value):
value = super().coerce(instance, value)
if value is not None:
try:
value = tuple(value)
except TypeError:
raise ValidationError(
"Value must be castable to a tuple", attr=self.name, obj=instance
)
if self.length is not None and len(value) != self.length:
raise ValidationError(
"Must be %d items (got %d)" % (self.length, len(value)),
attr=self.name,
obj=instance,
)
return value
[docs]class ShapeParam(TupleParam):
"""A parameter where the value is a tuple of integers."""
equatable = True
def __init__(
self,
name,
default=Unconfigurable,
length=None,
low=0,
optional=False,
readonly=None,
):
super().__init__(
name, default=default, length=length, optional=optional, readonly=readonly
)
self.low = low
def coerce(self, instance, value):
value = super().coerce(instance, value)
if value is not None:
for i, v in enumerate(value):
if not is_integer(v):
raise ValidationError(
"Element %d must be an int (got type %r)"
% (i, type(v).__name__),
attr=self.name,
obj=instance,
)
if self.low is not None and v < self.low:
raise ValidationError(
"Element %d must be >= %d (got %d)" % (i, self.low, v),
attr=self.name,
obj=instance,
)
return value
[docs]class DictParam(Parameter):
"""A parameter where the value is a dictionary."""
def coerce(self, instance, value):
self.check_type(instance, value, dict)
return super().coerce(instance, value)
[docs]class NdarrayParam(Parameter):
"""A parameter where the value is a NumPy ndarray.
If the passed value is an ndarray, a view onto that array is stored.
If the passed value is not an ndarray, it will be cast to an ndarray
of ``dtype`` and stored.
"""
equatable = True
def __init__(
self,
name,
default=Unconfigurable,
shape=None,
dtype=None,
optional=False,
readonly=None,
):
if shape is not None:
assert shape.count("...") <= 1, "Cannot have more than one ellipsis"
self.shape = shape
self._dtype = dtype
super().__init__(name, default, optional, readonly)
@property
def coerce_defaults(self):
if self.shape is None:
return True
return all(is_integer(dim) or dim in ("...", "*") for dim in self.shape)
@property
def dtype(self):
if self._dtype is not None:
return self._dtype
return rc.float_dtype
[docs] def hashvalue(self, instance):
return array_hash(self.__get__(instance, None))
def coerce(self, instance, value):
if value is not None:
value = self.coerce_ndarray(instance, value)
return super().coerce(instance, value)
def coerce_ndarray(self, instance, ndarray): # noqa: C901
if isinstance(ndarray, np.ndarray):
ndarray = ndarray.view()
else:
try:
ndarray = np.array(ndarray, dtype=self.dtype)
except (ValueError, TypeError):
raise ValidationError(
"Must be a %s NumPy array (got type %r)"
% (self.dtype, type(ndarray).__name__),
attr=self.name,
obj=instance,
)
if self.readonly:
ndarray.setflags(write=False)
if self.shape is None:
return ndarray
if "..." in self.shape:
# Convert '...' to the appropriate number of '*'s
nfixed = len(self.shape) - 1
n = ndarray.ndim - nfixed
if n < 0:
raise ValidationError(
"ndarray must be at least %dD (got %dD)" % (nfixed, ndarray.ndim),
attr=self.name,
obj=instance,
)
i = self.shape.index("...")
shape = list(self.shape[:i]) + (["*"] * n)
if i < len(self.shape) - 1:
shape.extend(self.shape[i + 1 :])
else:
shape = self.shape
if ndarray.ndim != len(shape):
raise ValidationError(
"ndarray must be %dD (got %dD)" % (len(shape), ndarray.ndim),
attr=self.name,
obj=instance,
)
for i, attr in enumerate(shape):
assert is_integer(attr) or isinstance(
attr, str
), "shape can only be an int or str representing an attribute"
if attr == "*":
continue
desired = attr if is_integer(attr) else getattr(instance, attr)
if not is_integer(desired):
raise ValidationError(
"%s not yet initialized; cannot determine if shape is "
"correct. Consider using a distribution instead." % attr,
attr=self.name,
obj=instance,
)
if ndarray.shape[i] != desired:
raise ValidationError(
"shape[%d] should be %d (got %d)" % (i, desired, ndarray.shape[i]),
attr=self.name,
obj=instance,
)
return ndarray
FunctionInfo = collections.namedtuple("FunctionInfo", ["function", "size"])
[docs]class FunctionParam(Parameter):
"""A parameter where the value is a function."""
def determine_size(self, instance, function):
args = self.function_args(instance, function)
value, invoked = checked_call(function, *args)
if not invoked:
raise ValidationError(
"function '%s' must accept a single np.array argument" % function,
attr=self.name,
obj=instance,
)
return np.asarray(value).size
def function_args(self, instance, function):
return (np.zeros(1),)
def coerce(self, instance, function):
function = super().coerce(instance, function)
if isinstance(function, FunctionInfo):
function_info = function
function = function_info.function
else:
size = (
self.determine_size(instance, function) if callable(function) else None
)
function_info = FunctionInfo(function=function, size=size)
if function is not None and not callable(function):
raise ValidationError(
"function '%s' must be callable" % function,
attr=self.name,
obj=instance,
)
return function_info
[docs]class FrozenObject:
"""An object with parameters that cannot change value after instantiation.
Since such objects are read-only ("frozen"), they can be safely used in
multiple locations, compared, etc.
"""
# Order in which parameters have to be initialized.
# Missing parameters will be initialized last in an undefined order.
# This is needed for pickling and copying of Nengo objects when the
# parameter initialization order matters.
_param_init_order = []
def __init__(self):
self._paramdict = collections.OrderedDict(
(k, v)
for k, v in inspect.getmembers(type(self))
if isinstance(v, Parameter) and not isinstance(v, ObsoleteParam)
)
for p in self._params:
if not p.readonly:
msg = "All parameters of a FrozenObject must be readonly"
raise ReadonlyError(attr=p, obj=self, msg=msg)
self.__argreprs = None
@property
def _params(self):
return list(self._paramdict.values())
def __eq__(self, other):
if self is other: # quick check for speed
return True
return type(self) == type(other) and all(
p.equal(self, other) for p in self._params
)
def __hash__(self):
return hash((type(self), tuple(p.hashvalue(self) for p in self._params)))
def __getstate__(self):
d = dict(self.__dict__)
d.pop("_paramdict") # do not pickle the param dict itself
for k in self._paramdict:
d[k] = getattr(self, k)
return d
def __setstate__(self, state):
FrozenObject.__init__(self) # set up the param dict
for attr in self._param_init_order:
setattr(self, attr, state.pop(attr))
for attr in set(self._paramdict).difference(self._param_init_order):
setattr(self, attr, state.pop(attr))
self.__dict__.update(state)
def __repr__(self):
if isinstance(self._argreprs, str):
return "<%s at 0x%x>" % (type(self).__name__, id(self))
return "%s(%s)" % (type(self).__name__, ", ".join(self._argreprs))
@property
def _argreprs(self):
if self.__argreprs is not None:
return self.__argreprs
# get arguments to display from __init__ functions
spec = inspect.getfullargspec(type(self).__init__)
defaults = {}
if spec.defaults is not None:
defaults.update(zip(spec.args[-len(spec.defaults) :], spec.defaults))
self.__argreprs = []
for arg in spec.args[1:]: # start at 1 to drop `self`
if not hasattr(self, arg):
# We rely on storing the initial arguments. If we don't have
# them, we don't auto-generate a repr.
self.__argreprs = "Cannot find %r" % arg
break
value = getattr(self, arg)
param = self._paramdict.get(arg, None)
if arg in defaults:
not_default = not equal(value, defaults[arg])
elif param is not None and param.default is not Unconfigurable:
not_default = not equal(value, param.default)
else:
not_default = True
if not_default:
self.__argreprs.append("%s=%r" % (arg, value))
return self.__argreprs