Source code for keras_spiking.layers

"""
Components for building spiking models in Keras.
"""

import numpy as np
import tensorflow as tf
from tensorflow.python.framework import smart_cond


[docs]class SpikingActivationCell(tf.keras.layers.Layer): """ RNN cell for converting an arbitrary activation function to a spiking equivalent. Neurons will spike at a rate proportional to the output of the base activation function. For example, if the activation function is outputting a value of 10, then the wrapped SpikingActivationCell will output spikes at a rate of 10Hz (i.e., 10 spikes per 1 simulated second, where 1 simulated second is equivalent to ``1/dt`` time steps). Each spike will have height ``1/dt`` (so that the integral of the spiking output will be the same as the integral of the base activation output). Note that if the base activation is outputting a negative value then the spikes will have height ``-1/dt``. Multiple spikes per timestep are also possible, in which case the output will be ``n/dt`` (where ``n`` is the number of spikes). Notes ----- This cell needs to be wrapped in a ``tf.keras.layers.RNN``, like .. testcode:: my_layer = tf.keras.layers.RNN( keras_spiking.SpikingActivationCell(units=10, activation=tf.nn.relu) ) Parameters ---------- units : int Dimensionality of layer. activation : callable Activation function to be converted to spiking equivalent. dt : float Length of time (in seconds) represented by one time step. seed : int Seed for random state initialization. spiking_aware_training : bool If True (default), use the spiking activation function for the forward pass and the base activation function for the backward pass. If False, use the base activation function for the forward and backward pass during training. kwargs : dict Passed on to `tf.keras.layers.Layer <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer>`_. """ def __init__( self, units, activation, dt=0.001, seed=None, # TODO: should this default to True or False? spiking_aware_training=True, **kwargs ): super().__init__(**kwargs) self.units = units self.activation = tf.keras.activations.get(activation) self.dt = dt self.seed = seed self.spiking_aware_training = spiking_aware_training self.output_size = (units,) self.state_size = (units,)
[docs] def get_initial_state(self, inputs=None, batch_size=None, dtype=None): """ Set up initial spiking state. Initial state is chosen from a uniform distribution, seeded based on the seed passed on construction (if one was given). Note: state will be initialized automatically, user does not need to call this themselves. """ seed = ( tf.random.uniform((), maxval=np.iinfo(np.int32).max, dtype=tf.int32) if self.seed is None else self.seed ) # TODO: we could make the initial voltages trainable return tf.random.stateless_uniform( (batch_size, self.units), seed=(seed, seed), dtype=dtype )
[docs] def call(self, inputs, states, training=None): """ Compute layer output. """ if training is None: training = tf.keras.backend.learning_phase() voltage = states[0] return smart_cond.smart_cond( tf.logical_and(tf.cast(training, tf.bool), not self.spiking_aware_training), lambda: (self.activation(inputs), voltage), lambda: self._compute_spikes(inputs, voltage), )
@tf.custom_gradient def _compute_spikes(self, inputs, voltage): """ Compute spiking output, with custom gradient for spiking aware training. Parameters ---------- inputs : ``tf.Tensor`` Input to the activation function. voltage : ``tf.Tensor`` Spiking voltage state. Returns ------- spikes : ``tf.Tensor`` Output spike values (0 or ``n/dt`` for each element in ``inputs``, where ``n`` is the number of spikes). voltage : ``tf.Tensor`` Updated voltage state. grad : callable Custom gradient function for spiking aware training. """ with tf.GradientTape() as g: g.watch(inputs) rates = self.activation(inputs) voltage = voltage + rates * self.dt n_spikes = tf.floor(voltage) voltage -= n_spikes spikes = n_spikes / self.dt def grad(grad_spikes, grad_voltage): return ( g.gradient(rates, inputs) * grad_spikes, None, ) return (spikes, voltage), grad
[docs] def get_config(self): """Return config of layer (for serialization during model saving/loading).""" cfg = super().get_config() cfg.update( dict( units=self.units, activation=tf.keras.activations.serialize(self.activation), dt=self.dt, seed=self.seed, spiking_aware_training=self.spiking_aware_training, ) ) return cfg
[docs]class SpikingActivation(tf.keras.layers.Layer): """ Layer for converting an arbitrary activation function to a spiking equivalent. Neurons will spike at a rate proportional to the output of the base activation function. For example, if the activation function is outputting a value of 10, then the wrapped SpikingActivationCell will output spikes at a rate of 10Hz (i.e., 10 spikes per 1 simulated second, where 1 simulated second is equivalent to ``1/dt`` time steps). Each spike will have height ``1/dt`` (so that the integral of the spiking output will be the same as the integral of the base activation output). Note that if the base activation is outputting a negative value then the spikes will have height ``-1/dt``. Multiple spikes per timestep are also possible, in which case the output will be ``n/dt`` (where ``n`` is the number of spikes). Notes ----- This is equivalent to ``tf.keras.layers.RNN(SpikingActivationCell(...) ...)``, it just takes care of the RNN construction automatically. Parameters ---------- activation : callable Activation function to be converted to spiking equivalent. dt : float Length of time (in seconds) represented by one time step. seed : int Seed for random state initialization. spiking_aware_training : bool If True (default), use the spiking activation function for the forward pass and the base activation function for the backward pass. If False, use the base activation function for the forward and backward pass during training. return_sequences : bool Whether to return the last output in the output sequence (default), or the full sequence. return state : bool Whether to return the state in addition to the output. stateful : bool If False (default), each time the layer is called it will begin from the same initial conditions. If True, each call will resume from the terminal state of the previous call (``my_layer.reset_states()`` can be called to reset the state to initial conditions). unroll : bool If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed up computations, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. time_major : bool The shape format of the input and output tensors. If True, the inputs and outputs will be in shape ``(timesteps, batch, ...)``, whereas in the False case, it will be ``(batch, timesteps, ...)``. Using ``time_major=True`` is a bit more efficient because it avoids transposes at the beginning and end of the layer calculation. However, most TensorFlow data is batch-major, so by default this layer accepts input and emits output in batch-major form. kwargs : dict Passed on to `tf.keras.layers.Layer <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer>`_. """ def __init__( self, activation, dt=0.001, seed=None, spiking_aware_training=True, return_sequences=False, return_state=False, stateful=False, unroll=False, time_major=False, **kwargs ): super().__init__(**kwargs) self.activation = tf.keras.activations.get(activation) self.dt = dt self.seed = seed self.spiking_aware_training = spiking_aware_training self.return_sequences = return_sequences self.return_state = return_state self.stateful = stateful self.unroll = unroll self.time_major = time_major self.layer = None
[docs] def build(self, input_shapes): """ Builds the RNN/SpikingActivationCell layers contained within this layer. Notes ----- This method should not be called manually; rather, use the implicit layer callable behaviour (like ``my_layer(inputs)``), which will apply this method with some additional bookkeeping. """ super().build(input_shapes) # we initialize these here, rather than in ``__init__``, so that we can # determine ``units`` automatically self.layer = tf.keras.layers.RNN( SpikingActivationCell( activation=self.activation, units=input_shapes[-1], dt=self.dt, seed=self.seed, spiking_aware_training=self.spiking_aware_training, ), return_sequences=self.return_sequences, return_state=self.return_state, stateful=self.stateful, unroll=self.unroll, time_major=self.time_major, ) self.layer.build(input_shapes)
[docs] def call(self, inputs, training=None, initial_state=None, constants=None): """ Apply this layer to inputs. Notes ----- This method should not be called manually; rather, use the implicit layer callable behaviour (like ``my_layer(inputs)``), which will apply this method with some additional bookkeeping. """ return self.layer.call( inputs, training=training, initial_state=initial_state, constants=constants )
[docs] def reset_states(self, states=None): """ Reset the internal state of the layer (only necessary if ``stateful=True``). Parameters ---------- states : `~numpy.ndarray` Optional state array that can be used to override the values returned by `.SpikingActivationCell.get_initial_state`. """ self.layer.reset_states(states=states)
[docs] def get_config(self): """Return config of layer (for serialization during model saving/loading).""" cfg = super().get_config() cfg.update( dict( activation=tf.keras.activations.serialize(self.activation), dt=self.dt, seed=self.seed, spiking_aware_training=self.spiking_aware_training, return_sequences=self.return_sequences, return_state=self.return_state, stateful=self.stateful, unroll=self.unroll, time_major=self.time_major, ) ) return cfg
[docs]class LowpassCell(tf.keras.layers.Layer): """ RNN cell for a lowpass filter. The initial filter state and filter time constants are both trainable parameters. However, if ``apply_during_training=False`` then the parameters are not part of the training loop, and so will never be updated. Notes ----- This cell needs to be wrapped in a ``tf.keras.layers.RNN``, like .. testcode:: my_layer = tf.keras.layers.RNN(keras_spiking.LowpassCell(units=10, tau=0.01)) Parameters ---------- units : int Dimensionality of layer. tau : float Time constant of filter (in seconds). dt : float Length of time (in seconds) represented by one time step. apply_during_training : bool If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g., `.SpikingActivation`, since if the activations are not spiking during training then we often don't need to filter them either). level_initializer : str or ``tf.keras.initializers.Initializer`` Initializer for filter state. kwargs : dict Passed on to `tf.keras.layers.Layer <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer>`_. """ def __init__( self, units, tau, dt=0.001, # TODO: better name for this parameter? apply_during_training=True, level_initializer="zeros", **kwargs ): super().__init__(**kwargs) if tau <= 0: raise ValueError("tau must be a positive number") self.units = units self.tau = tau self.dt = dt self.apply_during_training = apply_during_training self.level_initializer = tf.initializers.get(level_initializer) self.state_size = units self.output_size = units # apply ZOH discretization tau = np.exp(-dt / tau) # compute inverse sigmoid of tau, so that when we apply the sigmoid # later we'll get the tau value specified self.smoothing_init = np.log(tau / (1 - tau))
[docs] def build(self, input_shapes): """Build parameters associated with this layer.""" super().build(input_shapes) self.initial_level = self.add_weight( name="initial_level", shape=(1, self.units), initializer=self.level_initializer, trainable=self.apply_during_training, ) self.smoothing = self.add_weight( name="level_smoothing", shape=(1, self.units), initializer=tf.initializers.constant( np.ones(self.units) * self.smoothing_init ), trainable=self.apply_during_training, )
[docs] def get_initial_state(self, inputs=None, batch_size=None, dtype=None): """Get initial filter state.""" return tf.tile(self.initial_level, (batch_size, 1))
[docs] def call(self, inputs, states, training=None): """ Apply this layer to inputs. Notes ----- This method should not be called manually; rather, use the implicit layer callable behaviour (like ``my_layer(inputs)``), which will apply this method with some additional bookkeeping. """ if training is None: training = tf.keras.backend.learning_phase() def apply(): smoothing = tf.nn.sigmoid(self.smoothing) x = (1 - smoothing) * inputs + smoothing * states[0] return x, (x,) return smart_cond.smart_cond( tf.logical_and(tf.cast(training, tf.bool), not self.apply_during_training), lambda: (inputs, states), apply, )
[docs] def get_config(self): """Return config of layer (for serialization during model saving/loading).""" config = super().get_config() config.update( dict( units=self.units, tau=self.tau, dt=self.dt, apply_during_training=self.apply_during_training, level_initializer=tf.keras.initializers.serialize( self.level_initializer ), ) ) return config
# TODO: reduce code duplication between this and SpikingActivation
[docs]class Lowpass(tf.keras.layers.Layer): """ Layer implementing a lowpass filter. The initial filter state and filter time constants are both trainable parameters. However, if ``apply_during_training=False`` then the parameters are not part of the training loop, and so will never be updated. Notes ----- This is equivalent to ``tf.keras.layers.RNN(LowpassCell(...) ...)``, it just takes care of the RNN construction automatically. Parameters ---------- tau : float Time constant of filter (in seconds). dt : float Length of time (in seconds) represented by one time step. apply_during_training : bool If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g., `.SpikingActivation`, since if the activations are not spiking during training then we often don't need to filter them either). level_initializer : str or ``tf.keras.initializers.Initializer`` Initializer for filter state. return_sequences : bool Whether to return the last output in the output sequence (default), or the full sequence. return state : bool Whether to return the state in addition to the output. stateful : bool If False (default), each time the layer is called it will begin from the same initial conditions. If True, each call will resume from the terminal state of the previous call (``my_layer.reset_states()`` can be called to reset the state to initial conditions). unroll : bool If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed up computations, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. time_major : bool The shape format of the input and output tensors. If True, the inputs and outputs will be in shape ``(timesteps, batch, ...)``, whereas in the False case, it will be ``(batch, timesteps, ...)``. Using ``time_major=True`` is a bit more efficient because it avoids transposes at the beginning and end of the layer calculation. However, most TensorFlow data is batch-major, so by default this layer accepts input and emits output in batch-major form. kwargs : dict Passed on to `tf.keras.layers.Layer <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer>`_. """ def __init__( self, tau, dt=0.001, apply_during_training=True, level_initializer="zeros", return_sequences=False, return_state=False, stateful=False, unroll=False, time_major=False, **kwargs ): super().__init__(**kwargs) self.tau = tau self.dt = dt self.apply_during_training = apply_during_training self.level_initializer = tf.keras.initializers.get(level_initializer) self.return_sequences = return_sequences self.return_state = return_state self.stateful = stateful self.unroll = unroll self.time_major = time_major self.layer = None
[docs] def build(self, input_shapes): """ Builds the RNN/SpikingActivationCell layers contained within this layer. Notes ----- This method should not be called manually; rather, use the implicit layer callable behaviour (like ``my_layer(inputs)``), which will apply this method with some additional bookkeeping. """ super().build(input_shapes) # we initialize these here, rather than in ``__init__``, so that we can # determine ``units`` automatically self.layer = tf.keras.layers.RNN( LowpassCell( units=input_shapes[-1], tau=self.tau, dt=self.dt, apply_during_training=self.apply_during_training, level_initializer=self.level_initializer, ), return_sequences=self.return_sequences, return_state=self.return_state, stateful=self.stateful, unroll=self.unroll, time_major=self.time_major, ) self.layer.build(input_shapes)
[docs] def call(self, inputs, training=None, initial_state=None, constants=None): """ Apply this layer to inputs. Notes ----- This method should not be called manually; rather, use the implicit layer callable behaviour (like ``my_layer(inputs)``), which will apply this method with some additional bookkeeping. """ return self.layer.call( inputs, training=training, initial_state=initial_state, constants=constants )
[docs] def reset_states(self, states=None): """ Reset the internal state of the layer (only necessary if ``stateful=True``). Parameters ---------- states : `~numpy.ndarray` Optional state array that can be used to override the values returned by `.SpikingActivationCell.get_initial_state`. """ self.layer.reset_states(states=states)
[docs] def get_config(self): """Return config of layer (for serialization during model saving/loading).""" cfg = super().get_config() cfg.update( dict( tau=self.tau, dt=self.dt, apply_during_training=self.apply_during_training, level_initializer=tf.keras.initializers.serialize( self.level_initializer ), return_sequences=self.return_sequences, return_state=self.return_state, stateful=self.stateful, unroll=self.unroll, time_major=self.time_major, ) ) return cfg