Source code for nengo_dl.signals

from collections import defaultdict
import warnings

from nengo.builder.signal import Signal
from nengo.exceptions import BuildError
from nengo.neurons import Direct
import numpy as np
import tensorflow as tf

from nengo_dl import DEBUG


[docs]class TensorSignal(object): """Represents a tensor as an indexed view into a base array. Parameters ---------- indices : tuple or list or :class:`~numpy:numpy.ndarray` of int indices along the first axis of the base array corresponding to the data for this signal key : object key mapping to the base array that contains the data for this signal dtype : :class:`~numpy:numpy.dtype` dtype of the values represented by this signal shape : tuple of int view shape of this signal (may differ from shape of base array) minibatched : bool if True then this signal contains a minibatch dimension label : str, optional name for this signal, used to make debugging easier """ def __init__(self, indices, key, dtype, shape, minibatched, label="TensorSignal"): # make indices read-only assert isinstance(indices, (tuple, list, np.ndarray)) self._indices = np.asarray(indices) self._indices.flags.writeable = False self.tf_indices = None self.full_indices = None self.key = key self.dtype = dtype self.shape = shape self.minibatched = minibatched self.label = label @property def indices(self): return self._indices @indices.setter def indices(self, val): raise BuildError("Indices are read only") @property def ndim(self): return len(self.shape) def __repr__(self): return "TensorSignal(key=%s, shape=%s, label=%s)" % ( self.key, self.shape, self.label)
[docs] def __getitem__(self, indices): """Create a new TensorSignal representing a subset (slice or advanced indexing) of the indices of this TensorSignal. Parameters ---------- indices : slice or list of int the desired subset of the indices in this TensorSignal Returns ------- :class:`.signals.TensorSignal` a new TensorSignal representing the subset of this TensorSignal """ if indices is Ellipsis or indices is None: return self new_indices = self.indices[indices] return TensorSignal( new_indices, self.key, self.dtype, (len(new_indices),) + self.shape[1:], self.minibatched, label=self.label + ".slice")
[docs] def reshape(self, shape): """Create a new TensorSignal representing a reshaped view of the same data in this TensorSignal (size of data must remain unchanged). Parameters ---------- shape : tuple of int new shape for the signal (one dimension can be -1 to indicate an inferred dimension size, as in numpy) Returns ------- :class:`.signals.TensorSignal` new TensorSignal representing the same data as this signal but with the given shape """ # replace -1 with inferred dimension if shape.count(-1) > 1: raise BuildError("Only one inferred dimension allowed in reshape") n_elem = np.prod(self.shape) n_shape = np.prod([x for x in shape if x != -1]) if n_elem % n_shape != 0: raise BuildError("No valid length for inferred dimension") shape = tuple([x if x != -1 else n_elem // n_shape for x in shape]) if np.prod(shape) != np.prod(self.shape): raise BuildError("Number of elements don't match in reshape") return TensorSignal( self.indices, self.key, self.dtype, shape, self.minibatched, label=self.label + ".reshape(%s)" % (shape,))
[docs] def broadcast(self, axis, length): """Add a new dimension by broadcasting this signal along ``axis`` for the given length. Parameters ---------- axis : 0 or -1 where to insert the new dimension (currently only supports either the beginning or end of the array) length : int the number of times to duplicate signal along the broadcast dimension Returns ------- :class:`.signals.TensorSignal` TensorSignal with new broadcasted shape """ assert axis in (0, -1) # this only works on vectors assert self.ndim == 1 and not self.minibatched indices = self.indices indices = np.stack([indices] * length, axis=axis) indices = np.reshape(indices, (-1,)) if axis == -1: display_shape = self.shape + (length,) else: display_shape = (length,) + self.shape return TensorSignal( indices, self.key, self.dtype, display_shape, self.minibatched, label=self.label + ".broadcast(%d, %d)" % (axis, length))
[docs] def load_indices(self): """Loads the indices for this signal into tensorflow, and if the indices form a contiguous slice then also loads the start/stop/step of that slice.""" self.tf_indices = tf.constant(self.indices, dtype=tf.int32) start = self.indices[0] stop = self.indices[-1] + 1 step = (self.indices[1] - self.indices[0] if len(self.indices) > 1 else 1) if step != 0 and np.all(self.indices == np.arange(start, stop, step)): self.as_slice = (tf.constant([start]), tf.constant([stop]), tf.constant([step])) else: self.as_slice = None
[docs]class SignalDict(object): """Handles the mapping from :class:`~nengo:nengo.builder.Signal` to ``tf.Tensor``. Takes care of gather/scatter logic to read/write signals within the base arrays. Parameters ---------- sig_map : dict of {:class:`~nengo:nengo.builder.Signal`: \ :class:`.TensorSignal`} mapping from ``nengo`` signals to ``nengo_dl`` signals dtype : ``tf.DType`` floating point precision used in signals minibatch_size : int number of items in each minibatch """ def __init__(self, sig_map, dtype, minibatch_size): self.dtype = dtype self.sig_map = sig_map self.minibatch_size = minibatch_size self.base_ranges = {} self.bases = None self.reads_by_base = defaultdict(list)
[docs] def scatter(self, dst, val, mode="update"): """Updates the base data corresponding to `dst`. Parameters ---------- dst : :class:`.TensorSignal` signal indicating the data to be modified in base array val : ``tf.Tensor`` update data (same shape as ``dst``, i.e. a dense array <= the size of the base array) mode : "update" or "inc" or "mul" overwrite/add/multiply the data at ``dst`` with ``val`` """ if dst.tf_indices is None: raise BuildError("Indices for %s have not been loaded into " "Tensorflow" % dst) if not dst.minibatched: raise BuildError("Assigning to a trainable variable") if val.dtype.is_floating and val.dtype.base_dtype != self.dtype: raise BuildError("Tensor detected with wrong dtype (%s), should " "be %s." % (val.dtype.base_dtype, self.dtype)) # align val shape with dst base shape self.bases[dst.key].get_shape().assert_is_fully_defined() val.get_shape().assert_is_fully_defined() dst_shape = ((dst.shape[0],) + tuple(self.bases[dst.key].get_shape().as_list()[1:])) if val.get_shape() != dst_shape: val = tf.reshape(val, dst_shape) # TODO: until tensorflow implements scatter_nd kernel for GPU, users # will have to train on CPU (using tensors/scatter_nd), but can still # do inference on GPU (using variables/scatter) # if mode == "update": # scatter_f = tf.scatter_update # elif mode == "inc": # scatter_f = tf.scatter_add # elif mode == "mul": # scatter_f = tf.scatter_mul if DEBUG: print("scatter") print("values", val) print("dst", dst) print("indices", dst.indices) print("dst base", self.bases[dst.key]) print("reads_by_base", self.reads_by_base[self.bases[dst.key]]) # make sure that any reads to the target signal happen before this # write (note: this is only any reads that have happened since the # last write, since each write changes the base array object) with tf.control_dependencies(self.reads_by_base[self.bases[dst.key]]): if np.all(np.arange(self.bases[dst.key].get_shape()[0].value) == dst.indices): if mode == "update": self.bases[dst.key] = tf.identity(val) elif mode == "inc": self.bases[dst.key] += val else: # self.bases[dst.key] = scatter_f( # self.bases[dst.key], dst.tf_indices, val) # self.bases[dst.key] = self._scatter_f( # self.bases[dst.key], dst.tf_indices, val, mode=mode) self.bases[dst.key] = self._scatter_f2(dst, val, mode=mode) if DEBUG: print("new dst base", self.bases[dst.key])
def _scatter_f(self, dst, idxs, src, mode="update"): if mode == "update": tmp = tf.dynamic_stitch([tf.range(dst.get_shape()[0]), idxs], [dst, src]) tmp.set_shape(dst.get_shape()) return tmp elif mode == "inc": # src = tf.reshape(src, (-1,)) # tmp = tf.SparseTensor( # idxs, src, # dst.get_shape()[:1].concatenate(src.get_shape()[1:])) # return tf.sparse_add(dst, tmp) idxs = tf.expand_dims(idxs, 1) return dst + tf.scatter_nd(idxs, src, dst.get_shape()) else: raise NotImplementedError def _scatter_f2(self, dst, src, mode="update"): base_idxs = self.base_ranges[dst.key] if mode == "update": result = tf.dynamic_stitch([base_idxs, dst.tf_indices], [self.bases[dst.key], src]) elif mode == "inc": x = self.gather(dst) result = tf.dynamic_stitch([base_idxs, dst.tf_indices], [self.bases[dst.key], x + src]) # elif mode == "mul": # x = self.gather(dst) # result = tf.dynamic_stitch([base_idxs, dst.tf_indices], # [self.bases[dst.key], x * src]) else: raise NotImplementedError result.set_shape(self.bases[dst.key].get_shape()) return result
[docs] def gather(self, src, force_copy=False): """Fetches the data corresponding to ``src`` from the base array. Parameters ---------- src : :class:`.TensorSignal` signal indicating the data to be read from base array force_copy : bool, optional if True, always perform a gather, not a slice (this forces a copy). note that setting ``force_copy=False`` does not guarantee that a copy won't be performed. Returns ------- ``tf.Tensor`` tensor object corresponding to a dense subset of data from the base array """ if src.tf_indices is None: raise BuildError("Indices for %s have not been loaded into " "Tensorflow" % src) if DEBUG: print("gather") print("src", src) print("indices", src.indices) print("src base", self.bases[src.key]) # we prefer to get the data via `strided_slice` if possible, as it # is more efficient if force_copy or src.as_slice is None: result = tf.gather(self.bases[src.key], src.tf_indices) elif np.all(np.arange(self.bases[src.key].get_shape()[0].value) == src.indices): result = tf.identity(self.bases[src.key]) else: result = tf.strided_slice(self.bases[src.key], *src.as_slice) # for some reason the shape inference doesn't work in some cases, # and tensorflow loses track of the shape result.set_shape(src.tf_indices.get_shape()[:1].concatenate( self.bases[src.key].get_shape()[1:])) # reshape the data according to the shape set in `src`, if there is # one, otherwise keep the shape of the base array src_shape = src.shape if src.minibatched: src_shape += (self.minibatch_size,) if result.get_shape() != src_shape: result = tf.reshape(result, src_shape) # whenever we read from an array we use this to mark it as "read" # (so that any future writes to the array will be scheduled after # the read) # TODO: we could store the indices as well, so that future writes are # only delayed if they write to the same part of the array self.reads_by_base[self.bases[src.key]] += [result] return result
[docs] def combine(self, sigs, load_indices=True): """Combines several TensorSignals into one by concatenating along the first axis. Parameters ---------- sigs : list of :class:`.TensorSignal` or \ :class:`~nengo:nengo.builder.Signal` signals to be combined load_indices : bool, optional if True, load the indices for the new signal into tensorflow right away (otherwise they will need to be manually loaded later) Returns ------- :class:`.TensorSignal` new TensorSignal representing the concatenation of the data in ``sigs`` """ if len(sigs) == 0: return [] assert isinstance(sigs, (list, tuple)) assert isinstance(sigs[0], (Signal, TensorSignal)) sigs = [self.sig_map[s] if isinstance(s, Signal) else s for s in sigs] key = sigs[0].key # make sure all the signals have the same base assert all([s.key == key for s in sigs]) indices = np.concatenate([s.indices for s in sigs], axis=0) # make sure all signals have the same shape (except first axis, # which we're concatenating along); note, this can fail even if they # all have the same base, due to reshaping assert all([s.shape[1:] == sigs[0].shape[1:] for s in sigs]) shape = (np.sum([s.shape[0] for s in sigs]),) + sigs[0].shape[1:] output = TensorSignal(indices, key, sigs[0].dtype, shape, sigs[0].minibatched) if load_indices: output.load_indices() return output
def __str__(self): """Pretty-print the signals and current values.""" return "\n".join(["%s: %s" % (repr(k), repr(self[k])) for k in self])
def mark_signals(model): """Mark all the signals in ``model`` according to whether they represent trainable parameters of the model (parameters that can be optimized by deep learning methods). Trainable parameters include connection weights, ensemble encoders, and neuron biases. Unless one of those signals is targeted by a Nengo learning rule (otherwise the learning rule update conflicts with the deep learning optimization). Parameters ---------- model : class:`~nengo:nengo.builder.Model` built Nengo model """ if model.toplevel is None: warnings.warn("No top-level network in model") else: for ens in model.toplevel.all_ensembles: model.sig[ens]["encoders"].trainable = True if not isinstance(ens.neuron_type, Direct): model.sig[ens.neurons]["bias"].trainable = True for conn in model.toplevel.all_connections: # note: this doesn't include probe connections, since they aren't # added to the network # TODO: should we disable training on connections to learning # rules? model.sig[conn]["weights"].trainable = True # parameters can't be modified by an online Nengo learning rule # and offline training at the same time. (it is possible in theory, # but it complicates things a lot and is probably not a common # use case). rule = conn.learning_rule if rule is not None: if isinstance(rule, dict): rule = list(rule.values()) elif not isinstance(rule, list): rule = [rule] for r in rule: if r.modifies == "weights" or r.modifies == "decoders": model.sig[conn]["weights"].trainable = False elif r.modifies == "encoders": model.sig[conn.post_obj]["encoders"].trainable = False else: raise NotImplementedError # mark everything as not trainable by default for op in model.operators: for sig in op.all_signals: for x in (sig, sig.base): if not hasattr(x, "trainable"): x.trainable = False # at the moment minibatched is just the opposite of trainable, # but it could be the case that the two are independent x.minibatched = not sig.trainable