Source code for nengo_dl.learning_rule_builders

"""
Build classes for Nengo learning rule operators.
"""

from nengo.builder import Signal
from nengo.builder.learning_rules import (
    SimBCM,
    SimOja,
    SimPES,
    SimVoja,
    get_post_ens,
    build_or_passthrough,
)
from nengo.builder.operator import Reset, DotInc, Copy
from nengo.learning_rules import PES
import numpy as np
import tensorflow as tf

from nengo_dl.builder import Builder, OpBuilder, NengoBuilder


[docs]@Builder.register(SimBCM) class SimBCMBuilder(OpBuilder): """Build a group of `~nengo.builder.learning_rules.SimBCM` operators."""
[docs] def build_pre(self, signals, config): super().build_pre(signals, config) self.post_data = signals.combine([op.post_filtered for op in self.ops]) self.post_data = self.post_data.reshape(self.post_data.shape + (1,)) self.theta_data = signals.combine([op.theta for op in self.ops]) self.theta_data = self.theta_data.reshape(self.theta_data.shape + (1,)) self.pre_data = signals.combine( [ op.pre_filtered for op in self.ops for _ in range(op.post_filtered.shape[0]) ] ) self.pre_data = self.pre_data.reshape( (self.post_data.shape[0], self.ops[0].pre_filtered.shape[0]) ) self.learning_rate = signals.op_constant( self.ops, [op.post_filtered.shape[0] for op in self.ops], "learning_rate", signals.dtype, shape=(1, -1, 1), ) self.output_data = signals.combine([op.delta for op in self.ops])
[docs] def build_step(self, signals): pre = signals.gather(self.pre_data) post = signals.gather(self.post_data) theta = signals.gather(self.theta_data) post = self.learning_rate * signals.dt * post * (post - theta) signals.scatter(self.output_data, post * pre)
[docs] @staticmethod def mergeable(x, y): # pre inputs must have the same dimensionality so that we can broadcast # them when computing the outer product return x.pre_filtered.shape[0] == y.pre_filtered.shape[0]
[docs]@Builder.register(SimOja) class SimOjaBuilder(OpBuilder): """Build a group of `~nengo.builder.learning_rules.SimOja` operators."""
[docs] def build_pre(self, signals, config): super().build_pre(signals, config) self.post_data = signals.combine([op.post_filtered for op in self.ops]) self.post_data = self.post_data.reshape(self.post_data.shape + (1,)) self.pre_data = signals.combine( [ op.pre_filtered for op in self.ops for _ in range(op.post_filtered.shape[0]) ] ) self.pre_data = self.pre_data.reshape( (self.post_data.shape[0], self.ops[0].pre_filtered.shape[0]) ) self.weights_data = signals.combine([op.weights for op in self.ops]) self.output_data = signals.combine([op.delta for op in self.ops]) self.learning_rate = signals.op_constant( self.ops, [op.post_filtered.shape[0] for op in self.ops], "learning_rate", signals.dtype, shape=(1, -1, 1), ) self.beta = signals.op_constant( self.ops, [op.post_filtered.shape[0] for op in self.ops], "beta", signals.dtype, shape=(1, -1, 1), )
[docs] def build_step(self, signals): pre = signals.gather(self.pre_data) post = signals.gather(self.post_data) weights = signals.gather(self.weights_data) alpha = self.learning_rate * signals.dt update = alpha * post ** 2 update *= -self.beta * weights update += alpha * post * pre signals.scatter(self.output_data, update)
[docs] @staticmethod def mergeable(x, y): # pre inputs must have the same dimensionality so that we can broadcast # them when computing the outer product return x.pre_filtered.shape[0] == y.pre_filtered.shape[0]
[docs]@Builder.register(SimVoja) class SimVojaBuilder(OpBuilder): """Build a group of `~nengo.builder.learning_rules.SimVoja` operators."""
[docs] def build_pre(self, signals, config): super().build_pre(signals, config) self.post_data = signals.combine([op.post_filtered for op in self.ops]) self.post_data = self.post_data.reshape(self.post_data.shape + (1,)) self.pre_data = signals.combine( [ op.pre_decoded for op in self.ops for _ in range(op.post_filtered.shape[0]) ] ) self.pre_data = self.pre_data.reshape( (self.post_data.shape[0], self.ops[0].pre_decoded.shape[0]) ) self.learning_data = signals.combine( [ op.learning_signal for op in self.ops for _ in range(op.post_filtered.shape[0]) ] ) self.learning_data = self.learning_data.reshape(self.learning_data.shape + (1,)) self.encoder_data = signals.combine([op.scaled_encoders for op in self.ops]) self.output_data = signals.combine([op.delta for op in self.ops]) self.scale = tf.constant( np.concatenate([op.scale[None, :, None] for op in self.ops], axis=1), dtype=signals.dtype, ) self.learning_rate = signals.op_constant( self.ops, [op.post_filtered.shape[0] for op in self.ops], "learning_rate", signals.dtype, shape=(1, -1, 1), )
[docs] def build_step(self, signals): pre = signals.gather(self.pre_data) post = signals.gather(self.post_data) learning_signal = signals.gather(self.learning_data) scaled_encoders = signals.gather(self.encoder_data) alpha = self.learning_rate * signals.dt * learning_signal update = alpha * (self.scale * post * pre - post * scaled_encoders) signals.scatter(self.output_data, update)
[docs] @staticmethod def mergeable(x, y): # pre inputs must have the same dimensionality so that we can broadcast # them when computing the outer product return x.pre_decoded.shape[0] == y.pre_decoded.shape[0]
[docs]@NengoBuilder.register(PES) def build_pes(model, pes, rule): """ Builds a `nengo.PES` object into a Nengo model. Overrides the standard Nengo PES builder in order to avoid slicing on axes > 0 (not currently supported in NengoDL). Parameters ---------- model : Model The model to build into. pes : PES Learning rule type to build. rule : LearningRule The learning rule object corresponding to the neuron type. Notes ----- Does not modify ``model.params[]`` and can therefore be called more than once with the same `nengo.PES` instance. """ conn = rule.connection # Create input error signal error = Signal(shape=(rule.size_in,), name="PES:error") model.add_op(Reset(error)) model.sig[rule]["in"] = error # error connection will attach here acts = build_or_passthrough(model, pes.pre_synapse, model.sig[conn.pre_obj]["out"]) if not conn.is_decoded: # multiply error by post encoders to get a per-neuron error post = get_post_ens(conn) encoders = model.sig[post]["encoders"] if conn.post_obj is not conn.post: # in order to avoid slicing encoders along an axis > 0, we pad # `error` out to the full base dimensionality and then do the # dotinc with the full encoder matrix padded_error = Signal(shape=(encoders.shape[1],)) model.add_op(Copy(error, padded_error, dst_slice=conn.post_slice)) else: padded_error = error # error = dot(encoders, error) local_error = Signal(shape=(post.n_neurons,)) model.add_op(Reset(local_error)) model.add_op(DotInc(encoders, padded_error, local_error, tag="PES:encode")) else: local_error = error model.operators.append( SimPES(acts, local_error, model.sig[rule]["delta"], pes.learning_rate) ) # expose these for probes model.sig[rule]["error"] = error model.sig[rule]["activities"] = acts
[docs]@Builder.register(SimPES) class SimPESBuilder(OpBuilder): """Build a group of `~nengo.builder.learning_rules.SimPES` operators."""
[docs] def build_pre(self, signals, config): super().build_pre(signals, config) self.error_data = signals.combine([op.error for op in self.ops]) self.error_data = self.error_data.reshape( (len(self.ops), self.ops[0].error.shape[0], 1) ) self.pre_data = signals.combine([op.pre_filtered for op in self.ops]) self.pre_data = self.pre_data.reshape( (len(self.ops), 1, self.ops[0].pre_filtered.shape[0]) ) self.alpha = signals.op_constant( self.ops, [1 for _ in self.ops], "learning_rate", signals.dtype, shape=(1, -1, 1, 1), ) * (-signals.dt_val / self.ops[0].pre_filtered.shape[0]) assert all(op.encoders is None for op in self.ops) self.output_data = signals.combine([op.delta for op in self.ops])
[docs] def build_step(self, signals): pre_filtered = signals.gather(self.pre_data) error = signals.gather(self.error_data) error *= self.alpha update = error * pre_filtered signals.scatter(self.output_data, update)
[docs] @staticmethod def mergeable(x, y): # pre inputs must have the same dimensionality so that we can broadcast # them when computing the outer product. # the error signals also have to have the same shape. return ( x.pre_filtered.shape[0] == y.pre_filtered.shape[0] and x.error.shape[0] == y.error.shape[0] )