API reference¶
Layers¶
Components for building spiking models in Keras.
Base class for RNN cells in KerasSpiking. 

Base class for KerasSpiking layers. 

RNN cell for converting an arbitrary activation function to a spiking equivalent. 

Layer for converting an arbitrary activation function to a spiking equivalent. 

RNN cell for a lowpass filter. 

Layer implementing a lowpass filter. 

RNN cell for an alpha filter. 

Layer implementing an alpha filter. 

class
keras_spiking.layers.
KerasSpikingCell
(*args, **kwargs)[source]¶ Base class for RNN cells in KerasSpiking.
The important feature of this class is that it allows cells to define a different implementation to be used in training versus inference.
 Parameters
 sizeint or tuple of int or
tf.TensorShape
Input/output shape of the layer (not including batch/time dimensions).
 state_sizeint or tuple of int or
tf.TensorShape
Shape of the cell state. If
None
, usesize
. dtfloat
Length of time (in seconds) represented by one time step. If None, uses
keras_spiking.default.dt
(which is 0.001 seconds by default). always_use_inferencebool
If True, this layer will use its
call_inference
behaviour during training, rather thancall_training
. kwargsdict
Passed on to tf.keras.layers.Layer.
 sizeint or tuple of int or

call
(inputs, states, training=None)[source]¶ Call function that defines a different forward pass during training versus inference.

class
keras_spiking.layers.
KerasSpikingLayer
(*args, **kwargs)[source]¶ Base class for KerasSpiking layers.
The main role of this class is to wrap a KerasSpikingCell in a
tf.keras.layers.RNN
. Parameters
 dtfloat
Length of time (in seconds) represented by one time step. If None, uses
keras_spiking.default.dt
(which is 0.001 seconds by default). return_sequencesbool
Whether to return the full sequence of output values (default), or just the values on the last timestep.
 return statebool
Whether to return the state in addition to the output.
 statefulbool
If False (default), each time the layer is called it will begin from the same initial conditions. If True, each call will resume from the terminal state of the previous call (
my_layer.reset_states()
can be called to reset the state to initial conditions). unrollbool
If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed up computations, although it tends to be more memoryintensive. Unrolling is only suitable for short sequences.
 time_majorbool
The shape format of the input and output tensors. If True, the inputs and outputs will be in shape
(timesteps, batch, ...)
, whereas in the False case, it will be(batch, timesteps, ...)
. Usingtime_major=True
is a bit more efficient because it avoids transposes at the beginning and end of the layer calculation. However, most TensorFlow data is batchmajor, so by default this layer accepts input and emits output in batchmajor form. kwargsdict
Passed on to tf.keras.layers.Layer.

build
(input_shape)[source]¶ Builds the RNN/cell layers contained within this layer.
Notes
This method should not be called manually; rather, use the implicit layer callable behaviour (like
my_layer(inputs)
), which will apply this method with some additional bookkeeping.

call
(inputs, training=None, initial_state=None, constants=None)[source]¶ Apply this layer to inputs.
Notes
This method should not be called manually; rather, use the implicit layer callable behaviour (like
my_layer(inputs)
), which will apply this method with some additional bookkeeping.

class
keras_spiking.
SpikingActivationCell
(*args, **kwargs)[source]¶ RNN cell for converting an arbitrary activation function to a spiking equivalent.
Neurons will spike at a rate proportional to the output of the base activation function. For example, if the activation function is outputting a value of 10, then the wrapped SpikingActivationCell will output spikes at a rate of 10Hz (i.e., 10 spikes per 1 simulated second, where 1 simulated second is equivalent to
1/dt
time steps). Each spike will have height1/dt
(so that the integral of the spiking output will be the same as the integral of the base activation output). Note that if the base activation is outputting a negative value then the spikes will have height1/dt
. Multiple spikes per timestep are also possible, in which case the output will ben/dt
(wheren
is the number of spikes). Parameters
 sizeint or tuple of int or
tf.TensorShape
Input/output shape of the layer (not including batch/time dimensions).
 activationcallable
Activation function to be converted to spiking equivalent.
 dtfloat
Length of time (in seconds) represented by one time step. If None, uses
keras_spiking.default.dt
(which is 0.001 seconds by default). seedint
Seed for random state initialization.
 spiking_aware_trainingbool
If True (default), use the spiking activation function for the forward pass and the base activation function for the backward pass. If False, use the base activation function for the forward and backward pass during training.
 kwargsdict
Passed on to tf.keras.layers.Layer.
 sizeint or tuple of int or
Notes
This cell needs to be wrapped in a
tf.keras.layers.RNN
, likemy_layer = tf.keras.layers.RNN( keras_spiking.SpikingActivationCell(size=10, activation=tf.nn.relu) )

get_initial_state
(inputs=None, batch_size=None, dtype=None)[source]¶ Set up initial spiking state.
Initial state is chosen from a uniform distribution, seeded based on the seed passed on construction (if one was given).
Note: state will be initialized automatically, user does not need to call this themselves.

call_training
(inputs, states)[source]¶ Compute layer output when training and
always_use_inference=False
.

call_inference
(inputs, states)¶ Compute spiking output, with custom gradient for spiking aware training.
 Parameters
 inputs
tf.Tensor
Input to the activation function.
 voltage
tf.Tensor
Spiking voltage state.
 inputs
 Returns
 spikes
tf.Tensor
Output spike values (0 or
n/dt
for each element ininputs
, wheren
is the number of spikes). voltage
tf.Tensor
Updated voltage state.
 gradcallable
Custom gradient function for spiking aware training.
 spikes

class
keras_spiking.
SpikingActivation
(*args, **kwargs)[source]¶ Layer for converting an arbitrary activation function to a spiking equivalent.
Neurons will spike at a rate proportional to the output of the base activation function. For example, if the activation function is outputting a value of 10, then the wrapped SpikingActivationCell will output spikes at a rate of 10Hz (i.e., 10 spikes per 1 simulated second, where 1 simulated second is equivalent to
1/dt
time steps). Each spike will have height1/dt
(so that the integral of the spiking output will be the same as the integral of the base activation output). Note that if the base activation is outputting a negative value then the spikes will have height1/dt
. Multiple spikes per timestep are also possible, in which case the output will ben/dt
(wheren
is the number of spikes).When applying this layer to an input, make sure that the input has a time axis (the
time_major
option controls whether it comes before or after the batch axis). The spiking output will be computed along the time axis. The number of simulation timesteps will depend on the length of that time axis. The number of timesteps does not need to be the same during training/evaluation/inference. In particular, it may be more efficient to use one timestep during training and multiple timesteps during inference (often withspiking_aware_training=False
, andapply_during_training=False
on anyLowpass
layers). Parameters
 activationcallable
Activation function to be converted to spiking equivalent.
 dtfloat
Length of time (in seconds) represented by one time step. If None, uses
keras_spiking.default.dt
(which is 0.001 seconds by default). seedint
Seed for random state initialization.
 spiking_aware_trainingbool
If True (default), use the spiking activation function for the forward pass and the base activation function for the backward pass. If False, use the base activation function for the forward and backward pass during training.
 return_sequencesbool
Whether to return the full sequence of output spikes (default), or just the spikes on the last timestep.
 return statebool
Whether to return the state in addition to the output.
 statefulbool
If False (default), each time the layer is called it will begin from the same initial conditions. If True, each call will resume from the terminal state of the previous call (
my_layer.reset_states()
can be called to reset the state to initial conditions). unrollbool
If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed up computations, although it tends to be more memoryintensive. Unrolling is only suitable for short sequences.
 time_majorbool
The shape format of the input and output tensors. If True, the inputs and outputs will be in shape
(timesteps, batch, ...)
, whereas in the False case, it will be(batch, timesteps, ...)
. Usingtime_major=True
is a bit more efficient because it avoids transposes at the beginning and end of the layer calculation. However, most TensorFlow data is batchmajor, so by default this layer accepts input and emits output in batchmajor form. kwargsdict
Passed on to tf.keras.layers.Layer.
Notes
This is equivalent to
tf.keras.layers.RNN(SpikingActivationCell(...) ...)
, it just takes care of the RNN construction automatically.

class
keras_spiking.
LowpassCell
(*args, **kwargs)[source]¶ RNN cell for a lowpass filter.
The initial filter state and filter time constants are both trainable parameters. However, if
apply_during_training=False
then the parameters are not part of the training loop, and so will never be updated. Parameters
 sizeint or tuple of int or
tf.TensorShape
Input/output shape of the layer (not including batch/time dimensions).
 tau_initializerfloat or str or
tf.keras.initializers.Initializer
Initial value of time constant of filter (in seconds). Passing a float will initialize it to that value, or any standard Keras initializer can be used.
 dtfloat
Length of time (in seconds) represented by one time step. If None, uses
keras_spiking.default.dt
(which is 0.001 seconds by default). apply_during_trainingbool
If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g.,
SpikingActivation
, since if the activations are not spiking during training then we often don’t need to filter them either). level_initializerstr or
tf.keras.initializers.Initializer
Initializer for filter state.
 initial_level_constraintstr or
tf.keras.constraints.Constraint
Constraint for
initial_level
. tau_constraintstr or
tf.keras.constraints.Constraint
Constraint for
tau
. For example,Mean
will share the same time constant across all of the lowpass filters. The time constant is always clipped to be positive in the forward pass for numerical stability. kwargsdict
Passed on to tf.keras.layers.Layer.
 sizeint or tuple of int or
Notes
This cell needs to be wrapped in a
tf.keras.layers.RNN
, likemy_layer = tf.keras.layers.RNN( keras_spiking.LowpassCell(size=10, tau_initializer=0.01) )

call_inference
(inputs, states)[source]¶ Compute layer output when testing or
always_use_inference=True
.

class
keras_spiking.
Lowpass
(*args, **kwargs)[source]¶ Layer implementing a lowpass filter.
The impulseresponse function (time domain) and transfer function are:
\[\begin{split}h(t) &= (1 / \tau) \exp(t / \tau) \\ H(s) &= \frac{1}{\tau s + 1}\end{split}\]The initial filter state and filter time constants are both trainable parameters. However, if
apply_during_training=False
then the parameters are not part of the training loop, and so will never be updated.When applying this layer to an input, make sure that the input has a time axis (the
time_major
option controls whether it comes before or after the batch axis). Parameters
 tau_initializerfloat or str or
tf.keras.initializers.Initializer
Initial value of time constant of filter (in seconds). Passing a float will initialize it to that value, or any standard Keras initializer can be used.
 dtfloat
Length of time (in seconds) represented by one time step. If None, uses
keras_spiking.default.dt
(which is 0.001 seconds by default). apply_during_trainingbool
If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g.,
SpikingActivation
, since if the activations are not spiking during training then we often don’t need to filter them either). level_initializerstr or
tf.keras.initializers.Initializer
Initializer for filter state.
 initial_level_constraintstr or
tf.keras.constraints.Constraint
Constraint for
initial_level
. tau_constraintstr or
tf.keras.constraints.Constraint
Constraint for
tau
. For example,Mean
will share the same time constant across all of the lowpass filters. The time constant is always clipped to be positive in the forward pass for numerical stability. return_sequencesbool
Whether to return the full sequence of filtered output (default), or just the output on the last timestep.
 return statebool
Whether to return the state in addition to the output.
 statefulbool
If False (default), each time the layer is called it will begin from the same initial conditions. If True, each call will resume from the terminal state of the previous call (
my_layer.reset_states()
can be called to reset the state to initial conditions). unrollbool
If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed up computations, although it tends to be more memoryintensive. Unrolling is only suitable for short sequences.
 time_majorbool
The shape format of the input and output tensors. If True, the inputs and outputs will be in shape
(timesteps, batch, ...)
, whereas in the False case, it will be(batch, timesteps, ...)
. Usingtime_major=True
is a bit more efficient because it avoids transposes at the beginning and end of the layer calculation. However, most TensorFlow data is batchmajor, so by default this layer accepts input and emits output in batchmajor form. kwargsdict
Passed on to tf.keras.layers.Layer.
 tau_initializerfloat or str or
Notes
This is equivalent to
tf.keras.layers.RNN(LowpassCell(...) ...)
, it just takes care of the RNN construction automatically.

class
keras_spiking.
AlphaCell
(*args, **kwargs)[source]¶ RNN cell for an alpha filter.
The initial filter state and filter time constants are both trainable parameters. However, if
apply_during_training=False
then the parameters are not part of the training loop, and so will never be updated. Parameters
 sizeint or tuple of int or
tf.TensorShape
Input/output shape of the layer (not including batch/time dimensions).
 tau_initializerfloat or str or
tf.keras.initializers.Initializer
Initial value of time constant of filter (in seconds). Passing a float will initialize it to that value, or any standard Keras initializer can be used.
 dtfloat
Length of time (in seconds) represented by one time step. If None, uses
keras_spiking.default.dt
(which is 0.001 seconds by default). apply_during_trainingbool
If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g.,
SpikingActivation
, since if the activations are not spiking during training then we often don’t need to filter them either). level_initializerstr or
tf.keras.initializers.Initializer
Initializer for filter state.
 initial_level_constraintstr or
tf.keras.constraints.Constraint
Constraint for
initial_level
. tau_constraintstr or
tf.keras.constraints.Constraint
Constraint for
tau
. For example,Mean
will share the same time constant across all of the lowpass filters. The time constant is always clipped to be positive in the forward pass for numerical stability. kwargsdict
Passed on to tf.keras.layers.Layer.
 sizeint or tuple of int or
Notes
This cell needs to be wrapped in a
tf.keras.layers.RNN
, likemy_layer = tf.keras.layers.RNN( keras_spiking.AlphaCell(size=10, tau_initializer=0.01) )

call_inference
(inputs, states)[source]¶ Compute layer output when testing or
always_use_inference=True
.

class
keras_spiking.
Alpha
(*args, **kwargs)[source]¶ Layer implementing an alpha filter.
The impulseresponse function (time domain) and transfer function are:
\[\begin{split}h(t) &= (t / \tau^2) \exp(t / \tau) \\ H(s) &= \frac{1}{(\tau s + 1)^2}\end{split}\]The initial filter state and filter time constants are both trainable parameters. However, if
apply_during_training=False
then the parameters are not part of the training loop, and so will never be updated.When applying this layer to an input, make sure that the input has a time axis (the
time_major
option controls whether it comes before or after the batch axis). Parameters
 tau_initializerfloat or str or
tf.keras.initializers.Initializer
Initial value of time constant of filter (in seconds). Passing a float will initialize it to that value, or any standard Keras initializer can be used.
 dtfloat
Length of time (in seconds) represented by one time step. If None, uses
keras_spiking.default.dt
(which is 0.001 seconds by default). apply_during_trainingbool
If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g.,
SpikingActivation
, since if the activations are not spiking during training then we often don’t need to filter them either). level_initializerstr or
tf.keras.initializers.Initializer
Initializer for filter state.
 initial_level_constraintstr or
tf.keras.constraints.Constraint
Constraint for
initial_level
. tau_constraintstr or
tf.keras.constraints.Constraint
Constraint for
tau
. For example,Mean
will share the same time constant across all of the lowpass filters. The time constant is always clipped to be positive in the forward pass for numerical stability. return_sequencesbool
Whether to return the full sequence of filtered output (default), or just the output on the last timestep.
 return statebool
Whether to return the state in addition to the output.
 statefulbool
If False (default), each time the layer is called it will begin from the same initial conditions. If True, each call will resume from the terminal state of the previous call (
my_layer.reset_states()
can be called to reset the state to initial conditions). unrollbool
If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed up computations, although it tends to be more memoryintensive. Unrolling is only suitable for short sequences.
 time_majorbool
The shape format of the input and output tensors. If True, the inputs and outputs will be in shape
(timesteps, batch, ...)
, whereas in the False case, it will be(batch, timesteps, ...)
. Usingtime_major=True
is a bit more efficient because it avoids transposes at the beginning and end of the layer calculation. However, most TensorFlow data is batchmajor, so by default this layer accepts input and emits output in batchmajor form. kwargsdict
Passed on to tf.keras.layers.Layer.
 tau_initializerfloat or str or
Notes
This is equivalent to
tf.keras.layers.RNN(AlphaCell(...) ...)
, it just takes care of the RNN construction automatically.
Regularizers¶
Regularization methods designed to work with spiking layers.
A regularizer that penalizes values that fall outside a range. 

A version of 

A version of 

A version of 

A regularizer that penalizes a percentile of a tensor. 

class
keras_spiking.regularizers.
RangedRegularizer
(target=0, regularizer=<keras.regularizers.L1L2 object>)[source]¶ A regularizer that penalizes values that fall outside a range.
This allows regularized values to fall anywhere within the range, as opposed to standard regularizers that penalize any departure from some fixed point.
 Parameters
 targetfloat or tuple
The value that we want the regularized outputs to be driven towards. Can be a float, in which case all outputs will be driven towards that value, or a tuple specifying a range
(min, max)
, in which case outputs outside that range will be driven towards that range (but outputs within the range will not be penalized). regularizer: ``tf.keras.regularizers.Regularizer``
Regularization penalty that will be applied to the outputs with respect to
target
.

class
keras_spiking.regularizers.
L1L2
(l1=0.0, l2=0.0, target=0, **kwargs)[source]¶ A version of
tf.keras.regularizers.L1L2
that allows the user to specify a nonzero target output. Parameters
 l1float
Weight on L1 regularization penalty.
 l2float
Weight on L2 regularization penalty.
 targetfloat or tuple
The value that we want the regularized outputs to be driven towards. Can be a float, in which case all outputs will be driven towards that value, or a tuple specifying a range
(min, max)
, in which case outputs outside that range will be driven towards that range (but outputs within the range will not be penalized).

class
keras_spiking.regularizers.
L1
(l1=0.01, target=0, **kwargs)[source]¶ A version of
tf.keras.regularizers.L1
that allows the user to specify a nonzero target output. Parameters
 l1float
Weight on L1 regularization penalty.
 targetfloat or tuple
The value that we want the regularized outputs to be driven towards. Can be a float, in which case all outputs will be driven towards that value, or a tuple specifying a range
(min, max)
, in which case outputs outside that range will be driven towards that range (but outputs within the range will not be penalized).

class
keras_spiking.regularizers.
L2
(l2=0.01, target=0, **kwargs)[source]¶ A version of
tf.keras.regularizers.L2
that allows the user to specify a nonzero target output. Parameters
 l2float
Weight on L2 regularization penalty.
 targetfloat or tuple
The value that we want the regularized outputs to be driven towards. Can be a float, in which case all outputs will be driven towards that value, or a tuple specifying a range
(min, max)
, in which case outputs outside that range will be driven towards that range (but outputs within the range will not be penalized).

class
keras_spiking.regularizers.
Percentile
(percentile=100, axis=0, target=0, l1=0, l2=0)[source]¶ A regularizer that penalizes a percentile of a tensor.
This regularizer finds the requested
percentile
of the data over theaxis
, and then applies a regularizer to the percentile values with respect totarget
. This can be useful as it is makes the computed regularization penalty more invariant to outliers. Parameters
 percentilefloat
Percentile to compute over the
axis
. Defaults to 100, which is equivalent to taking the maximum across the specifiedaxis
.Note
For
percentile != 100
, requires tensorflowprobability. axisint or tuple of int
Axis or axes to take the percentile over.
 targetfloat or tuple
The value that we want the regularized outputs to be driven towards. Can be a float, in which case all outputs will be driven towards that value, or a tuple specifying a range
(min, max)
, in which case outputs outside that range will be driven towards that range (but outputs within the range will not be penalized). l1float
Weight on L1 regularization penalty applied to percentiles.
 l2float
Weight on L2 regularization penalty applied to percentiles.
Examples
In the following example, we use
Percentile
to ensure the neuron activities (a.k.a., firing rates) fall in the desired range of 510 Hz when computing the product of two inputs.train_x = np.random.uniform(1, 1, size=(1024 * 100, 2)) train_y = train_x[:, :1] * train_x[:, 1:] test_x = np.random.uniform(1, 1, size=(128, 2)) test_y = test_x[:, :1] * test_x[:, 1:] # train using one timestep, to speed things up train_seq = train_x[:, None] # test using 10 timesteps n_steps = 10 test_seq = np.tile(test_x[:, None], (1, n_steps, 1)) inp = x = tf.keras.Input((None, 2)) x = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(50))(x) x = spikes = keras_spiking.SpikingActivation( "relu", dt=1, activity_regularizer=keras_spiking.regularizers.Percentile( target=(5, 10), l2=0.01 ), )(x) x = tf.keras.layers.GlobalAveragePooling1D()(x) x = tf.keras.layers.Dense(1)(x) model = tf.keras.Model(inp, (x, spikes)) model.compile( # note: we use a dict to specify loss/metrics because we only want to # apply these to the final dense output, not the spike layer optimizer="rmsprop", loss={"dense_1": "mse"}, metrics={"dense_1": "mae"} ) model.fit(train_seq, train_y, epochs=5) outputs, spikes = model.predict(test_seq) # estimate rates by averaging over time rates = spikes.mean(axis=1) max_rates = rates.max(axis=0) print("Max rates: %s, %s" % (max_rates.mean(), max_rates.std())) error = np.mean(np.abs(outputs  test_y)) print("MAE: %s" % (error,))
Constraints¶
Custom constraints for weight tensors in Keras models.
Constrains weight tensors to be their mean. 
Callbacks¶
Callbacks for use with KerasSpiking models.
A callback for updating Layer 

class
keras_spiking.callbacks.
DtScheduler
(dt, scheduler, verbose=False)[source]¶ A callback for updating Layer
dt
attributes during training.This uses the same scheduler interface as TensorFlow’s learning rate schedulers, so any of those builtin schedules can be used to adjust
dt
, or a custom function implementing the same interface.When using this functionality,
dt
should be initialized as atf.Variable
, and that Variable should be passed as thedt
parameter to any Layers that should be affected by this callback.For example:
dt = tf.Variable(1.0) inp = tf.keras.Input((None, 10)) x = keras_spiking.SpikingActivation("relu", dt=dt)(inp) x = keras_spiking.Lowpass(0.1, dt=dt)(x) model = tf.keras.Model(inp, x) callback = keras_spiking.callbacks.DtScheduler( dt, tf.optimizers.schedules.ExponentialDecay( 1.0, decay_steps=5, decay_rate=0.9 ) ) model.compile(loss="mse", optimizer="sgd") model.fit( np.ones((100, 2, 10)), np.ones((100, 2, 10)), epochs=10, batch_size=20, callbacks=[callback], )
 Parameters
 dt
tf.Variable
Variable representing
dt
that has been passed to other Layers. scheduler
tf.optimizers.schedules.LearningRateSchedule
A schedule class that will update
dt
based on the training step (one training step is one minibatch worth of training). verbosebool
If True, print out some information about
dt
updates during training.
 dt
Notes
Because Variable values persist over time, any changes made to
dt
by this callback will persist after training completes. For example, if you callfit
with this callback and thenpredict
later on, thatpredict
call will be using the lastdt
value set by this callback.
Configuration¶
Configuration options for KerasSpiking layers.
Manages the default parameter values for KerasSpiking layers. 

class
keras_spiking.config.
DefaultManager
(dt=0.001)[source]¶ Manages the default parameter values for KerasSpiking layers.
 Parameters
 dtfloat
Length of time (in seconds) represented by one time step. Defaults to 0.001s.
Notes
Do not instantiate this class directly, instead access it through
keras_spiking.default
.
Energy estimation¶
Estimate energy usage on various devices for Keras models.
Compute statistics and device energy estimates for a Keras model. 

Fallback for computing stats on an unknown layer (assumes none). 

Compute activation layer stats. 

Compute 

Compute 

Compute 

Compute 

class
keras_spiking.
ModelEnergy
(model, example_data=None)[source]¶ Compute statistics and device energy estimates for a Keras model.
Computes the following statistics on each layer:
“connections”: The number of connections from all input elements to all activation units. The number of synaptic operations per second (“synops”) will be computed by multiplying this number by the average firing rate of the input to the layer.
“neurons”: The number of neuron updates per timestep performed by the layer. The number of neuron updates per inference will be computed by multiplying this number by the number of timesteps per inference.
Using expected average firing rates for each layer in the network, along with the above statistics, this class can estimate the energy usage on one of the following types of devices (see
total_energy
andsummary
):“cpu”: Estimate energy usage on a CPU (Intel i74960X), assuming each synop/neuron update is one MAC [1].
“gpu”: Estimate energy usage on a GPU (Nvidia GTX Titan Black), assuming each synop/neuron update is one MAC [1]. Note that this assumes significant parallelism (e.g., inputs being processed in large batches).
“arm”: Estimate energy usage on an ARM CortexA, assuming each synop/neuron update is one MAC [1].
“loihi”: Estimate energy usage on the Intel Loihi chip [2].
“spinnaker” and “spinnaker2”: Estimate energy usage on SpiNNaker or SpiNNaker 2 [3].
Note: on nonspiking devices (“cpu”/”gpu”/”arm”) this assumes the model is being run as a traditional nonspiking ANN (computing every synapse each timestep), not taking advantage of spikebased computation. This estimate is therefore independent of
example_data
. On spiking devices (“loihi”/”spinnaker1”/”spinnaker2”), we assume that the model has been fully converted to a spiking implementation in some way, even ifmodel
contains nonspiking elements.For example:
inp = tf.keras.Input(784) dense = tf.keras.layers.Dense(units=128, activation="relu")(inp) model = tf.keras.Model(inp, dense) energy = keras_spiking.ModelEnergy(model) energy.summary(line_length=80)
Layer (type) Output shape Param #Conn #Neuron #J/inf (cpu)  input_3 (InputLayer)[(None, 784)] 0 0 0 0 dense_2 (Dense)  (None, 128) 100480100352 128 0.00086 ====================================================================== Total energy per inference [Joules/inf] (cpu): 8.64e04 ...
Additional devices or different energy assumptions for a given device can be added with
register_device
, e.g.keras_spiking.ModelEnergy.register_device( "mycpu", energy_per_synop=1e10, energy_per_neuron=2e9, spiking=False ) energy.summary( columns=("name", "energy cpu", "energy mycpu"), line_length=80 )
Layer (type) J/inf (cpu)J/inf (mycpu)  input_3 (InputLayer) 0 0 dense_2 (Dense)  0.00086 1e05 =============================================== Total energy per inference [Joules/inf] (cpu): 8.64e04 Total energy per inference [Joules/inf] (mycpu): 1.03e05 ...
 Parameters
 model
tf.keras.Model
The model to compute statistics and energy estimates for.
 example_dataarray_like
Input used to estimate average firing rates of each layer (used to estimate the number of synaptic events). It is passed directly to
model.predict
(seetf.keras.Model.predict
for all acceptable types of input data). This is required to estimate energy on spiking devices, but does not affect nonspiking devices.
 model
Notes
It is important to keep in mind that actual power usage will be heavily dependent on the specific details of the underlying software and hardware implementation. The numbers provided by
ModelEnergy
should be taken as very rough estimates only, and they rely on a number of assumptions:Device specifications: In order to estimate the energy used by a model on a particular device, we need to know how much energy is used per synaptic operation/neuron update. We rely on published data for these numbers (see our sources below). Energy numbers in practice can differ significantly from published results.
Overhead: We do not account for any overhead in the energy estimates (e.g., the cost of transferring data on and off a device). We only estimate the energy usage of internal model computations (synaptic operations and neuron updates). In practice, overhead can be a significant contributor to the energy usage of a model.
Spiking implementation: We assume that the model being estimated has been fully converted to a spiking implementation when estimating the energy usage on a spiking device (even if the input model has nonspiking elements). For example, if the model contains
tf.keras.layers.Activation("relu")
layers (nonspiking), we assume that on a spiking device those layers will be converted to something equivalent tokeras_spiking.SpikingActivation("relu")
, and that any connecting layers (e.g.tf.keras.layers.Dense
) are applied in an eventbased fashion (i.e., processing only occurs when the input layer emits a spike). In practice, it is not trivial to map a neural network to a spiking device in this way, and implementation details can significantly affect energy usage. [Nengo](https://www.nengo.ai/nengo/) and [NengoDL](https://www.nengo.ai/nengodl/) are designed to make this easier.
References
 1(1,2,3)
Degnan, Brian, Bo Marr, and Jennifer Hasler. “Assessing trends in performance per watt for signal processing applications.” IEEE Transactions on Very Large Scale Integration (VLSI) Systems 24.1 (2015): 5866. https://ieeexplore.ieee.org/abstract/document/7054508
 2
Davies, Mike, et al. “Loihi: A neuromorphic manycore processor with onchip learning.” IEEE Micro 38.1 (2018): 8299. https://redwood.berkeley.edu/wpcontent/uploads/2021/08/Davies2018.pdf
 3
Höppner, Sebastian, et al. “Dynamic power management for neuromorphic manycore systems.” IEEE Transactions on Circuits and Systems I: Regular Papers 66.8 (2019): 29732986. https://arxiv.org/abs/1903.08941

classmethod
compute_layer_stats
(layer, node=None)[source]¶ Compute statistics for a given layer.
Examples
inp = tf.keras.Input([None, 10]) layer = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(5, "relu")) model = tf.keras.Model(inp, [layer(inp)]) print(keras_spiking.ModelEnergy.compute_layer_stats(layer))
{'connections': 50, 'neurons': 5, 'spiking': False}

classmethod
register_layer
(layer_class)[source]¶ Decorator to register a statistic calculator for a layer.
The input to the decorated function is a node from a layer for which we want to compute statistics.
The decorated function should return a dictionary with the following entries:
“connections”: The number of connections from all input elements to all activation units. Defaults to 0.
“neurons”: The number of neuron updates per timestep performed by this layer. Defaults to 0.
“spiking”: Whether or not this layer could be implemented on a spiking device asis (e.g.
tf.keras.layers.ReLU
returnsspiking=False
because nonspiking nonlinearities like ReLU can’t be directly implemented on a spiking device). Defaults to True.
Examples
If we know that the
tf.keras.layers.UpSampling2D
layer uses one synaptic operation per output element, we can register the following function:@keras_spiking.ModelEnergy.register_layer(tf.keras.layers.UpSampling2D) def upsampling2d_stats(node): # note: ignore the batch dimension when computing output size output_size = np.prod(node.output_shapes[1:]) return {"connections": output_size, "neurons": 0} # use our registered stat calculator inp = tf.keras.Input([4, 4, 3]) layer = tf.keras.layers.UpSampling2D(size=(2, 2)) model = tf.keras.Model(inp, [layer(inp)]) print(keras_spiking.ModelEnergy.compute_layer_stats(layer))
{'connections': 192, 'neurons': 0}
We see that the synaptic operations is 192, which equals the number of pixels in the upsampled image size of
(8, 8)
, times the number of channels (3).

classmethod
register_device
(device_name, energy_per_synop, energy_per_neuron, spiking)[source]¶ Register a new device type for estimating energy consumption.
 Parameters
 device_namestr
The string to use to refer to the device.
 energy_per_synopfloat
The energy (in Joules) used by a single synaptic operation on the device.
 energy_per_neuronfloat
The energy (in Joules) used by a single neuron update on the device.
 spikingbool
Whether the device is spiking (eventbased), and thus only computes synaptic updates for incoming spikes, rather than on every timestep.

layer_energy
(layer, device, timesteps_per_inference=1, dt=None)[source]¶ Estimate the energy used by one layer.
 Parameters
 layer
tf.keras.layers.Layer
Layer to estimate energy for. Note that if the same layer is being reused multiple times in the model, this will return the total energy summed over all the applications.
 devicestr
Device to estimate energy for. Can be a supported device (see
ModelEnergy
for a list), or another device added withModelEnergy.register_device
. timesteps_per_inferenceint
Timesteps used per inference (for example, if the model is classifying images and we want to present each image for 10 timesteps).
 dtfloat
The length of one timestep, in seconds, used by the device. Used to compute the number of synaptic events based on the firing rates (in Hz). Can differ from the
dt
used on anykeras_spiking
layers in the model. If None, useskeras_spiking.default.dt
(which is 0.001 seconds by default).
 layer
 Returns
 synop_energyfloat
Estimated energy used (in Joules) for synaptic computations per inference.
 neuron_energyfloat
Estimated energy used (in Joules) for neuron updates per inference.

total_energy
(device, timesteps_per_inference=1, dt=None)[source]¶ Estimate the energy usage for a whole model.
 Parameters
 devicestr
Device to estimate energy for. Can be a supported device (see
ModelEnergy
for a list), or another device added withModelEnergy.register_device
. timesteps_per_inferenceint
Timesteps used per inference (for example, if the model is classifying images and we want to present each image for 10 timesteps).
 dtfloat
The length of one timestep, in seconds, used by the device. Used to compute the number of synaptic events based on the firing rates (in Hz). Can differ from the
dt
used on anykeras_spiking
layers in the model. If None, useskeras_spiking.default.dt
(which is 0.001 seconds by default).
 Returns
 energyfloat
Total estimated energy used by the model (in Joules) per inference.

summary
(columns=('name', 'output_shape', 'params', 'connections', 'neurons', 'energy cpu'), timesteps_per_inference=1, dt=None, line_length=98, print_warnings=True)[source]¶ Print a perlayer summary of computation statistics and energy estimates.
 Parameters
 columnslist or tuple of string
Columns to display. Can be any combination of the following:
“name”: The layer name.
“output_shape”: The output shape of the layer.
“params”: The number of parameters in the layer.
“connections”: The number of synaptic connections from inputs to the neurons of this layer (see
ModelEnergy
for the definition of “connections”).“neurons”: The number of neuron updates performed by the layer each timestep (see
ModelEnergy
for the definition of “neuron update”).“rate”: The average input firing rate to the layer, in spikes per second. Note that this is only relevant for layers that perform synaptic operations; for other layers (e.g. an activation layer that gets input from a convolutional layer), this number has no effect.
“synop_energy <device>”: The estimated energy in Joules per inference used by the layer on <device> for synaptic operations.
“neuron_energy <device>”: The estimated energy in Joules per inference used by the layer on <device> for neuron updates.
“energy <device>”: The total estimated energy in Joules per inference used by the layer on <device>.
Here, <device> can be any of the supported devices (see
ModelEnergy
). Additional devices can be added withModelEnergy.register_device
. timesteps_per_inferenceint
Timesteps used per inference (for example, if the model is classifying images and we want to present each image for 10 timesteps).
 dtfloat
The length of one timestep, in seconds, used by the device. Used to compute the number of synaptic events based on the firing rates (in Hz). Can differ from the
dt
used on anykeras_spiking
layers in the model. If None, useskeras_spiking.default.dt
(which is 0.001 seconds by default). line_lengthint
The length of each printed line.
 print_warningsbool
Set to False to disable the warnings regarding assumptions made in the energy calculations.

summary_string
(columns=('name', 'output_shape', 'params', 'connections', 'neurons', 'energy cpu'), timesteps_per_inference=1, dt=None, line_length=98, print_warnings=True)[source]¶ Returns a perlayer summary of computation statistics and energy estimates.
The same as
summary
, except returns the summary as a string, rather than printing it.For documentation on parameters and other features, see
summary
.

keras_spiking.model_energy.
layer_stats
(node, **_)[source]¶ Fallback for computing stats on an unknown layer (assumes none).

keras_spiking.model_energy.
spikingactivation_stats
(node)[source]¶ Compute
SpikingActivation
layer stats.

keras_spiking.model_energy.
timedistributed_stats
(node)[source]¶ Compute
TimeDistributed
layer stats.Calls
ModelEnergy.compute_layer_stats
on the wrapped layer.