Convolutional networks

In this example we will show how to build and train a convolutional network in NengoDL, and then deploy that network on Loihi.

We’ll use the basic MNIST dataset to demonstrate the steps. The input data are images of handwritten digits, and the goal is for the network to classify each image as 0-9.

We will assume here that the reader is somewhat familiar with NengoDL, and focus on the issue of how to use NengoDL to train a network for Loihi. For a more basic introduction to NengoDL, check out the documentation and examples.

[1]:
%matplotlib inline

import gzip
import os
import pickle
from urllib.request import urlretrieve

import nengo
import nengo_dl
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
try:
    import requests
    has_requests = True
except ImportError:
    has_requests = False

import nengo_loihi
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
WARNING: Logging before flag parsing goes to stderr.
W0812 13:19:44.139604 140158788081472 deprecation_wrapper.py:119] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/compat.py:30: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0812 13:19:44.141319 140158788081472 deprecation_wrapper.py:119] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/__init__.py:39: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.

W0812 13:19:44.143020 140158788081472 deprecation_wrapper.py:119] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/__init__.py:39: The name tf.logging.WARN is deprecated. Please use tf.compat.v1.logging.WARN instead.

[2]:
# helper function for later
def download(fname, drive_id):
    """Download a file from Google Drive.

    Adapted from https://stackoverflow.com/a/39225039/1306923
    """
    def get_confirm_token(response):
        for key, value in response.cookies.items():
            if key.startswith('download_warning'):
                return value
        return None

    def save_response_content(response, destination):
        CHUNK_SIZE = 32768

        with open(destination, "wb") as f:
            for chunk in response.iter_content(CHUNK_SIZE):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)

    if os.path.exists(fname):
        return
    if not has_requests:
        link = "https://drive.google.com/open?id=%s" % drive_id
        raise RuntimeError(
            "Cannot find '%s'. Download the file from\n  %s\n"
            "and place it in %s." % (fname, link, os.getcwd()))

    url = "https://docs.google.com/uc?export=download"
    session = requests.Session()
    response = session.get(url, params={'id': drive_id}, stream=True)
    token = get_confirm_token(response)
    if token is not None:
        params = {'id': drive_id, 'confirm': token}
        response = session.get(url, params=params, stream=True)
    save_response_content(response, fname)


# load mnist dataset
if not os.path.exists('mnist.pkl.gz'):
    urlretrieve('http://deeplearning.net/data/mnist/mnist.pkl.gz',
                'mnist.pkl.gz')

with gzip.open('mnist.pkl.gz') as f:
    train_data, _, test_data = pickle.load(f, encoding="latin1")
train_data = list(train_data)
test_data = list(test_data)
for data in (train_data, test_data):
    one_hot = np.zeros((data[0].shape[0], 10))
    one_hot[np.arange(data[0].shape[0]), data[1]] = 1
    data[1] = one_hot

# plot some examples
for i in range(3):
    plt.figure()
    plt.imshow(np.reshape(train_data[0][i], (28, 28)))
    plt.axis('off')
    plt.title(str(np.argmax(train_data[1][i])));
../_images/examples_mnist_convnet_2_0.png
../_images/examples_mnist_convnet_2_1.png
../_images/examples_mnist_convnet_2_2.png

We’ll begin by defining a simple function to build a “convolutional layer”. This is just a nengo.Connection and nengo.Ensemble put together, but we’ll be doing this a lot so we’ll use this function to put them together in an easy-to-use bundle.

[3]:
def conv_layer(x, *args, activation=True, **kwargs):
    # create a Conv2D transform with the given arguments
    conv = nengo.Convolution(*args, channels_last=False, **kwargs)

    if activation:
        # add an ensemble to implement the activation function
        layer = nengo.Ensemble(conv.output_shape.size, 1).neurons
    else:
        # no nonlinearity, so we just use a node
        layer = nengo.Node(size_in=conv.output_shape.size)

    # connect up the input object to the new layer
    nengo.Connection(x, layer, transform=conv)

    # print out the shape information for our new layer
    print("LAYER")
    print(conv.input_shape.shape, "->", conv.output_shape.shape)

    return layer, conv

Next we define the structure of our network. Because we need to keep the number of neurons and axons per core below the Loihi hardware limits, we adopt a somewhat unusual network architecture. We’ll have a relatively small core network, so that each layer fits on one Loihi core, and then repeat that network several times in parallel, summing their output. We can think of this as a variation on ensemble learning.

[4]:
dt = 0.001  # simulation timestep
presentation_time = 0.1  # input presentation time
max_rate = 100  # neuron firing rates
# neuron spike amplitude (scaled so that the overall output is ~1)
amp = 1 / max_rate
# input image shape
input_shape = (1, 28, 28)
n_parallel = 2  # number of parallel network repetitions

with nengo.Network(seed=0) as net:
    # set up the default parameters for ensembles/connections
    nengo_loihi.add_params(net)
    net.config[nengo.Ensemble].neuron_type = (
        nengo.SpikingRectifiedLinear(amplitude=amp))
    net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([max_rate])
    net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])
    net.config[nengo.Connection].synapse = None

    # the input node that will be used to feed in input images
    inp = nengo.Node(
        nengo.processes.PresentInput(test_data[0], presentation_time),
        size_out=28 * 28)

    # the output node provides the 10-dimensional classification
    out = nengo.Node(size_in=10)

    # build parallel copies of the network
    for _ in range(n_parallel):
        layer, conv = conv_layer(
            inp, 1, input_shape, kernel_size=(1, 1),
            init=np.ones((1, 1, 1, 1)))
        # first layer is off-chip to translate the images into spikes
        net.config[layer.ensemble].on_chip = False
        layer, conv = conv_layer(layer, 6, conv.output_shape,
                                 strides=(2, 2))
        layer, conv = conv_layer(layer, 24, conv.output_shape,
                                 strides=(2, 2))
        nengo.Connection(layer, out, transform=nengo_dl.dists.Glorot())

    out_p = nengo.Probe(out)
    out_p_filt = nengo.Probe(out, synapse=nengo.Alpha(0.01))
LAYER
(1, 28, 28) -> (1, 28, 28)
LAYER
(1, 28, 28) -> (6, 13, 13)
LAYER
(6, 13, 13) -> (24, 6, 6)
LAYER
(1, 28, 28) -> (1, 28, 28)
LAYER
(1, 28, 28) -> (6, 13, 13)
LAYER
(6, 13, 13) -> (24, 6, 6)

The next step is to optimize the parameters of the network using NengoDL.

First we set up the input/target data for the training and test datasets.

[5]:
# set up training data
minibatch_size = 200
train_data = {inp: train_data[0][:, None, :],
              out_p: train_data[1][:, None, :]}

# for the test data evaluation we'll be running the network over time
# using spiking neurons, so we need to repeat the input/target data
# for a number of timesteps (based on the presentation_time)
test_data = {
    inp: np.tile(test_data[0][:, None, :],
                 (1, int(presentation_time / dt), 1)),
    out_p_filt: np.tile(test_data[1][:, None, :],
                        (1, int(presentation_time / dt), 1))
}

Next we need to define our error functions. We’ll use two different error functions: classification error (the % of images classified incorrectly) as an intuitive measure of how well the network is doing, and crossentropy as the error measure the training processwill seek to optimize.

[6]:
def crossentropy(outputs, targets):
    return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
        logits=outputs, labels=targets))


def classification_error(outputs, targets):
    return 100 * tf.reduce_mean(
        tf.cast(tf.not_equal(tf.argmax(outputs[:, -1], axis=-1),
                             tf.argmax(targets[:, -1], axis=-1)),
                tf.float32))

Now we create the NengoDL simulator and run the training using the sim.train function.

More details on how to use NengoDL to optimize a model can be found here: https://www.nengo.ai/nengo-dl/training.html.

To speed up this example we can set do_training=False to load some pre-trained parameters. If you have the requests package installed, we will download these automatically. If not, download the following files to the directory containing this notebook.

Note that in order to run do_training=True, you will need to have TensorFlow installed with GPU support.

[7]:
do_training = False

with nengo_dl.Simulator(net, minibatch_size=minibatch_size, seed=0) as sim:
    if do_training:
        print("error before training: %.2f%%" %
              sim.loss(test_data, {out_p_filt: classification_error}))

        # run training
        sim.train(train_data, tf.train.RMSPropOptimizer(learning_rate=0.001),
                  objective={out_p: crossentropy}, n_epochs=5)

        print("error after training: %.2f%%" %
              sim.loss(test_data, {out_p_filt: classification_error}))

        sim.save_params("./mnist_params")
    else:
        download("mnist_params.data-00000-of-00001",
                 "1BaNU7Er_Q3SJt4i4Eqbv1Ln_TkmmCXvy")
        download("mnist_params.index", "1w8GNylkamI-3yHfSe_L1-dBtvaQYjNlC")
        download("mnist_params.meta", "1JiaoxIqmRupT4reQ5BrstuILQeHNffrX")
        sim.load_params("./mnist_params")

    # store trained parameters back into the network
    sim.freeze_params(net)
Build finished in 0:00:00
Optimization finished in 0:00:00
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/simulator.py:102: UserWarning: No GPU support detected. It is recommended that you install tensorflow-gpu (`pip install tensorflow-gpu`).
  "No GPU support detected. It is recommended that you "
W0812 13:19:51.147116 140158788081472 deprecation_wrapper.py:119] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/simulator.py:160: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.

|#                        Constructing graph                          | 0:00:00
W0812 13:19:51.152198 140158788081472 deprecation_wrapper.py:119] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/tensor_graph.py:200: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0812 13:19:51.154365 140158788081472 deprecation_wrapper.py:119] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/tensor_graph.py:205: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.

|         Constructing graph: creating base arrays (0%)        | ETA:  --:--:--
W0812 13:19:51.163998 140158788081472 deprecation_wrapper.py:119] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/tensor_graph.py:229: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W0812 13:19:51.165245 140158788081472 deprecation_wrapper.py:119] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/tensor_graph.py:230: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.

|             Constructing graph: build stage (0%)             | ETA:  --:--:--
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/transform_builders.py:35: UserWarning: TensorFlow does not support convolution with channels_last=False on the CPU; inputs will be transformed to channels_last=True
  UserWarning,
Construction finished in 0:00:00
W0812 13:19:55.321410 140158788081472 deprecation_wrapper.py:119] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/simulator.py:1106: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.

W0812 13:19:55.340006 140158788081472 deprecation.py:323] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.

As we built it, the network has no synaptic filters on the neural connections. This works well during training, but we can see that the error is still somewhat high when we evaluate it using spiking neurons. We can improve performance by adding synaptic filters to our trained network.

[8]:
for conn in net.all_connections:
    conn.synapse = 0.005

if do_training:
    with nengo_dl.Simulator(net, minibatch_size=minibatch_size) as sim:
        print("error w/ synapse: %.2f%%" %
              sim.loss(test_data, {out_p_filt: classification_error}))

Now we can load our trained network, with synaptic filters, onto Loihi. This is as easy as passing the network to nengo_loihi.Simulator and running it, there is no extra work required. We will give the network 100 test images, and use that to evaluate the classification error.

[9]:
n_presentations = 50
with nengo_loihi.Simulator(net, dt=dt, precompute=False) as sim:
    # if running on Loihi, increase the max input spikes per step
    if 'loihi' in sim.sims:
        sim.sims['loihi'].snip_max_spikes_per_step = 120

    # run the simulation on Loihi
    sim.run(n_presentations * presentation_time)

    # check classification error
    step = int(presentation_time / dt)
    output = sim.data[out_p_filt][step - 1::step]
    correct = 100 * (np.mean(
        np.argmax(output, axis=-1)
        != np.argmax(test_data[out_p_filt][:n_presentations, -1],
                     axis=-1)
    ))
    print("loihi error: %.2f%%" % correct)
loihi error: 2.00%

We can also plot the output activity from the Loihi network as we show it different test images, to see what this performance looks like in practice.

[10]:
n_plots = 10
plt.figure()

plt.subplot(2, 1, 1)
images = test_data[inp].reshape(-1, 28, 28, 1)[::step]
ni, nj, nc = images[0].shape
allimage = np.zeros((ni, nj * n_plots, nc), dtype=images.dtype)
for i, image in enumerate(images[:n_plots]):
    allimage[:, i * nj:(i + 1) * nj] = image
if allimage.shape[-1] == 1:
    allimage = allimage[:, :, 0]
plt.imshow(allimage, aspect='auto', interpolation='none', cmap='gray')

plt.subplot(2, 1, 2)
plt.plot(sim.trange()[:n_plots * step], sim.data[out_p_filt][:n_plots * step])
plt.legend(['%d' % i for i in range(10)], loc='best');
../_images/examples_mnist_convnet_18_0.png