API reference

Layers

Components for building spiking models in Keras.

keras_spiking.layers.KerasSpikingCell

Base class for RNN cells in KerasSpiking.

keras_spiking.layers.KerasSpikingLayer

Base class for KerasSpiking layers.

keras_spiking.SpikingActivationCell

RNN cell for converting an arbitrary activation function to a spiking equivalent.

keras_spiking.SpikingActivation

Layer for converting an arbitrary activation function to a spiking equivalent.

keras_spiking.LowpassCell

RNN cell for a lowpass filter.

keras_spiking.Lowpass

Layer implementing a lowpass filter.

keras_spiking.AlphaCell

RNN cell for an alpha filter.

keras_spiking.Alpha

Layer implementing an alpha filter.

class keras_spiking.layers.KerasSpikingCell(*args, **kwargs)[source]

Base class for RNN cells in KerasSpiking.

The important feature of this class is that it allows cells to define a different implementation to be used in training versus inference.

Parameters
sizeint or tuple of int or tf.TensorShape

Input/output shape of the layer (not including batch/time dimensions).

state_sizeint or tuple of int or tf.TensorShape

Shape of the cell state. If None, use size.

dtfloat

Length of time (in seconds) represented by one time step. If None, uses keras_spiking.default.dt (which is 0.001 seconds by default).

always_use_inferencebool

If True, this layer will use its call_inference behaviour during training, rather than call_training.

kwargsdict

Passed on to tf.keras.layers.Layer.

call(inputs, states, training=None)[source]

Call function that defines a different forward pass during training versus inference.

call_training(inputs, states)[source]

Compute layer output when training and always_use_inference=False.

call_inference(inputs, states)[source]

Compute layer output when testing or always_use_inference=True.

class keras_spiking.layers.KerasSpikingLayer(*args, **kwargs)[source]

Base class for KerasSpiking layers.

The main role of this class is to wrap a KerasSpikingCell in a tf.keras.layers.RNN.

Parameters
dtfloat

Length of time (in seconds) represented by one time step. If None, uses keras_spiking.default.dt (which is 0.001 seconds by default).

return_sequencesbool

Whether to return the full sequence of output values (default), or just the values on the last timestep.

return statebool

Whether to return the state in addition to the output.

statefulbool

If False (default), each time the layer is called it will begin from the same initial conditions. If True, each call will resume from the terminal state of the previous call (my_layer.reset_states() can be called to reset the state to initial conditions).

unrollbool

If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed up computations, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences.

time_majorbool

The shape format of the input and output tensors. If True, the inputs and outputs will be in shape (timesteps, batch, ...), whereas in the False case, it will be (batch, timesteps, ...). Using time_major=True is a bit more efficient because it avoids transposes at the beginning and end of the layer calculation. However, most TensorFlow data is batch-major, so by default this layer accepts input and emits output in batch-major form.

kwargsdict

Passed on to tf.keras.layers.Layer.

build_cell(input_shape)[source]

Create and return the RNN cell.

build(input_shape)[source]

Builds the RNN/cell layers contained within this layer.

Notes

This method should not be called manually; rather, use the implicit layer callable behaviour (like my_layer(inputs)), which will apply this method with some additional bookkeeping.

call(inputs, training=None, initial_state=None, constants=None)[source]

Apply this layer to inputs.

Notes

This method should not be called manually; rather, use the implicit layer callable behaviour (like my_layer(inputs)), which will apply this method with some additional bookkeeping.

reset_states(states=None)[source]

Reset the internal state of the layer (only necessary if stateful=True).

Parameters
statesndarray

Optional state array that can be used to override the values returned by cell.get_initial_state, where cell is returned by build_cell.

get_config()[source]

Return config of layer (for serialization during model saving/loading).

class keras_spiking.SpikingActivationCell(*args, **kwargs)[source]

RNN cell for converting an arbitrary activation function to a spiking equivalent.

Neurons will spike at a rate proportional to the output of the base activation function. For example, if the activation function is outputting a value of 10, then the wrapped SpikingActivationCell will output spikes at a rate of 10Hz (i.e., 10 spikes per 1 simulated second, where 1 simulated second is equivalent to 1/dt time steps). Each spike will have height 1/dt (so that the integral of the spiking output will be the same as the integral of the base activation output). Note that if the base activation is outputting a negative value then the spikes will have height -1/dt. Multiple spikes per timestep are also possible, in which case the output will be n/dt (where n is the number of spikes).

Parameters
sizeint or tuple of int or tf.TensorShape

Input/output shape of the layer (not including batch/time dimensions).

activationcallable

Activation function to be converted to spiking equivalent.

dtfloat

Length of time (in seconds) represented by one time step. If None, uses keras_spiking.default.dt (which is 0.001 seconds by default).

seedint

Seed for random state initialization.

spiking_aware_trainingbool

If True (default), use the spiking activation function for the forward pass and the base activation function for the backward pass. If False, use the base activation function for the forward and backward pass during training.

kwargsdict

Passed on to tf.keras.layers.Layer.

Notes

This cell needs to be wrapped in a tf.keras.layers.RNN, like

my_layer = tf.keras.layers.RNN(
    keras_spiking.SpikingActivationCell(size=10, activation=tf.nn.relu)
)
get_initial_state(inputs=None, batch_size=None, dtype=None)[source]

Set up initial spiking state.

Initial state is chosen from a uniform distribution, seeded based on the seed passed on construction (if one was given).

Note: state will be initialized automatically, user does not need to call this themselves.

call_training(inputs, states)[source]

Compute layer output when training and always_use_inference=False.

call_inference(inputs, states)

Compute spiking output, with custom gradient for spiking aware training.

Parameters
inputstf.Tensor

Input to the activation function.

voltagetf.Tensor

Spiking voltage state.

Returns
spikestf.Tensor

Output spike values (0 or n/dt for each element in inputs, where n is the number of spikes).

voltagetf.Tensor

Updated voltage state.

gradcallable

Custom gradient function for spiking aware training.

get_config()[source]

Return config of layer (for serialization during model saving/loading).

class keras_spiking.SpikingActivation(*args, **kwargs)[source]

Layer for converting an arbitrary activation function to a spiking equivalent.

Neurons will spike at a rate proportional to the output of the base activation function. For example, if the activation function is outputting a value of 10, then the wrapped SpikingActivationCell will output spikes at a rate of 10Hz (i.e., 10 spikes per 1 simulated second, where 1 simulated second is equivalent to 1/dt time steps). Each spike will have height 1/dt (so that the integral of the spiking output will be the same as the integral of the base activation output). Note that if the base activation is outputting a negative value then the spikes will have height -1/dt. Multiple spikes per timestep are also possible, in which case the output will be n/dt (where n is the number of spikes).

When applying this layer to an input, make sure that the input has a time axis (the time_major option controls whether it comes before or after the batch axis). The spiking output will be computed along the time axis. The number of simulation timesteps will depend on the length of that time axis. The number of timesteps does not need to be the same during training/evaluation/inference. In particular, it may be more efficient to use one timestep during training and multiple timesteps during inference (often with spiking_aware_training=False, and apply_during_training=False on any Lowpass layers).

Parameters
activationcallable

Activation function to be converted to spiking equivalent.

dtfloat

Length of time (in seconds) represented by one time step. If None, uses keras_spiking.default.dt (which is 0.001 seconds by default).

seedint

Seed for random state initialization.

spiking_aware_trainingbool

If True (default), use the spiking activation function for the forward pass and the base activation function for the backward pass. If False, use the base activation function for the forward and backward pass during training.

return_sequencesbool

Whether to return the full sequence of output spikes (default), or just the spikes on the last timestep.

return statebool

Whether to return the state in addition to the output.

statefulbool

If False (default), each time the layer is called it will begin from the same initial conditions. If True, each call will resume from the terminal state of the previous call (my_layer.reset_states() can be called to reset the state to initial conditions).

unrollbool

If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed up computations, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences.

time_majorbool

The shape format of the input and output tensors. If True, the inputs and outputs will be in shape (timesteps, batch, ...), whereas in the False case, it will be (batch, timesteps, ...). Using time_major=True is a bit more efficient because it avoids transposes at the beginning and end of the layer calculation. However, most TensorFlow data is batch-major, so by default this layer accepts input and emits output in batch-major form.

kwargsdict

Passed on to tf.keras.layers.Layer.

Notes

This is equivalent to tf.keras.layers.RNN(SpikingActivationCell(...) ...), it just takes care of the RNN construction automatically.

build_cell(input_shape)[source]

Create and return the RNN cell.

get_config()[source]

Return config of layer (for serialization during model saving/loading).

class keras_spiking.LowpassCell(*args, **kwargs)[source]

RNN cell for a lowpass filter.

The initial filter state and filter time constants are both trainable parameters. However, if apply_during_training=False then the parameters are not part of the training loop, and so will never be updated.

Parameters
sizeint or tuple of int or tf.TensorShape

Input/output shape of the layer (not including batch/time dimensions).

tau_initializerfloat or str or tf.keras.initializers.Initializer

Initial value of time constant of filter (in seconds). Passing a float will initialize it to that value, or any standard Keras initializer can be used.

dtfloat

Length of time (in seconds) represented by one time step. If None, uses keras_spiking.default.dt (which is 0.001 seconds by default).

apply_during_trainingbool

If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g., SpikingActivation, since if the activations are not spiking during training then we often don’t need to filter them either).

level_initializerstr or tf.keras.initializers.Initializer

Initializer for filter state.

initial_level_constraintstr or tf.keras.constraints.Constraint

Constraint for initial_level.

tau_constraintstr or tf.keras.constraints.Constraint

Constraint for tau. For example, Mean will share the same time constant across all of the lowpass filters. The time constant is always clipped to be positive in the forward pass for numerical stability.

kwargsdict

Passed on to tf.keras.layers.Layer.

Notes

This cell needs to be wrapped in a tf.keras.layers.RNN, like

my_layer = tf.keras.layers.RNN(
    keras_spiking.LowpassCell(size=10, tau_initializer=0.01)
)
build(input_shape)[source]

Build parameters associated with this layer.

get_initial_state(inputs=None, batch_size=None, dtype=None)[source]

Get initial filter state.

call_inference(inputs, states)[source]

Compute layer output when testing or always_use_inference=True.

call_training(inputs, states)[source]

Compute layer output when training and always_use_inference=False.

get_config()[source]

Return config of layer (for serialization during model saving/loading).

class keras_spiking.Lowpass(*args, **kwargs)[source]

Layer implementing a lowpass filter.

The impulse-response function (time domain) and transfer function are:

\[\begin{split}h(t) &= (1 / \tau) \exp(-t / \tau) \\ H(s) &= \frac{1}{\tau s + 1}\end{split}\]

The initial filter state and filter time constants are both trainable parameters. However, if apply_during_training=False then the parameters are not part of the training loop, and so will never be updated.

When applying this layer to an input, make sure that the input has a time axis (the time_major option controls whether it comes before or after the batch axis).

Parameters
tau_initializerfloat or str or tf.keras.initializers.Initializer

Initial value of time constant of filter (in seconds). Passing a float will initialize it to that value, or any standard Keras initializer can be used.

dtfloat

Length of time (in seconds) represented by one time step. If None, uses keras_spiking.default.dt (which is 0.001 seconds by default).

apply_during_trainingbool

If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g., SpikingActivation, since if the activations are not spiking during training then we often don’t need to filter them either).

level_initializerstr or tf.keras.initializers.Initializer

Initializer for filter state.

initial_level_constraintstr or tf.keras.constraints.Constraint

Constraint for initial_level.

tau_constraintstr or tf.keras.constraints.Constraint

Constraint for tau. For example, Mean will share the same time constant across all of the lowpass filters. The time constant is always clipped to be positive in the forward pass for numerical stability.

return_sequencesbool

Whether to return the full sequence of filtered output (default), or just the output on the last timestep.

return statebool

Whether to return the state in addition to the output.

statefulbool

If False (default), each time the layer is called it will begin from the same initial conditions. If True, each call will resume from the terminal state of the previous call (my_layer.reset_states() can be called to reset the state to initial conditions).

unrollbool

If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed up computations, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences.

time_majorbool

The shape format of the input and output tensors. If True, the inputs and outputs will be in shape (timesteps, batch, ...), whereas in the False case, it will be (batch, timesteps, ...). Using time_major=True is a bit more efficient because it avoids transposes at the beginning and end of the layer calculation. However, most TensorFlow data is batch-major, so by default this layer accepts input and emits output in batch-major form.

kwargsdict

Passed on to tf.keras.layers.Layer.

Notes

This is equivalent to tf.keras.layers.RNN(LowpassCell(...) ...), it just takes care of the RNN construction automatically.

build_cell(input_shape)[source]

Create and return the RNN cell.

get_config()[source]

Return config of layer (for serialization during model saving/loading).

class keras_spiking.AlphaCell(*args, **kwargs)[source]

RNN cell for an alpha filter.

The initial filter state and filter time constants are both trainable parameters. However, if apply_during_training=False then the parameters are not part of the training loop, and so will never be updated.

Parameters
sizeint or tuple of int or tf.TensorShape

Input/output shape of the layer (not including batch/time dimensions).

tau_initializerfloat or str or tf.keras.initializers.Initializer

Initial value of time constant of filter (in seconds). Passing a float will initialize it to that value, or any standard Keras initializer can be used.

dtfloat

Length of time (in seconds) represented by one time step. If None, uses keras_spiking.default.dt (which is 0.001 seconds by default).

apply_during_trainingbool

If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g., SpikingActivation, since if the activations are not spiking during training then we often don’t need to filter them either).

level_initializerstr or tf.keras.initializers.Initializer

Initializer for filter state.

initial_level_constraintstr or tf.keras.constraints.Constraint

Constraint for initial_level.

tau_constraintstr or tf.keras.constraints.Constraint

Constraint for tau. For example, Mean will share the same time constant across all of the lowpass filters. The time constant is always clipped to be positive in the forward pass for numerical stability.

kwargsdict

Passed on to tf.keras.layers.Layer.

Notes

This cell needs to be wrapped in a tf.keras.layers.RNN, like

my_layer = tf.keras.layers.RNN(
    keras_spiking.AlphaCell(size=10, tau_initializer=0.01)
)
build(input_shape)[source]

Build parameters associated with this layer.

get_initial_state(inputs=None, batch_size=None, dtype=None)[source]

Get initial filter state.

call_inference(inputs, states)[source]

Compute layer output when testing or always_use_inference=True.

call_training(inputs, states)[source]

Compute layer output when training and always_use_inference=False.

get_config()[source]

Return config of layer (for serialization during model saving/loading).

class keras_spiking.Alpha(*args, **kwargs)[source]

Layer implementing an alpha filter.

The impulse-response function (time domain) and transfer function are:

\[\begin{split}h(t) &= (t / \tau^2) \exp(-t / \tau) \\ H(s) &= \frac{1}{(\tau s + 1)^2}\end{split}\]

The initial filter state and filter time constants are both trainable parameters. However, if apply_during_training=False then the parameters are not part of the training loop, and so will never be updated.

When applying this layer to an input, make sure that the input has a time axis (the time_major option controls whether it comes before or after the batch axis).

Parameters
tau_initializerfloat or str or tf.keras.initializers.Initializer

Initial value of time constant of filter (in seconds). Passing a float will initialize it to that value, or any standard Keras initializer can be used.

dtfloat

Length of time (in seconds) represented by one time step. If None, uses keras_spiking.default.dt (which is 0.001 seconds by default).

apply_during_trainingbool

If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g., SpikingActivation, since if the activations are not spiking during training then we often don’t need to filter them either).

level_initializerstr or tf.keras.initializers.Initializer

Initializer for filter state.

initial_level_constraintstr or tf.keras.constraints.Constraint

Constraint for initial_level.

tau_constraintstr or tf.keras.constraints.Constraint

Constraint for tau. For example, Mean will share the same time constant across all of the lowpass filters. The time constant is always clipped to be positive in the forward pass for numerical stability.

return_sequencesbool

Whether to return the full sequence of filtered output (default), or just the output on the last timestep.

return statebool

Whether to return the state in addition to the output.

statefulbool

If False (default), each time the layer is called it will begin from the same initial conditions. If True, each call will resume from the terminal state of the previous call (my_layer.reset_states() can be called to reset the state to initial conditions).

unrollbool

If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed up computations, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences.

time_majorbool

The shape format of the input and output tensors. If True, the inputs and outputs will be in shape (timesteps, batch, ...), whereas in the False case, it will be (batch, timesteps, ...). Using time_major=True is a bit more efficient because it avoids transposes at the beginning and end of the layer calculation. However, most TensorFlow data is batch-major, so by default this layer accepts input and emits output in batch-major form.

kwargsdict

Passed on to tf.keras.layers.Layer.

Notes

This is equivalent to tf.keras.layers.RNN(AlphaCell(...) ...), it just takes care of the RNN construction automatically.

build_cell(input_shape)[source]

Create and return the RNN cell.

get_config()[source]

Return config of layer (for serialization during model saving/loading).

Regularizers

Regularization methods designed to work with spiking layers.

keras_spiking.regularizers.RangedRegularizer

A regularizer that penalizes values that fall outside a range.

keras_spiking.regularizers.L1L2

A version of tf.keras.regularizers.L1L2 that allows the user to specify a nonzero target output.

keras_spiking.regularizers.L1

A version of tf.keras.regularizers.L1 that allows the user to specify a nonzero target output.

keras_spiking.regularizers.L2

A version of tf.keras.regularizers.L2 that allows the user to specify a nonzero target output.

keras_spiking.regularizers.Percentile

A regularizer that penalizes a percentile of a tensor.

class keras_spiking.regularizers.RangedRegularizer(target=0, regularizer=<keras.regularizers.L1L2 object>)[source]

A regularizer that penalizes values that fall outside a range.

This allows regularized values to fall anywhere within the range, as opposed to standard regularizers that penalize any departure from some fixed point.

Parameters
targetfloat or tuple

The value that we want the regularized outputs to be driven towards. Can be a float, in which case all outputs will be driven towards that value, or a tuple specifying a range (min, max), in which case outputs outside that range will be driven towards that range (but outputs within the range will not be penalized).

regularizer: ``tf.keras.regularizers.Regularizer``

Regularization penalty that will be applied to the outputs with respect to target.

get_config()[source]

Return config (for serialization during model saving/loading).

classmethod from_config(config)[source]

Create a new instance from the serialized config.

class keras_spiking.regularizers.L1L2(l1=0.0, l2=0.0, target=0, **kwargs)[source]

A version of tf.keras.regularizers.L1L2 that allows the user to specify a nonzero target output.

Parameters
l1float

Weight on L1 regularization penalty.

l2float

Weight on L2 regularization penalty.

targetfloat or tuple

The value that we want the regularized outputs to be driven towards. Can be a float, in which case all outputs will be driven towards that value, or a tuple specifying a range (min, max), in which case outputs outside that range will be driven towards that range (but outputs within the range will not be penalized).

get_config()[source]

Return config (for serialization during model saving/loading).

classmethod from_config(config)[source]

Create a new instance from the serialized config.

class keras_spiking.regularizers.L1(l1=0.01, target=0, **kwargs)[source]

A version of tf.keras.regularizers.L1 that allows the user to specify a nonzero target output.

Parameters
l1float

Weight on L1 regularization penalty.

targetfloat or tuple

The value that we want the regularized outputs to be driven towards. Can be a float, in which case all outputs will be driven towards that value, or a tuple specifying a range (min, max), in which case outputs outside that range will be driven towards that range (but outputs within the range will not be penalized).

class keras_spiking.regularizers.L2(l2=0.01, target=0, **kwargs)[source]

A version of tf.keras.regularizers.L2 that allows the user to specify a nonzero target output.

Parameters
l2float

Weight on L2 regularization penalty.

targetfloat or tuple

The value that we want the regularized outputs to be driven towards. Can be a float, in which case all outputs will be driven towards that value, or a tuple specifying a range (min, max), in which case outputs outside that range will be driven towards that range (but outputs within the range will not be penalized).

class keras_spiking.regularizers.Percentile(percentile=100, axis=0, target=0, l1=0, l2=0)[source]

A regularizer that penalizes a percentile of a tensor.

This regularizer finds the requested percentile of the data over the axis, and then applies a regularizer to the percentile values with respect to target. This can be useful as it is makes the computed regularization penalty more invariant to outliers.

Parameters
percentilefloat

Percentile to compute over the axis. Defaults to 100, which is equivalent to taking the maximum across the specified axis.

Note

For percentile != 100, requires tensorflow-probability.

axisint or tuple of int

Axis or axes to take the percentile over.

targetfloat or tuple

The value that we want the regularized outputs to be driven towards. Can be a float, in which case all outputs will be driven towards that value, or a tuple specifying a range (min, max), in which case outputs outside that range will be driven towards that range (but outputs within the range will not be penalized).

l1float

Weight on L1 regularization penalty applied to percentiles.

l2float

Weight on L2 regularization penalty applied to percentiles.

Examples

In the following example, we use Percentile to ensure the neuron activities (a.k.a., firing rates) fall in the desired range of 5-10 Hz when computing the product of two inputs.

train_x = np.random.uniform(-1, 1, size=(1024 * 100, 2))
train_y = train_x[:, :1] * train_x[:, 1:]
test_x = np.random.uniform(-1, 1, size=(128, 2))
test_y = test_x[:, :1] * test_x[:, 1:]

# train using one timestep, to speed things up
train_seq = train_x[:, None]

# test using 10 timesteps
n_steps = 10
test_seq = np.tile(test_x[:, None], (1, n_steps, 1))

inp = x = tf.keras.Input((None, 2))
x = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(50))(x)
x = spikes = keras_spiking.SpikingActivation(
    "relu",
    dt=1,
    activity_regularizer=keras_spiking.regularizers.Percentile(
        target=(5, 10), l2=0.01
    ),
)(x)
x = tf.keras.layers.GlobalAveragePooling1D()(x)
x = tf.keras.layers.Dense(1)(x)

model = tf.keras.Model(inp, (x, spikes))

model.compile(
    # note: we use a dict to specify loss/metrics because we only want to
    # apply these to the final dense output, not the spike layer
    optimizer="rmsprop", loss={"dense_1": "mse"}, metrics={"dense_1": "mae"}
)
model.fit(train_seq, train_y, epochs=5)

outputs, spikes = model.predict(test_seq)

# estimate rates by averaging over time
rates = spikes.mean(axis=1)
max_rates = rates.max(axis=0)
print("Max rates: %s, %s" % (max_rates.mean(), max_rates.std()))

error = np.mean(np.abs(outputs - test_y))
print("MAE: %s" % (error,))
get_config()[source]

Return config (for serialization during model saving/loading).

Constraints

Custom constraints for weight tensors in Keras models.

keras_spiking.constraints.Mean

Constrains weight tensors to be their mean.

class keras_spiking.constraints.Mean(axis=- 1)[source]

Constrains weight tensors to be their mean.

Parameters
axisint

Axis used to compute the mean and repeat its value. Defaults to the last axis.

get_config()[source]

Return config of layer (for serialization during model saving/loading).

Callbacks

Callbacks for use with KerasSpiking models.

keras_spiking.callbacks.DtScheduler

A callback for updating Layer dt attributes during training.

class keras_spiking.callbacks.DtScheduler(dt, scheduler, verbose=False)[source]

A callback for updating Layer dt attributes during training.

This uses the same scheduler interface as TensorFlow’s learning rate schedulers, so any of those built-in schedules can be used to adjust dt, or a custom function implementing the same interface.

When using this functionality, dt should be initialized as a tf.Variable, and that Variable should be passed as the dt parameter to any Layers that should be affected by this callback.

For example:

dt = tf.Variable(1.0)

inp = tf.keras.Input((None, 10))
x = keras_spiking.SpikingActivation("relu", dt=dt)(inp)
x = keras_spiking.Lowpass(0.1, dt=dt)(x)
model = tf.keras.Model(inp, x)

callback = keras_spiking.callbacks.DtScheduler(
    dt, tf.optimizers.schedules.ExponentialDecay(
        1.0, decay_steps=5, decay_rate=0.9
    )
)

model.compile(loss="mse", optimizer="sgd")
model.fit(
    np.ones((100, 2, 10)),
    np.ones((100, 2, 10)),
    epochs=10,
    batch_size=20,
    callbacks=[callback],
)
Parameters
dttf.Variable

Variable representing dt that has been passed to other Layers.

schedulertf.optimizers.schedules.LearningRateSchedule

A schedule class that will update dt based on the training step (one training step is one minibatch worth of training).

verbosebool

If True, print out some information about dt updates during training.

Notes

Because Variable values persist over time, any changes made to dt by this callback will persist after training completes. For example, if you call fit with this callback and then predict later on, that predict call will be using the last dt value set by this callback.

on_epoch_begin(epoch, logs=None)[source]

Keep track of the current epoch so we can count the total number of steps.

on_train_batch_begin(batch, logs=None)[source]

Update dt variable based on the current training step.

Configuration

Configuration options for KerasSpiking layers.

keras_spiking.config.DefaultManager

Manages the default parameter values for KerasSpiking layers.

class keras_spiking.config.DefaultManager(dt=0.001)[source]

Manages the default parameter values for KerasSpiking layers.

Parameters
dtfloat

Length of time (in seconds) represented by one time step. Defaults to 0.001s.

Notes

Do not instantiate this class directly, instead access it through keras_spiking.default.

Energy estimation

Estimate energy usage on various devices for Keras models.

keras_spiking.ModelEnergy

Compute statistics and device energy estimates for a Keras model.

keras_spiking.model_energy.layer_stats

Fallback for computing stats on an unknown layer (assumes none).

keras_spiking.model_energy.act_stats

Compute activation layer stats.

keras_spiking.model_energy.conv_stats

Compute Conv1D/Conv2D/Conv3D layer stats.

keras_spiking.model_energy.dense_stats

Compute Dense layer stats.

keras_spiking.model_energy.spikingactivation_stats

Compute SpikingActivation layer stats.

keras_spiking.model_energy.timedistributed_stats

Compute TimeDistributed layer stats.

class keras_spiking.ModelEnergy(model, example_data=None)[source]

Compute statistics and device energy estimates for a Keras model.

Computes the following statistics on each layer:

  • “connections”: The number of connections from all input elements to all activation units. The number of synaptic operations per second (“synops”) will be computed by multiplying this number by the average firing rate of the input to the layer.

  • “neurons”: The number of neuron updates per timestep performed by the layer. The number of neuron updates per inference will be computed by multiplying this number by the number of timesteps per inference.

Using expected average firing rates for each layer in the network, along with the above statistics, this class can estimate the energy usage on one of the following types of devices (see total_energy and summary):

  • “cpu”: Estimate energy usage on a CPU (Intel i7-4960X), assuming each synop/neuron update is one MAC [1].

  • “gpu”: Estimate energy usage on a GPU (Nvidia GTX Titan Black), assuming each synop/neuron update is one MAC [1]. Note that this assumes significant parallelism (e.g., inputs being processed in large batches).

  • “arm”: Estimate energy usage on an ARM Cortex-A, assuming each synop/neuron update is one MAC [1].

  • “loihi”: Estimate energy usage on the Intel Loihi chip [2].

  • “spinnaker” and “spinnaker2”: Estimate energy usage on SpiNNaker or SpiNNaker 2 [3].

Note: on non-spiking devices (“cpu”/”gpu”/”arm”) this assumes the model is being run as a traditional non-spiking ANN (computing every synapse each timestep), not taking advantage of spike-based computation. This estimate is therefore independent of example_data. On spiking devices (“loihi”/”spinnaker1”/”spinnaker2”), we assume that the model has been fully converted to a spiking implementation in some way, even if model contains non-spiking elements.

For example:

inp = tf.keras.Input(784)
dense = tf.keras.layers.Dense(units=128, activation="relu")(inp)
model = tf.keras.Model(inp, dense)

energy = keras_spiking.ModelEnergy(model)
energy.summary(line_length=80)
Layer (type)        |Output shape |Param #|Conn #|Neuron #|J/inf (cpu)
--------------------|-------------|-------|------|--------|-----------
input_3 (InputLayer)|[(None, 784)]|      0|     0|       0|          0
dense_2 (Dense)     |  (None, 128)| 100480|100352|     128|    0.00086
======================================================================
Total energy per inference [Joules/inf] (cpu): 8.64e-04
...

Additional devices or different energy assumptions for a given device can be added with register_device, e.g.

keras_spiking.ModelEnergy.register_device(
    "my-cpu", energy_per_synop=1e-10, energy_per_neuron=2e-9, spiking=False
)
energy.summary(
    columns=("name", "energy cpu", "energy my-cpu"), line_length=80
)
Layer (type)        |J/inf (cpu)|J/inf (my-cpu)
--------------------|-----------|--------------
input_3 (InputLayer)|          0|             0
dense_2 (Dense)     |    0.00086|         1e-05
===============================================
Total energy per inference [Joules/inf] (cpu): 8.64e-04
Total energy per inference [Joules/inf] (my-cpu): 1.03e-05
...
Parameters
modeltf.keras.Model

The model to compute statistics and energy estimates for.

example_dataarray_like

Input used to estimate average firing rates of each layer (used to estimate the number of synaptic events). It is passed directly to model.predict (see tf.keras.Model.predict for all acceptable types of input data). This is required to estimate energy on spiking devices, but does not affect non-spiking devices.

Notes

It is important to keep in mind that actual power usage will be heavily dependent on the specific details of the underlying software and hardware implementation. The numbers provided by ModelEnergy should be taken as very rough estimates only, and they rely on a number of assumptions:

  • Device specifications: In order to estimate the energy used by a model on a particular device, we need to know how much energy is used per synaptic operation/neuron update. We rely on published data for these numbers (see our sources below). Energy numbers in practice can differ significantly from published results.

  • Overhead: We do not account for any overhead in the energy estimates (e.g., the cost of transferring data on and off a device). We only estimate the energy usage of internal model computations (synaptic operations and neuron updates). In practice, overhead can be a significant contributor to the energy usage of a model.

  • Spiking implementation: We assume that the model being estimated has been fully converted to a spiking implementation when estimating the energy usage on a spiking device (even if the input model has non-spiking elements). For example, if the model contains tf.keras.layers.Activation("relu") layers (non-spiking), we assume that on a spiking device those layers will be converted to something equivalent to keras_spiking.SpikingActivation("relu"), and that any connecting layers (e.g. tf.keras.layers.Dense) are applied in an event-based fashion (i.e., processing only occurs when the input layer emits a spike). In practice, it is not trivial to map a neural network to a spiking device in this way, and implementation details can significantly affect energy usage. [Nengo](https://www.nengo.ai/nengo/) and [NengoDL](https://www.nengo.ai/nengo-dl/) are designed to make this easier.

References

1(1,2,3)

Degnan, Brian, Bo Marr, and Jennifer Hasler. “Assessing trends in performance per watt for signal processing applications.” IEEE Transactions on Very Large Scale Integration (VLSI) Systems 24.1 (2015): 58-66. https://ieeexplore.ieee.org/abstract/document/7054508

2

Davies, Mike, et al. “Loihi: A neuromorphic manycore processor with on-chip learning.” IEEE Micro 38.1 (2018): 82-99. https://redwood.berkeley.edu/wp-content/uploads/2021/08/Davies2018.pdf

3

Höppner, Sebastian, et al. “Dynamic power management for neuromorphic many-core systems.” IEEE Transactions on Circuits and Systems I: Regular Papers 66.8 (2019): 2973-2986. https://arxiv.org/abs/1903.08941

classmethod compute_layer_stats(layer, node=None)[source]

Compute statistics for a given layer.

Examples

inp = tf.keras.Input([None, 10])
layer = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(5, "relu"))
model = tf.keras.Model(inp, [layer(inp)])

print(keras_spiking.ModelEnergy.compute_layer_stats(layer))
{'connections': 50, 'neurons': 5, 'spiking': False}
classmethod register_layer(layer_class)[source]

Decorator to register a statistic calculator for a layer.

The input to the decorated function is a node from a layer for which we want to compute statistics.

The decorated function should return a dictionary with the following entries:

  • “connections”: The number of connections from all input elements to all activation units. Defaults to 0.

  • “neurons”: The number of neuron updates per timestep performed by this layer. Defaults to 0.

  • “spiking”: Whether or not this layer could be implemented on a spiking device as-is (e.g. tf.keras.layers.ReLU returns spiking=False because non-spiking nonlinearities like ReLU can’t be directly implemented on a spiking device). Defaults to True.

Examples

If we know that the tf.keras.layers.UpSampling2D layer uses one synaptic operation per output element, we can register the following function:

@keras_spiking.ModelEnergy.register_layer(tf.keras.layers.UpSampling2D)
def upsampling2d_stats(node):
    # note: ignore the batch dimension when computing output size
    output_size = np.prod(node.output_shapes[1:])

    return {"connections": output_size, "neurons": 0}

# use our registered stat calculator
inp = tf.keras.Input([4, 4, 3])
layer = tf.keras.layers.UpSampling2D(size=(2, 2))
model = tf.keras.Model(inp, [layer(inp)])

print(keras_spiking.ModelEnergy.compute_layer_stats(layer))
{'connections': 192, 'neurons': 0}

We see that the synaptic operations is 192, which equals the number of pixels in the upsampled image size of (8, 8), times the number of channels (3).

classmethod register_device(device_name, energy_per_synop, energy_per_neuron, spiking)[source]

Register a new device type for estimating energy consumption.

Parameters
device_namestr

The string to use to refer to the device.

energy_per_synopfloat

The energy (in Joules) used by a single synaptic operation on the device.

energy_per_neuronfloat

The energy (in Joules) used by a single neuron update on the device.

spikingbool

Whether the device is spiking (event-based), and thus only computes synaptic updates for incoming spikes, rather than on every timestep.

layer_energy(layer, device, timesteps_per_inference=1, dt=None)[source]

Estimate the energy used by one layer.

Parameters
layertf.keras.layers.Layer

Layer to estimate energy for. Note that if the same layer is being reused multiple times in the model, this will return the total energy summed over all the applications.

devicestr

Device to estimate energy for. Can be a supported device (see ModelEnergy for a list), or another device added with ModelEnergy.register_device.

timesteps_per_inferenceint

Timesteps used per inference (for example, if the model is classifying images and we want to present each image for 10 timesteps).

dtfloat

The length of one timestep, in seconds, used by the device. Used to compute the number of synaptic events based on the firing rates (in Hz). Can differ from the dt used on any keras_spiking layers in the model. If None, uses keras_spiking.default.dt (which is 0.001 seconds by default).

Returns
synop_energyfloat

Estimated energy used (in Joules) for synaptic computations per inference.

neuron_energyfloat

Estimated energy used (in Joules) for neuron updates per inference.

total_energy(device, timesteps_per_inference=1, dt=None)[source]

Estimate the energy usage for a whole model.

Parameters
devicestr

Device to estimate energy for. Can be a supported device (see ModelEnergy for a list), or another device added with ModelEnergy.register_device.

timesteps_per_inferenceint

Timesteps used per inference (for example, if the model is classifying images and we want to present each image for 10 timesteps).

dtfloat

The length of one timestep, in seconds, used by the device. Used to compute the number of synaptic events based on the firing rates (in Hz). Can differ from the dt used on any keras_spiking layers in the model. If None, uses keras_spiking.default.dt (which is 0.001 seconds by default).

Returns
energyfloat

Total estimated energy used by the model (in Joules) per inference.

summary(columns=('name', 'output_shape', 'params', 'connections', 'neurons', 'energy cpu'), timesteps_per_inference=1, dt=None, line_length=98, print_warnings=True)[source]

Print a per-layer summary of computation statistics and energy estimates.

Parameters
columnslist or tuple of string

Columns to display. Can be any combination of the following:

  • “name”: The layer name.

  • “output_shape”: The output shape of the layer.

  • “params”: The number of parameters in the layer.

  • “connections”: The number of synaptic connections from inputs to the neurons of this layer (see ModelEnergy for the definition of “connections”).

  • “neurons”: The number of neuron updates performed by the layer each timestep (see ModelEnergy for the definition of “neuron update”).

  • “rate”: The average input firing rate to the layer, in spikes per second. Note that this is only relevant for layers that perform synaptic operations; for other layers (e.g. an activation layer that gets input from a convolutional layer), this number has no effect.

  • “synop_energy <device>”: The estimated energy in Joules per inference used by the layer on <device> for synaptic operations.

  • “neuron_energy <device>”: The estimated energy in Joules per inference used by the layer on <device> for neuron updates.

  • “energy <device>”: The total estimated energy in Joules per inference used by the layer on <device>.

Here, <device> can be any of the supported devices (see ModelEnergy). Additional devices can be added with ModelEnergy.register_device.

timesteps_per_inferenceint

Timesteps used per inference (for example, if the model is classifying images and we want to present each image for 10 timesteps).

dtfloat

The length of one timestep, in seconds, used by the device. Used to compute the number of synaptic events based on the firing rates (in Hz). Can differ from the dt used on any keras_spiking layers in the model. If None, uses keras_spiking.default.dt (which is 0.001 seconds by default).

line_lengthint

The length of each printed line.

print_warningsbool

Set to False to disable the warnings regarding assumptions made in the energy calculations.

summary_string(columns=('name', 'output_shape', 'params', 'connections', 'neurons', 'energy cpu'), timesteps_per_inference=1, dt=None, line_length=98, print_warnings=True)[source]

Returns a per-layer summary of computation statistics and energy estimates.

The same as summary, except returns the summary as a string, rather than printing it.

For documentation on parameters and other features, see summary.

keras_spiking.model_energy.layer_stats(node, **_)[source]

Fallback for computing stats on an unknown layer (assumes none).

keras_spiking.model_energy.act_stats(node)[source]

Compute activation layer stats.

keras_spiking.model_energy.conv_stats(node)[source]

Compute Conv1D/Conv2D/Conv3D layer stats.

keras_spiking.model_energy.dense_stats(node)[source]

Compute Dense layer stats.

keras_spiking.model_energy.spikingactivation_stats(node)[source]

Compute SpikingActivation layer stats.

keras_spiking.model_energy.timedistributed_stats(node)[source]

Compute TimeDistributed layer stats.

Calls ModelEnergy.compute_layer_stats on the wrapped layer.