Source code for nengo.connection

import logging

import numpy as np

from nengo.base import NengoObject, NengoObjectParam, ObjView
from nengo.dists import Distribution, DistOrArrayParam
from nengo.ensemble import Ensemble, Neurons
from nengo.exceptions import ValidationError
from nengo.learning_rules import LearningRuleType, LearningRuleTypeParam
from nengo.neurons import Direct
from nengo.node import Node
from nengo.params import (Default, Unconfigurable, ObsoleteParam,
                          BoolParam, FunctionInfo, Parameter)
from nengo.solvers import LstsqL2, SolverParam
from nengo.synapses import Lowpass, SynapseParam
from nengo.utils.compat import is_array_like, is_iterable, iteritems
from nengo.utils.functions import function_name
from nengo.utils.stdlib import checked_call

logger = logging.getLogger(__name__)


class PrePostParam(NengoObjectParam):
    def coerce(self, conn, nengo_obj):
        if isinstance(nengo_obj, Connection):
            raise ValidationError(
                "Cannot connect to or from connections. "
                "Did you mean to connect to the connection's learning rule?",
                attr=self.name, obj=conn)
        return super(PrePostParam, self).coerce(conn, nengo_obj)


class ConnectionLearningRuleTypeParam(LearningRuleTypeParam):
    """Connection-specific validation for learning rules."""

    coerce_defaults = False

    def check_rule(self, conn, rule):
        super(ConnectionLearningRuleTypeParam, self).check_rule(conn, rule)

        # --- Check pre object
        pre = conn.pre_obj
        if rule.modifies in ('decoders', 'weights'):
            # pre object must be neural
            if not isinstance(pre, (Ensemble, Neurons)):
                raise ValidationError(
                    "'pre' must be of type 'Ensemble' or 'Neurons' for "
                    "learning rule '%s' (got type %r)" % (
                        rule, type(pre).__name__),
                    attr=self.name, obj=conn)
            if (isinstance(pre, Ensemble)
                    and isinstance(pre.neuron_type, Direct)):
                raise ValidationError(
                    "'pre' cannot have neuron type 'Direct'. Connections from "
                    "'Direct' ensembles do not have decoders or weights.",
                    attr=self.name, obj=conn)

        # --- Check post object
        if rule.modifies == 'encoders':
            if not isinstance(conn.post_obj, Ensemble):
                raise ValidationError(
                    "'post' must be of type 'Ensemble' (got %r) "
                    "for learning rule '%s'"
                    % (type(pre).__name__, rule),
                    attr=self.name, obj=conn)
        else:
            if not isinstance(conn.post_obj, (Ensemble, Neurons, Node)):
                raise ValidationError(
                    "'post' must be of type 'Ensemble', 'Neurons' or 'Node' "
                    "(got %r) for learning rule '%s'"
                    % (type(conn.post_obj).__name__, rule),
                    attr=self.name, obj=conn)

        if rule.modifies == 'weights':
            # If the rule modifies 'weights', then it must have full weights
            if conn.is_decoded:
                raise ValidationError(
                    "Learning rule '%s' can not be applied to decoded "
                    "connections. Try setting solver.weights to True or "
                    "connecting between two Neurons objects." % rule,
                    attr=self.name, obj=conn)

            # transform matrix must be 2D
            pre_size = (
                pre.n_neurons if isinstance(pre, Ensemble)
                else conn.pre.size_out)
            post_size = conn.post.size_in
            if (not conn.solver.weights and
                    conn.transform.shape != (post_size, pre_size)):
                raise ValidationError(
                    "Transform must be 2D array with shape post_neurons x "
                    "pre_neurons (%d, %d)" % (pre_size, post_size),
                    attr=self.name, obj=conn)


class ConnectionSolverParam(SolverParam):
    """Connection-specific validation for decoder solvers."""

    coerce_defaults = False

    def coerce(self, conn, solver):
        solver = super(ConnectionSolverParam, self).coerce(conn, solver)
        if solver is not None:
            if solver.weights and not isinstance(conn.pre_obj, Ensemble):
                raise ValidationError(
                    "weight solvers only work for connections from ensembles "
                    "(got %r)" % type(conn.pre_obj).__name__,
                    attr=self.name, obj=conn)
            if solver.weights and not isinstance(conn.post_obj, Ensemble):
                raise ValidationError(
                    "weight solvers only work for connections to ensembles "
                    "(got %r)" % type(conn.post_obj).__name__,
                    attr=self.name, obj=conn)
        return solver


class EvalPointsParam(DistOrArrayParam):
    coerce_defaults = False

    def coerce(self, conn, distorarray):
        """Eval points are only valid when pre is an ensemble."""
        if distorarray is not None and not isinstance(conn.pre, Ensemble):
            msg = ("eval_points are only valid on connections from ensembles "
                   "(got type '%s')" % type(conn.pre).__name__)
            raise ValidationError(msg, attr=self.name, obj=conn)
        return super(EvalPointsParam, self).coerce(conn, distorarray)


class ConnectionFunctionParam(Parameter):
    """Connection-specific validation for functions."""

    coerce_defaults = False

    def check_array(self, conn, ndarray):
        if not isinstance(conn.eval_points, np.ndarray):
            raise ValidationError(
                "In order to set 'function' to specific points, 'eval_points' "
                "must be also be set to specific points.",
                attr=self.name, obj=conn)

        if ndarray.ndim != 2:
            raise ValidationError("array must be 2D (got %dD)" % ndarray.ndim,
                                  attr=self.name, obj=conn)

        if ndarray.shape[0] != conn.eval_points.shape[0]:
            raise ValidationError(
                "Number of evaluation points must match number "
                "of function points (%d != %d)"
                % (ndarray.shape[0], conn.eval_points.shape[0]),
                attr=self.name, obj=conn)

    def check_function_can_be_applied(self, conn, function_info):
        function, size = function_info
        type_pre = type(conn.pre_obj).__name__

        if function is not None:
            if not isinstance(conn.pre_obj, (Node, Ensemble)):
                raise ValidationError(
                    "function can only be set for connections from an Ensemble"
                    " or Node (got type %r)" % type_pre,
                    attr=self.name, obj=conn)

            if isinstance(conn.pre_obj, Node) and conn.pre_obj.output is None:
                raise ValidationError(
                    "Cannot apply functions to passthrough nodes",
                    attr=self.name, obj=conn)

        size_mid = conn.size_in if size is None else size
        transform = conn.transform

        if isinstance(transform, np.ndarray):
            if transform.ndim < 2 and size_mid != conn.size_out:
                raise ValidationError(
                    "function output size is incorrect; should return a "
                    "vector of size %d" % conn.size_out, attr=self.name,
                    obj=conn)

            if transform.ndim == 2 and size_mid != transform.shape[1]:
                # check input dimensionality matches transform
                raise ValidationError(
                    "%s output size (%d) not equal to transform input size "
                    "(%d)" % (type_pre, size_mid, transform.shape[1]),
                    attr=self.name, obj=conn)

    def coerce(self, conn, function):
        function = super(ConnectionFunctionParam, self).coerce(conn, function)

        if function is None:
            function_info = FunctionInfo(function=None, size=None)
        elif isinstance(function, FunctionInfo):
            function_info = function
        elif is_array_like(function):
            array = np.array(function, copy=False, dtype=np.float64)
            self.check_array(conn, array)
            function_info = FunctionInfo(function=array, size=array.shape[1])
        elif callable(function):
            function_info = FunctionInfo(
                function=function, size=self.determine_size(conn, function))
            # TODO: necessary?
            super(ConnectionFunctionParam, self).coerce(conn, function_info)
        else:
            raise ValidationError("Invalid connection function type %r "
                                  "(must be callable or array-like)"
                                  % type(function).__name__,
                                  attr=self.name, obj=conn)

        self.check_function_can_be_applied(conn, function_info)

        return function_info

    def determine_size(self, instance, function):
        args = self.function_args(instance, function)
        value, invoked = checked_call(function, *args)
        if not invoked:
            raise ValidationError("function '%s' must accept a single "
                                  "np.array argument" % function,
                                  attr=self.name, obj=instance)
        return np.asarray(value).size

    def function_args(self, conn, function):
        x = (conn.eval_points[0] if is_iterable(conn.eval_points)
             else np.zeros(conn.size_in))
        return (x,)


class TransformParam(DistOrArrayParam):
    """The transform additionally validates size_out."""

    coerce_defaults = False

    def __init__(self, name, default, optional=False, readonly=False):
        super(TransformParam, self).__init__(
            name, default, (), optional, readonly)

    def coerce(self, conn, transform):
        if transform is not None and not isinstance(transform, Distribution):
            # if transform is an array, figure out what the correct shape
            # should be
            transform = np.asarray(transform, dtype=np.float64)

            if transform.ndim == 0:
                self.shape = ()
            elif transform.ndim == 1:
                self.shape = ('size_out',)
            elif transform.ndim == 2:
                # Actually (size_out, size_mid) but Function handles size_mid
                self.shape = ('size_out', '*')

                # check for repeated dimensions in lists, as these don't work
                # for two-dimensional transforms
                def repeated_inds(x):
                    return (not isinstance(x, slice) and
                            np.unique(x).size != len(x))
                if repeated_inds(conn.pre_slice):
                    raise ValidationError(
                        "Input object selection has repeated indices",
                        attr=self.name, obj=conn)
                if repeated_inds(conn.post_slice):
                    raise ValidationError(
                        "Output object selection has repeated indices",
                        attr=self.name, obj=conn)
            else:
                raise ValidationError(
                    "Cannot handle transforms with dimensions > 2",
                    attr=self.name, obj=conn)

        return super(TransformParam, self).coerce(conn, transform)


[docs]class Connection(NengoObject): """Connects two objects together. The connection between the two object is unidirectional, transmitting information from the first argument, ``pre``, to the second argument, ``post``. Almost any Nengo object can act as the pre or post side of a connection. Additionally, you can use Python slice syntax to access only some of the dimensions of the pre or post object. For example, if ``node`` has ``size_out=2`` and ``ensemble`` has ``size_in=1``, we could not create the following connection:: nengo.Connection(node, ensemble) But, we could create either of these two connections:: nengo.Connection(node[0], ensemble) nengo.Connection(node[1], ensemble) Parameters ---------- pre : Ensemble or Neurons or Node The source Nengo object for the connection. post : Ensemble or Neurons or Node or Probe The destination object for the connection. synapse : Synapse or None, optional \ (Default: ``nengo.synapses.Lowpass(tau=0.005)``) Synapse model to use for filtering (see `~nengo.synapses.Synapse`). If *None*, no synapse will be used and information will be transmitted without any delay (if supported by the backend---some backends may introduce a single time step delay). Note that at least one connection must have a synapse that is not *None* if components are connected in a cycle. Furthermore, a synaptic filter with a zero time constant is different from a *None* synapse as a synaptic filter will always add a delay of at least one time step. function : callable or (n_eval_points, size_mid) array_like, \ optional (Default: None) Function to compute across the connection. Note that ``pre`` must be an ensemble to apply a function across the connection. If an array is passed, the function is implicitly defined by the points in the array and the provided ``eval_points``, which have a one-to-one correspondence. transform : (size_out, size_mid) array_like, optional \ (Default: ``np.array(1.0)``) Linear transform mapping the pre output to the post input. This transform is in terms of the sliced size; if either pre or post is a slice, the transform must be shaped according to the sliced dimensionality. Additionally, the function is applied before the transform, so if a function is computed across the connection, the transform must be of shape ``(size_out, size_mid)``. solver : Solver, optional (Default: ``nengo.solvers.LstsqL2()``) Solver instance to compute decoders or weights (see `~nengo.solvers.Solver`). If ``solver.weights`` is True, a full connection weight matrix is computed instead of decoders. learning_rule_type : LearningRuleType or iterable of LearningRuleType, \ optional (Default: None) Modifies the decoders or connection weights during simulation. eval_points : (n_eval_points, size_in) array_like or int, optional \ (Default: None) Points at which to evaluate ``function`` when computing decoders, spanning the interval (-pre.radius, pre.radius) in each dimension. If None, will use the eval_points associated with ``pre``. scale_eval_points : bool, optional (Default: True) Indicates whether the evaluation points should be scaled by the radius of the pre Ensemble. label : str, optional (Default: None) A descriptive label for the connection. seed : int, optional (Default: None) The seed used for random number generation. Attributes ---------- is_decoded : bool True if and only if the connection is decoded. This will not occur when ``solver.weights`` is True or both pre and post are `~nengo.ensemble.Neurons`. function : callable The given function. function_size : int The output dimensionality of the given function. If no function is specified, function_size will be 0. label : str A human-readable connection label for debugging and visualization. If not overridden, incorporates the labels of the pre and post objects. learning_rule_type : instance or list or dict of LearningRuleType, optional The learning rule types. post : Ensemble or Neurons or Node or Probe or ObjView The given post object. post_obj : Ensemble or Neurons or Node or Probe The underlying post object, even if ``post`` is an ``ObjView``. post_slice : slice or list or None The slice associated with ``post`` if it is an ObjView, or None. pre : Ensemble or Neurons or Node or ObjView The given pre object. pre_obj : Ensemble or Neurons or Node The underlying pre object, even if ``post`` is an ``ObjView``. pre_slice : slice or list or None The slice associated with ``pre`` if it is an ObjView, or None. seed : int The seed used for random number generation. solver : Solver The Solver instance that will be used to compute decoders or weights (see ``nengo.solvers``). synapse : Synapse The Synapse model used for filtering across the connection (see ``nengo.synapses``). transform : (size_out, size_mid) array_like Linear transform mapping the pre function output to the post input. Properties ---------- size_in : int The number of output dimensions of the pre object. Also the input size of the function, if one is specified. size_mid : int The number of output dimensions of the function, if specified. If the function is not specified, then ``size_in == size_mid``. size_out : int The number of input dimensions of the post object. Also the number of output dimensions of the transform. """ probeable = ('output', 'input', 'weights') pre = PrePostParam('pre', nonzero_size_out=True) post = PrePostParam('post', nonzero_size_in=True) synapse = SynapseParam('synapse', default=Lowpass(tau=0.005)) function_info = ConnectionFunctionParam( 'function', default=None, optional=True) transform = TransformParam('transform', default=np.array(1.0)) solver = ConnectionSolverParam('solver', default=LstsqL2()) learning_rule_type = ConnectionLearningRuleTypeParam( 'learning_rule_type', default=None, optional=True) eval_points = EvalPointsParam('eval_points', default=None, optional=True, sample_shape=('*', 'size_in')) scale_eval_points = BoolParam('scale_eval_points', default=True) modulatory = ObsoleteParam( 'modulatory', "Modulatory connections have been removed. " "Connect to a learning rule instead.", since="v2.1.0", url="https://github.com/nengo/nengo/issues/632#issuecomment-71663849") _param_init_order = [ 'pre', 'post', 'synapse', 'transform', 'eval_points', 'function_info', 'solver', 'learning_rule_type'] def __init__(self, pre, post, synapse=Default, function=Default, transform=Default, solver=Default, learning_rule_type=Default, eval_points=Default, scale_eval_points=Default, label=Default, seed=Default, modulatory=Unconfigurable): super(Connection, self).__init__(label=label, seed=seed) self.pre = pre self.post = post self.synapse = synapse self.transform = transform self.scale_eval_points = scale_eval_points self.eval_points = eval_points # Must be set before function self.function_info = function # Must be set after transform self.solver = solver # Must be set before learning rule self.learning_rule_type = learning_rule_type # set after transform self.modulatory = modulatory def __str__(self): return "<Connection %s>" % self._str def __repr__(self): return "<Connection at 0x%x %s>" % (id(self), self._str) @property def _str(self): if self.label is not None: return self.label desc = "" if self.function is None else " computing '%s'" % ( function_name(self.function)) return "from %s to %s%s" % (self.pre, self.post, desc) @property def function(self): return self.function_info.function @function.setter def function(self, function): self.function_info = function @property def is_decoded(self): return not (self.solver.weights or ( isinstance(self.pre_obj, Neurons) and isinstance(self.post_obj, Neurons))) @property def _label(self): if self.label is not None: return self.label return "from %s to %s%s" % ( self.pre, self.post, " computing '%s'" % function_name(self.function) if self.function is not None else "") @property def learning_rule(self): """(LearningRule or iterable) Connectable learning rule object(s).""" if self.learning_rule_type is None: return None types = self.learning_rule_type if isinstance(types, dict): learning_rule = type(types)() # dict of same type for k, v in iteritems(types): learning_rule[k] = LearningRule(self, v) elif is_iterable(types): learning_rule = [LearningRule(self, v) for v in types] elif isinstance(types, LearningRuleType): learning_rule = LearningRule(self, types) else: raise ValidationError( "Invalid type %r" % type(types).__name__, attr='learning_rule_type', obj=self) return learning_rule @property def post_obj(self): return self.post.obj if isinstance(self.post, ObjView) else self.post @property def post_slice(self): return (self.post.slice if isinstance(self.post, ObjView) else slice(None)) @property def pre_obj(self): return self.pre.obj if isinstance(self.pre, ObjView) else self.pre @property def pre_slice(self): return self.pre.slice if isinstance(self.pre, ObjView) else slice(None) @property def size_in(self): """(int) The number of output dimensions of the pre object. Also the input size of the function, if one is specified. """ return self.pre.size_out @property def size_mid(self): """(int) The number of output dimensions of the function, if specified. If the function is not specified, then ``size_in == size_mid``. """ size = self.function_info.size return self.size_in if size is None else size @property def size_out(self): """(int) The number of input dimensions of the post object. Also the number of output dimensions of the transform. """ return self.post.size_in
[docs]class LearningRule(object): """An interface for making connections to a learning rule. Connections to a learning rule are to allow elements of the network to affect the learning rule. For example, learning rules that use error information can obtain that information through a connection. Learning rule objects should only ever be accessed through the ``learning_rule`` attribute of a connection. """ def __init__(self, connection, learning_rule_type): self._connection = connection self.learning_rule_type = learning_rule_type def __repr__(self): return "<LearningRule at 0x%x modifying %r with type %r>" % ( id(self), self.connection, self.learning_rule_type) def __str__(self): return "<LearningRule modifying %s with type %s>" % ( self.connection, self.learning_rule_type) def __eq__(self, other): return ( self._connection is other._connection and self.learning_rule_type == other.learning_rule_type) def __hash__(self): # +1 to avoid collision with ensemble return hash(self._connection) + hash(self.learning_rule_type) + 1 def __getitem__(self, key): return ObjView(self, key) @property def connection(self): """(Connection) The connection modified by the learning rule.""" return self._connection @property def modifies(self): """(str) The variable modified by the learning rule.""" return self.learning_rule_type.modifies @property def probeable(self): """(tuple) Signals that can be probed in the learning rule.""" return self.learning_rule_type.probeable @property def size_in(self): conn = self.connection size_in = self.learning_rule_type.size_in if size_in == 'pre': return conn.size_in elif size_in == 'mid': return conn.size_mid elif size_in == 'post': return conn.size_out elif size_in == 'pre_state': return (conn.pre_obj.ensemble.size_out if isinstance(conn.pre_obj, Neurons) else conn.size_in) elif size_in == 'post_state': return (conn.post_obj.ensemble.size_in if isinstance(conn.post_obj, Neurons) else conn.size_out) else: return size_in # should be an integer @property def size_out(self): """(int) Cannot connect from learning rules, so always 0.""" return 0 # since a learning rule can't connect to anything
# TODO: allow probing individual learning rules