Source code for nengo_dl.op_builders

"""
Build classes for basic Nengo operators.
"""

from collections import defaultdict
from distutils.version import LooseVersion
import logging
import warnings

from nengo.builder.operator import Reset, Copy, ElementwiseInc, DotInc, SimPyFunc
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import gen_sparse_ops

from nengo_dl import utils
from nengo_dl.builder import Builder, OpBuilder
from nengo_dl.compat import tf_compat, SparseDotInc, SparseMatrix

logger = logging.getLogger(__name__)


[docs]class ResetInc(Reset): """ A version of Reset that increments the target value rather than setting it. """ @property def dst(self): """Overridden to return from incs rather than sets.""" return self.incs[0]
[docs]@Builder.register(Reset) @Builder.register(ResetInc) class ResetBuilder(OpBuilder): """ Build a group of `~nengo.builder.operator.Reset` operators. """ def __init__(self, ops, signals, config): super(ResetBuilder, self).__init__(ops, signals, config) logger.debug("val %s", [op.value for op in ops]) logger.debug("dst %s", [op.dst for op in ops]) self.mode = "inc" if type(ops[0]) == ResetInc else "update" dtype = np.asarray(ops[0].value).dtype if np.issubdtype(dtype, np.floating): dtype = signals.dtype.as_numpy_dtype # unlike other ops, Reset signals might be spread across multiple # bases, which we need to handle scatters = defaultdict(list) for op in ops: scatters[signals[op.dst].key] += [op] self.scatters = [] for group in scatters.values(): value = np.concatenate( [ np.resize(np.asarray(x.value).astype(dtype), x.dst.shape) for x in group ], axis=0, ) value = np.tile( value[..., None], tuple(1 for _ in value.shape) + (signals.minibatch_size,), ) self.scatters += [ (signals.combine([x.dst for x in group]), signals.constant(value)) ] logger.debug("scatters") logger.debug("\n".join([str(x) for x in self.scatters]))
[docs] def build_step(self, signals): for data, val in self.scatters: signals.scatter(data, val, mode=self.mode)
[docs] @staticmethod def mergeable(x, y): return True
[docs]@Builder.register(Copy) class CopyBuilder(OpBuilder): """ Build a group of `~nengo.builder.operator.Copy` operators. """ def __init__(self, ops, signals, config): super(CopyBuilder, self).__init__(ops, signals, config) logger.debug("src %s", [op.src for op in ops]) logger.debug("src_slice %s", [getattr(op, "src_slice", None) for op in ops]) logger.debug("dst %s", [op.dst for op in ops]) logger.debug("dst_slice %s", [getattr(op, "dst_slice", None) for op in ops]) srcs = [] dsts = [] for op in ops: srcs += [signals[op.src][op.src_slice]] dsts += [signals[op.dst][op.dst_slice]] self.mode = "inc" if ops[0].inc else "update" self.src_data = signals.combine(srcs) self.dst_data = signals.combine(dsts) if not self.src_data.minibatched and self.dst_data.minibatched: # broadcast indices so that the un-minibatched src data gets # copied to each minibatch dimension in dst self.src_data = self.src_data.broadcast(-1, signals.minibatch_size)
[docs] def build_step(self, signals): signals.scatter(self.dst_data, signals.gather(self.src_data), mode=self.mode)
[docs] @staticmethod def mergeable(x, y): return True
# class ElementwiseSet(ElementwiseInc): # @property # def Y(self): # return self.sets[0]
[docs]@Builder.register(ElementwiseInc) # @Builder.register(ElementwiseSet) class ElementwiseIncBuilder(OpBuilder): """ Build a group of `~nengo.builder.operator.ElementwiseInc` operators. """ def __init__(self, ops, signals, config): super(ElementwiseIncBuilder, self).__init__(ops, signals, config) logger.debug("dst %s", [op.Y for op in ops]) logger.debug("A %s", [op.A for op in ops]) logger.debug("X %s", [op.X for op in ops]) self.mode = "inc" if type(ops[0]) == ElementwiseInc else "update" self.Y_data = signals.combine([op.Y for op in ops]) # group all the A's and X's self.A_data = signals.combine([op.A for op in ops]) self.X_data = signals.combine([op.X for op in ops]) # separate data from each op along the first dimension if self.A_data.shape[0] != self.X_data.shape[0]: self.A_data = self.A_data.reshape((len(ops), -1) + self.A_data.shape[1:]) self.X_data = self.X_data.reshape((len(ops), -1) + self.X_data.shape[1:]) # add empty trailing dimensions for elementwise broadcasting while self.A_data.ndim < self.X_data.ndim: self.A_data = self.A_data.reshape(self.A_data.shape + (1,)) # add broadcast dimension for minibatch, if needed if not self.A_data.minibatched and self.X_data.minibatched: self.A_data = self.A_data.reshape(self.A_data.shape + (1,))
[docs] def build_step(self, signals): A = signals.gather(self.A_data) X = signals.gather(self.X_data) result = tf.multiply(A, X) signals.scatter(self.Y_data, result, mode=self.mode)
[docs] @staticmethod def mergeable(x, y): # for these operations we enforce that the first dimensions # match (we know all the other dimensions match due to the generic # checks). # this allows us to stack all the arguments into continuous array # blocks, allowing for more efficient multiplication (mainly # because it allows us to take advantage of broadcasting) for s0, s1 in zip(x.all_signals, y.all_signals): shape0 = s0.shape[0] if s0.shape != () else 1 shape1 = s1.shape[0] if s1.shape != () else 1 if shape0 != shape1: return False return True
[docs]def sparse_matmul(A_indices, A_data, A_shape, X): """ Matrix multiplication between sparse matrix A and dense matrix X Parameters ---------- A_indices : ``tf.Tensor`` N, 2) rray of [row,col] non-zero entries A_data : ``tf.Tensor`` (N,) array of data in the nonzero entries specified in ``A_indices`` A_shape : tuple of int Shape of full A matrix X : ``tf.Tensor`` Dense matrix being multiplied by A Returns ------- dot : ``tf.Tensor`` Result of matrix multiplication between A and X """ must_downcast = A_data.dtype.base_dtype != tf.float32 and ( "gpu" in A_data.device.lower() or (A_data.device == "" and utils.tf_gpu_installed) ) if must_downcast: assert A_data.dtype.base_dtype == X.dtype.base_dtype warnings.warn( "Downcasting data to float32 in sparse_matmul, since " "only float32 is supported on the GPU." ) A = tf.cast(A_data, tf.float32) X = tf.cast(X, tf.float32) else: A = A_data if LooseVersion(tf.__version__) < LooseVersion("1.7.0"): mat_mul = gen_sparse_ops._sparse_tensor_dense_mat_mul else: mat_mul = gen_sparse_ops.sparse_tensor_dense_mat_mul dot = mat_mul(A_indices, A, A_shape, X) if must_downcast: dot = tf.cast(dot, A_data.dtype.base_dtype) return dot
# class DotSet(DotInc): # @property # def Y(self): # return self.sets[0]
[docs]@Builder.register(DotInc) # @Builder.register(DotSet) class DotIncBuilder(OpBuilder): """ Build a group of `~nengo.builder.operator.DotInc` operators. """ def __init__(self, ops, signals, config): # note: bypassing the DotIncBuilder init # pylint: disable=bad-super-call super(DotIncBuilder, self).__init__(ops, signals, config) logger.debug("dst %s", [op.Y for op in ops]) logger.debug("A %s", [op.A for op in ops]) logger.debug("X %s", [op.X for op in ops]) self.mode = "inc" if type(ops[0]) == DotInc else "update" # check if all the signals have the same size for the first dimension self.len_match = True for i, s0 in enumerate(ops[0].all_signals): shape0 = s0.shape[0] if s0.shape != () else 1 for op in ops: s1 = op.all_signals[i] shape1 = s1.shape[0] if s1.shape != () else 1 if shape0 != shape1: self.len_match = False break if not self.len_match: break self.Y_data = signals.combine([op.Y for op in ops]) # group all the A's and X's A_data = signals.combine([op.A for op in ops]) X_data = signals.combine([op.X for op in ops]) if self.len_match: # if the first dimensions all match, then we can used the # (batched) matrix multiplication op # separate data from each op along the first dimension self.A_data = A_data.reshape((len(ops), -1, A_data.shape[1])) self.X_data = X_data.reshape((len(ops), -1)) if self.A_data.minibatched: # add broadcast dimension to X self.X_data = self.X_data.reshape(self.X_data.shape + (1,)) # precompute transposition indices self.perm = tf.constant((0, 3, 1, 2)) self.perm_inv = tf.constant((0, 2, 3, 1)) else: # if the first dimensions don't match, then we create a block # diagonal matrix out of all the op matrices, and then multiply # them using a sparse matrix multiplication self.A_data = A_data.reshape((-1,)) self.X_data = X_data assert not self.A_data.minibatched assert self.X_data.minibatched and self.Y_data.minibatched sparse_indices = [] corner = np.zeros(2, dtype=np.int64) for op in ops: block_shape = (op.A.shape[0], op.A.shape[1]) idxs = np.reshape( np.dstack( np.meshgrid( np.arange(block_shape[0]), np.arange(block_shape[1]), indexing="ij", ) ), (-1, 2), ) idxs += corner corner += block_shape sparse_indices += [idxs] sparse_indices = np.concatenate(sparse_indices, axis=0) self.sparse_indices = signals.constant( sparse_indices, dtype=( tf.int32 if np.all(sparse_indices < np.iinfo(np.int32).max) else tf.int64 ), ) self.A_shape = tf.constant(corner, dtype=tf.int64)
[docs] def build_step(self, signals): A = signals.gather(self.A_data) X = signals.gather(self.X_data) if self.len_match: if self.A_data.minibatched and self.X_data.minibatched: # dot = tf.einsum("ijkl,ikl->ijl", A, X) # note: this is just a duplicate of what einsum does # internally; we do it manually so that we can move the # perm/perm_inv constants into the pre-build step A = tf.transpose(a=A, perm=self.perm) X = tf.transpose(a=X, perm=self.perm) dot = tf.matmul(A, X) dot = tf.transpose(a=dot, perm=self.perm_inv) dot.set_shape(self.A_data.shape[:2] + (1, signals.minibatch_size)) elif not self.A_data.minibatched and self.X_data.minibatched: dot = tf.matmul(A, X) else: # note: these cases never come up (so far) in nengo, since X # is always minibatched. but preserving them here for # posterity, in case they are ever used # A minibatched, X not minibatched # dot = tf.einsum("ijkl,ik->ijl", A, X) # A not minibatched, X not minibatched # dot = tf.einsum("ijk,ik->ij", A, X) raise NotImplementedError else: dot = sparse_matmul(self.sparse_indices, A, self.A_shape, X) dot.set_shape(self.Y_data.shape + (signals.minibatch_size,)) signals.scatter(self.Y_data, dot, mode=self.mode)
[docs] @staticmethod def mergeable(x, y): # if the matrix (A) is minibatched, then the first dimensions need # to match up (to allow us to transpose the dimensions) if x.A.minibatched: for s0, s1 in zip(x.all_signals, y.all_signals): shape0 = s0.shape[0] if s0.shape != () else 1 shape1 = s1.shape[0] if s1.shape != () else 1 if shape0 != shape1: return False return True
[docs]@Builder.register(SimPyFunc) class SimPyFuncBuilder(OpBuilder): """ Build a group of `~nengo.builder.operator.SimPyFunc` operators. """ def __init__(self, ops, signals, config): super(SimPyFuncBuilder, self).__init__(ops, signals, config) logger.debug("t %s", [op.t for op in ops]) logger.debug("x %s", [op.x for op in ops]) logger.debug("fn %s", [op.fn for op in ops]) self.time_input = ops[0].t is not None self.input_data = signals.combine([op.x for op in ops]) if ops[0].output is not None: self.output_data = signals.combine([op.output for op in ops]) self.output_dtype = self.output_data.dtype else: self.output_data = None self.output_dtype = signals.dtype def merged_func(time, inputs): # pragma: no cover (runs in TF) outputs = [] offset = 0 for op in ops: if op.output is None: func = op.fn else: func = utils.align_func(op.output.shape, self.output_dtype)(op.fn) func_input = inputs[offset : offset + op.x.shape[0]] offset += op.x.shape[0] mini_out = [] for j in range(signals.minibatch_size): if op.t is None: func_out = func(func_input[..., j]) else: func_out = func(time, func_input[..., j]) if op.output is None: # just return time as a noop (since we need to # return something) func_out = time mini_out += [func_out] outputs += [np.stack(mini_out, axis=-1)] return np.concatenate(outputs, axis=0) self.merged_func = merged_func self.merged_func.__name__ = "_".join([utils.function_name(op.fn) for op in ops]) self.output_shape = ( (len(ops),) if self.output_data is None else self.output_data.shape ) self.output_shape += (signals.minibatch_size,)
[docs] def build_step(self, signals): time = signals.time if self.time_input else [] inputs = [] if self.input_data is None else signals.gather(self.input_data) with tf.device("/cpu:0"): node_outputs = tf_compat.py_func( self.merged_func, [time, inputs], self.output_dtype, name=self.merged_func.__name__, ) node_outputs.set_shape(self.output_shape) if self.output_data is not None: signals.scatter(self.output_data, node_outputs) # note: we only need to run the node for side effects, not the # assignment operator. if the result of the assignment is actually # used anywhere, then it will be run as part of the normal graph. return node_outputs
[docs] @staticmethod def mergeable(x, y): # for these we need to make a special check that the functions # all do/do not get time as input, otherwise we could end # up confusing a node that only gets a scalar float input with # a node that only gets time as input return x.t == y.t
[docs]@Builder.register(SparseDotInc) class SparseDotIncBuilder(OpBuilder): """ Build a group of `~nengo.builder.operator.SparseDotInc` operators. """ def __init__(self, ops, signals, config): super().__init__(ops, signals, config) self.Y_data = signals.combine([op.Y for op in ops]) # group all the A's and X's self.A_data = signals.combine([op.A for op in ops]) self.X_data = signals.combine([op.X for op in ops]) # the only way A would be minibatched is if it is targeted by an # online learning rule, which isn't supported for sparse transforms assert not self.A_data.minibatched assert self.X_data.minibatched and self.Y_data.minibatched # arrange the sparse matrices into a (sparse) block diagonal matrix # by adding an offset to each sparse matrix's indices sparse_indices = [] corner = np.zeros(2, dtype=np.int64) for op in ops: if isinstance(op.A.initial_value, SparseMatrix): idxs = np.array(op.A.initial_value.indices) else: initial_value = op.A.initial_value.tocoo() idxs = np.stack((initial_value.row, initial_value.col), axis=1) block_shape = (op.A.shape[0], op.A.shape[1]) idxs += corner corner += block_shape sparse_indices += [idxs] sparse_indices = np.concatenate(sparse_indices, axis=0) self.sparse_indices = signals.constant( sparse_indices, dtype=( tf.int32 if np.all(sparse_indices < np.iinfo(np.int32).max) else tf.int64 ), ) self.A_shape = tf.constant(corner, dtype=tf.int64)
[docs] def build_step(self, signals): A = signals.gather(self.A_data) X = signals.gather(self.X_data) dot = sparse_matmul(self.sparse_indices, A, self.A_shape, X) dot.set_shape(self.Y_data.shape + (signals.minibatch_size,)) signals.scatter(self.Y_data, dot, mode="inc")
[docs] @staticmethod def mergeable(x, y): return True