Source code for nengo_dl.tensor_node

"""
TensorNodes allow parts of a model to be defined using TensorFlow and smoothly
integrated with the rest of a Nengo model.

See `the documentation <https://www.nengo.ai/nengo-dl/tensor-node.html>`_ for more
details.
"""

import contextlib
import warnings

from nengo import Node, Connection, Ensemble, builder
from nengo.base import NengoObject
from nengo.builder.operator import Reset
from nengo.config import Config
from nengo.exceptions import ValidationError, SimulationError
from nengo.neurons import NeuronType
from nengo.params import Default, ShapeParam, Parameter, BoolParam
import numpy as np
import tensorflow as tf
from tensorflow.python.eager import context

from nengo_dl.builder import Builder, OpBuilder, NengoBuilder
from nengo_dl.compat import default_transform, eager_enabled
from nengo_dl.config import configure_settings


def validate_output(output, minibatch_size=None, output_d=None, dtype=None):
    """
    Performs validation on the output of a TensorNode ``tensor_func``.

    Parameters
    ----------
    output : ``tf.Tensor`` or ``tf.TensorSpec``
        Output from the ``tensor_func``.
    minibatch_size : int
        Expected minibatch size for the simulation.
    output_d
        Expected output dimensionality for the function.
    dtype
        Expected dtype of the function output.
    """

    if not isinstance(output, (tf.Tensor, tf.TensorSpec)):
        raise ValidationError(
            "TensorNode function must return a Tensor (got %s)" % type(output),
            attr="tensor_func",
        )

    if minibatch_size is not None and output.shape[0] != minibatch_size:
        raise ValidationError(
            "TensorNode output should have batch size %d (got %d)"
            % (minibatch_size, output.shape[0]),
            attr="tensor_func",
        )

    if output_d is not None and np.prod(output.shape[1:]) != output_d:
        raise ValidationError(
            "TensorNode output should have size %d (got shape %s with size %d)"
            % (minibatch_size, output.shape[1:], np.prod(output.shape[1:])),
            attr="tensor_func",
        )

    if dtype is not None and output.dtype != dtype:
        raise ValidationError(
            "TensorNode output should have dtype %s "
            "(got %s)" % (dtype, output.dtype),
            attr="tensor_func",
        )


class TensorFuncParam(Parameter):
    """Parameter for the ``tensor_func`` parameter of a `.TensorNode`."""

    def __init__(self, name, readonly=False):
        super().__init__(name, optional=False, readonly=readonly)

    def coerce(self, node, func):
        """
        Performs validation on the function passed to TensorNode, and sets
        ``shape_out`` if necessary.

        Parameters
        ----------
        node : `.TensorNode`
            The node whose ``tensor_func`` parameter is being set.
        func : callable
            The function being assigned to the TensorNode.

        Returns
        -------
        output : callable
            The function after validation is applied.
        """

        output = super().coerce(node, func)

        if not callable(func):
            raise ValidationError(
                "TensorNode output must be a function or Keras Layer",
                attr=self.name,
                obj=node,
            )

        if node.shape_out is None:
            if isinstance(func, tf.keras.layers.Layer):
                # we can use Keras' static shape inference to get the
                # output shape, which avoids having to build/call the layer
                if node.pass_time:
                    input_spec = [tf.TensorSpec(())]
                else:
                    input_spec = []
                if node.shape_in is not None:
                    input_spec += [tf.TensorSpec((1,) + node.shape_in)]

                if len(input_spec) == 1:
                    input_spec = input_spec[0]

                ctx = contextlib.suppress() if eager_enabled() else context.eager_mode()

                try:
                    with ctx:
                        result = func.compute_output_signature(input_spec)
                except Exception as e:
                    raise ValidationError(
                        "Attempting to automatically determine TensorNode output shape "
                        "by calling Layer.compute_output_signature produced an error. "
                        "If you would like to avoid this step, try manually setting "
                        "`TensorNode(..., shape_out=x)`. The error is shown below:\n%s"
                        % repr(e),
                        attr=self.name,
                        obj=node,
                    )

            else:
                if node.pass_time:
                    args = (tf.constant(0.0),)
                else:
                    args = ()
                if node.shape_in is not None:
                    args += (tf.zeros((1,) + node.shape_in),)

                try:
                    result = func(*args)
                except Exception as e:
                    raise ValidationError(
                        "Attempting to automatically determine TensorNode output shape "
                        "by calling TensorNode function produced an error. "
                        "If you would like to avoid this step, try manually setting "
                        "`TensorNode(..., shape_out=x)`. The error is shown below:\n%s"
                        % e,
                        attr=self.name,
                        obj=node,
                    )

            validate_output(result)

            node.shape_out = result.shape[1:]

        return output


[docs]class TensorNode(Node): """ Inserts TensorFlow code into a Nengo model. Parameters ---------- tensor_func : callable A function that maps node inputs to outputs shape_in : tuple of int Shape of TensorNode input signal (not including batch dimension). shape_out : tuple of int Shape of TensorNode output signal (not including batch dimension). If None, value will be inferred by calling ``tensor_func``. pass_time : bool If True, pass current simulation time to TensorNode function (in addition to the standard input). label : str (Default: None) A name for the node, used for debugging and visualization """ tensor_func = TensorFuncParam("tensor_func") shape_in = ShapeParam("shape_in", default=None, low=1, optional=True) shape_out = ShapeParam("shape_out", default=None, low=1, optional=True) pass_time = BoolParam("pass_time", default=True) def __init__( self, tensor_func, shape_in=Default, shape_out=Default, pass_time=Default, label=Default, ): # pylint: disable=non-parent-init-called,super-init-not-called # note: we bypass the Node constructor, because we don't want to # perform validation on `output` NengoObject.__init__(self, label=label, seed=None) self.shape_in = shape_in self.shape_out = shape_out self.pass_time = pass_time if not (self.shape_in or self.pass_time): raise ValidationError( "Must specify either shape_in or pass_time", "TensorNode" ) self.tensor_func = tensor_func @property def output(self): """ Ensures that nothing tries to evaluate the `output` attribute (indicating that something is trying to simulate this as a regular `nengo.Node` rather than a TensorNode). """ def output_func(*_): raise SimulationError( "Cannot call TensorNode output function (this probably means " "you are trying to use a TensorNode inside a Simulator other " "than NengoDL)" ) return output_func @property def size_in(self): """Number of input elements (flattened).""" return 0 if self.shape_in is None else np.prod(self.shape_in) @property def size_out(self): """Number of output elements (flattened).""" return 0 if self.shape_out is None else np.prod(self.shape_out)
[docs]@NengoBuilder.register(TensorNode) def build_tensor_node(model, node): """This is the Nengo build function, so that Nengo knows what to do with TensorNodes.""" # time signal if node.pass_time: time_in = model.time else: time_in = None # input signal if node.shape_in is not None: sig_in = builder.Signal(shape=(node.size_in,), name="%s.in" % node) model.add_op(Reset(sig_in)) else: sig_in = None sig_out = builder.Signal(shape=(node.size_out,), name="%s.out" % node) model.sig[node]["in"] = sig_in model.sig[node]["out"] = sig_out model.params[node] = None model.operators.append( SimTensorNode(node.tensor_func, time_in, sig_in, sig_out, node.shape_in) )
[docs]class SimTensorNode(builder.Operator): # pylint: disable=abstract-method """Operator for TensorNodes (constructed by `.build_tensor_node`). Parameters ---------- func : callable The TensorNode function (``tensor_func``). time : `~nengo.builder.Signal` or None Signal representing the current simulation time (or None if ``pass_time`` is False). input : `~nengo.builder.Signal` or None Input Signal for the TensorNode (or None if no inputs). output : `~nengo.builder.Signal` Output Signal for the TensorNode. shape_in : tuple of int or None Shape of input to TensorNode (if None, will leave the shape of input signal unchanged). tag : str A label associated with the operator, for debugging Notes ----- 1. sets ``[output]`` 2. incs ``[]`` 3. reads ``[time]`` (if ``pass_time=True``) + ``[input]`` (if ``input`` is not None) 4. updates ``[]`` """ def __init__(self, func, time, input, output, shape_in, tag=None): super().__init__(tag=tag) self.func = func self.time = time self.input = input self.output = output self.shape_in = shape_in self.sets = [output] self.incs = [] if time is None: self.reads = [] else: self.reads = [time] if input is not None: self.reads += [input] self.updates = []
[docs]@Builder.register(SimTensorNode) class SimTensorNodeBuilder(OpBuilder): """Builds a `~.tensor_node.SimTensorNode` operator into a NengoDL model."""
[docs] def build_pre(self, signals, config): super().build_pre(signals, config) # SimTensorNodes should never be merged assert len(self.ops) == 1 op = self.ops[0] if op.time is None: self.time_data = None else: self.time_data = signals[op.time].reshape(()) if op.input is None: self.src_data = None else: self.src_data = signals[op.input] assert self.src_data.ndim == 1 if op.shape_in is not None: self.src_data = self.src_data.reshape(op.shape_in) self.dst_data = signals[op.output] self.func = op.func
[docs] def build_step(self, signals): if self.time_data is None: inputs = [] else: inputs = [signals.gather(self.time_data)] if self.src_data is not None: inputs += [signals.gather(self.src_data)] if isinstance(self.func, tf.keras.layers.Layer): if len(inputs) == 1: inputs = inputs[0] kwargs = ( dict(training=self.config.training) if self.func._expects_training_arg else {} ) output = self.func.call(inputs, **kwargs) else: output = self.func(*inputs) validate_output( output, minibatch_size=signals.minibatch_size, output_d=self.dst_data.shape[0], dtype=signals.dtype, ) signals.scatter(self.dst_data, output)
[docs]class Layer: """ A wrapper for constructing TensorNodes. This is designed to mimic and integrate with the ``tf.keras.layers.Layer`` API, e.g. .. testcode:: with nengo.Network(): a = nengo.Ensemble(10, 1) b = nengo_dl.Layer(tf.keras.layers.Dense(units=10))(a) c = nengo_dl.Layer(lambda x: x + 1)(b) d = nengo_dl.Layer(nengo.LIF())(c) Parameters ---------- layer_func : callable or ``tf.keras.Layer`` or `~nengo.neurons.NeuronType` A function or Keras Layer that takes the value from an input (represented as a ``tf.Tensor``) and maps it to some output value, or a Nengo neuron type (which will be instantiated in a Nengo Ensemble and applied to the input). """ def __init__(self, layer_func): self.layer_func = layer_func
[docs] def __call__( self, input, transform=default_transform, shape_in=None, synapse=None, return_conn=False, **layer_args ): """ Apply the TensorNode layer to the given input object. Parameters ---------- input : ``NengoObject`` Object providing input to the layer. transform : `~numpy.ndarray` Transform matrix to apply on connection from ``input`` to this layer. shape_in : tuple of int If not None, reshape the input to the given shape. synapse : float or `~nengo.synapses.Synapse` Synapse to apply on connection from ``input`` to this layer. return_conn : bool If True, also return the connection linking this layer to ``input``. layer_args : dict These arguments will be passed to `.TensorNode` if ``layer_func`` is a callable or Keras Layer, or `~nengo.Ensemble` if ``layer_func`` is a `~nengo.neurons.NeuronType`. Returns ------- obj : `.TensorNode` or `~nengo.ensemble.Neurons` A TensorNode that implements the given layer function (if ``layer_func`` was a callable/Keras layer), or a Neuron object with the given neuron type, connected to ``input``. conn : `~nengo.Connection` If ``return_conn`` is True, also returns the connection object linking ``input`` and ``obj``. Notes ----- The input connection created for the new TensorNode will be marked as non-trainable by default. """ if shape_in is not None and all(x is not None for x in shape_in): size_in = np.prod(shape_in) elif isinstance(transform, np.ndarray) and transform.ndim == 2: size_in = transform.shape[0] else: size_in = input.size_out if isinstance(self.layer_func, NeuronType): obj = Ensemble( size_in, 1, neuron_type=self.layer_func, **layer_args ).neurons else: obj = TensorNode( self.layer_func, shape_in=(size_in,) if shape_in is None else shape_in, pass_time=False, **layer_args, ) conn = Connection(input, obj, synapse=synapse, transform=transform) # set connection to non-trainable cfg = Config.context[0][conn] if not hasattr(cfg, "trainable"): configure_settings(trainable=None) cfg.trainable = False return (obj, conn) if return_conn else obj
[docs] def __str__(self): return "Layer(%s)" % getattr( self.layer_func, "name", getattr(self.layer_func, "__name__", self.layer_func), )
def tensor_layer(input, layer_func, **kwargs): """Deprecated, use `.Layer` instead.""" warnings.warn( "nengo_dl.tensor_layer is deprecated; use nengo_dl.Layer instead", DeprecationWarning, ) return Layer(layer_func)(input, **kwargs)