{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Optimizing spiking neural networks\n", "\n", "Almost all deep learning methods are based on [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent), which means that the network being optimized needs to be differentiable. Deep neural networks are usually built using rectified linear or sigmoid neurons, as these are differentiable nonlinearities. However, in biological neural modelling we often want to use spiking neurons, which are not differentiable. So the challenge is how to apply deep learning methods to spiking neural networks.\n", "\n", "A method for accomplishing this is presented in [Hunsberger and Eliasmith (2016)](https://arxiv.org/abs/1611.05141). The basic idea is to use a differentiable approximation of the spiking neurons during the training process, and the actual spiking neurons during inference. NengoDL will perform these transformations automatically if the user tries to optimize a model containing a spiking neuron model that has an equivalent, differentiable rate-based implementation. In this example we will use these techniques to develop a network to classify handwritten digits ([MNIST](http://yann.lecun.com/exdb/mnist/)) in a spiking convolutional network. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "import gzip\n", "import pickle\n", "from urllib.request import urlretrieve\n", "import zipfile\n", "\n", "import nengo\n", "import nengo_dl\n", "import tensorflow as tf\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we'll load the training data, the MNIST digits/labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "urlretrieve(\"http://deeplearning.net/data/mnist/mnist.pkl.gz\", \"mnist.pkl.gz\")\n", "with gzip.open(\"mnist.pkl.gz\") as f:\n", " train_data, _, test_data = pickle.load(f, encoding=\"latin1\")\n", "train_data = list(train_data)\n", "test_data = list(test_data)\n", "for data in (train_data, test_data):\n", " one_hot = np.zeros((data[0].shape[0], 10))\n", " one_hot[np.arange(data[0].shape[0]), data[1]] = 1\n", " data[1] = one_hot\n", "\n", "for i in range(3):\n", " plt.figure()\n", " plt.imshow(np.reshape(train_data[0][i], (28, 28)))\n", " plt.axis('off')\n", " plt.title(str(np.argmax(train_data[1][i])));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use [TensorNodes](https://www.nengo.ai/nengo-dl/tensor_node.html) to construct the network, as they allow us to easily include features such as convolutional connections. To make things even easier, we'll use `nengo_dl.tensor_layer`. This is a utility function for constructing `TensorNodes` that mimics the layer-based syntax of many deep learning packages (e.g. [tf.layers](https://www.tensorflow.org/api_docs/python/tf/layers)). The full documentation for this function can be found [here](https://www.nengo.ai/nengo-dl/tensor_node.html). \n", "\n", "`tensor_layer` is used to build a sequence of layers, where each layer takes the output of the previous layer and applies some transformation to it. So when we build a `tensor_layer` we pass it the input to the layer, the transformation we want to apply (expressed as a function that accepts a `tf.Tensor` as input and produces a `tf.Tensor` as output), and any arguments to that transformation function. `tensor_layer` also has optional `transform` and `synapse` parameters that set those respective values on the Connection from the previous layer to the one being constructed.\n", "\n", "Normally all signals in a Nengo model are (batched) vectors. However, certain layer functions, such as convolutional layers, may expect a different shape for their inputs. If the `shape_in` argument is specified for a `tensor_layer` then the inputs to\n", "the layer will automatically be reshaped to the given shape. Note that this shape does not include the batch dimension on the first axis, as that will be automatically set by the simulation.\n", "\n", "`tensor_layer` can also be passed a Nengo NeuronType, instead of a Tensor function. In this case `tensor_layer` will construct an Ensemble implementing the given neuron nonlinearity (the rest of the arguments work the same).\n", "\n", "Note that `tensor_layer` is just a syntactic wrapper for constructing `TensorNodes` or `Ensembles`; anything we build with a `tensor_layer` we could instead construct directly using those underlying components. `tensor_layer` just simplifies the construction of this common layer-based pattern." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with nengo.Network() as net:\n", " # set some default parameters for the neurons that will make\n", " # the training progress more smoothly\n", " net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([100])\n", " net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])\n", " neuron_type = nengo.LIF(amplitude=0.01)\n", " \n", " # we'll make all the nengo objects in the network\n", " # non-trainable. we could train them if we wanted, but they don't\n", " # add any representational power. note that this doesn't affect \n", " # the internal components of tensornodes, which will always be \n", " # trainable or non-trainable depending on the code written in \n", " # the tensornode.\n", " nengo_dl.configure_settings(trainable=False)\n", "\n", " # the input node that will be used to feed in input images \n", " inp = nengo.Node([0] * 28 * 28)\n", "\n", " # add the first convolutional layer\n", " x = nengo_dl.tensor_layer(\n", " inp, tf.layers.conv2d, shape_in=(28, 28, 1), filters=32,\n", " kernel_size=3)\n", "\n", " # apply the neural nonlinearity\n", " x = nengo_dl.tensor_layer(x, neuron_type)\n", "\n", " # add another convolutional layer\n", " x = nengo_dl.tensor_layer(\n", " x, tf.layers.conv2d, shape_in=(26, 26, 32), \n", " filters=64, kernel_size=3)\n", " x = nengo_dl.tensor_layer(x, neuron_type)\n", "\n", " # add a pooling layer\n", " x = nengo_dl.tensor_layer(\n", " x, tf.layers.average_pooling2d, shape_in=(24, 24, 64),\n", " pool_size=2, strides=2)\n", "\n", " # another convolutional layer\n", " x = nengo_dl.tensor_layer(\n", " x, tf.layers.conv2d, shape_in=(12, 12, 64),\n", " filters=128, kernel_size=3)\n", " x = nengo_dl.tensor_layer(x, neuron_type)\n", "\n", " # another pooling layer\n", " x = nengo_dl.tensor_layer(\n", " x, tf.layers.average_pooling2d, shape_in=(10, 10, 128),\n", " pool_size=2, strides=2)\n", "\n", " # linear readout\n", " x = nengo_dl.tensor_layer(x, tf.layers.dense, units=10)\n", "\n", " # we'll create two different output probes, one with a filter\n", " # (for when we're simulating the network over time and\n", " # accumulating spikes), and one without (for when we're\n", " # training the network using a rate-based approximation)\n", " out_p = nengo.Probe(x)\n", " out_p_filt = nengo.Probe(x, synapse=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we can construct a Simulator for that network." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "minibatch_size = 200\n", "sim = nengo_dl.Simulator(net, minibatch_size=minibatch_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we need to train this network to classify MNIST digits. First we load our input images and target labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# note that we need to add the time dimension (axis 1), which will be a\n", "# single step during training\n", "train_inputs = {inp: train_data[0][:, None, :]}\n", "train_targets = {out_p: train_data[1][:, None, :]}\n", "\n", "# when testing our network with spiking neurons we will need to run it over time,\n", "# so we repeat the input/target data for a number of timesteps. we're also going \n", "# to reduce the number of test images, just to speed up this example.\n", "n_steps = 30\n", "test_inputs = {inp: np.tile(test_data[0][:minibatch_size*2, None, :], \n", " (1, n_steps, 1))}\n", "test_targets = {out_p_filt: np.tile(test_data[1][:minibatch_size*2, None, :], \n", " (1, n_steps, 1))}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we need to define our objective (error) function. Because this is a classification task we'll use cross entropy, instead of the default mean squared error." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def objective(x, y): \n", " return tf.nn.softmax_cross_entropy_with_logits_v2(logits=x, labels=y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The last thing we need to specify is the optimizer. For this example we'll use RMSProp." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "opt = tf.train.RMSPropOptimizer(learning_rate=0.001)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to quantify the network's performance we will also define a classification error function (the percentage of test images classified incorrectly). We could use the cross entropy objective, but classification error is easier to interpret. Note that we use the output from the network on the final timestep (as we are simulating the network over time)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def classification_error(outputs, targets):\n", " return 100 * tf.reduce_mean(\n", " tf.cast(tf.not_equal(tf.argmax(outputs[:, -1], axis=-1), \n", " tf.argmax(targets[:, -1], axis=-1)),\n", " tf.float32))\n", "\n", "print(\"error before training: %.2f%%\" % sim.loss(\n", " test_inputs, test_targets, {out_p_filt: classification_error}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are ready to train the network. In order to keep this example relatively quick we are going to download some pretrained weights. However, if you'd like to run the training yourself set `do_training=True` below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "do_training = False\n", "if do_training:\n", " # run training\n", " sim.train(train_inputs, train_targets, opt, objective={out_p: objective}, \n", " n_epochs=10)\n", " \n", " # save the parameters to file\n", " sim.save_params(\"./mnist_params\")\n", "else:\n", " # download pretrained weights\n", " urlretrieve(\n", " \"https://drive.google.com/uc?export=download&id=18ErAZh0LFRaISLHDM-dJjjnzqvYfyjo0\", \n", " \"mnist_params.zip\")\n", " with zipfile.ZipFile(\"mnist_params.zip\") as f:\n", " f.extractall()\n", " \n", " # load parameters\n", " sim.load_params(\"./mnist_params\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can check the classification error again, with the trained parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"error after training: %.2f%%\" % sim.loss(\n", " test_inputs, test_targets, {out_p_filt: classification_error}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that the spiking neural network is achieving ~1% error, which is what we would expect for MNIST. `n_steps` could be increased to further improve performance, since we would get a more accurate measure of each spiking neuron's output.\n", "\n", "We can also plot some example outputs from the network, to see how it is performing over time." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sim.run_steps(n_steps, input_feeds={inp: test_inputs[inp][:minibatch_size]})\n", "\n", "for i in range(5):\n", " plt.figure()\n", " plt.subplot(1, 2, 1)\n", " plt.imshow(np.reshape(test_data[0][i], (28, 28)))\n", " plt.axis('off')\n", "\n", " plt.subplot(1, 2, 2)\n", " plt.plot(sim.trange(), sim.data[out_p_filt][i])\n", " plt.legend([str(i) for i in range(10)], loc=\"upper left\")\n", " plt.xlabel(\"time\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sim.close()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.4" } }, "nbformat": 4, "nbformat_minor": 2 }