Inserting a TensorFlow/Keras network into a Nengo model

Often we may want to define one part of our model in Nengo, and another part in TensorFlow. For example, suppose we are building a biological reinforcement learning model, but we’d like the inputs to our model to be natural images rather than artificial vectors. We could load a vision network from TensorFlow, insert it into our model using NengoDL, and then build the rest of our model using normal Nengo syntax.

NengoDL supports this through the TensorNode class. This allows us to write code directly in TensorFlow, and then insert it easily into Nengo. In this example we will demonstrate this in two different ways: first using a network defined using Keras, and second using a prebuilt vision network from the tensorflow/models repository.

[1]:
%matplotlib inline

import sys
import os
from urllib.request import urlopen
import io
import shutil
import stat

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf
from tensorflow import keras
import tensorflow.contrib.slim as slim;

import nengo
import nengo_dl

Introduction to TensorNodes

nengo_dl.TensorNode works very similarly to nengo.Node, except instead of using the node to insert Python code into our model we will use it to insert TensorFlow code.

The first thing we need to do is define our TensorNode output. This is a function that accepts the current simulation time (and, optionally, a batch of vectors) as input, and produces a batch of vectors as output. All of these variables will be represented as tf.Tensor objects, and the internal operations of the TensorNode will be implemented with TensorFlow operations. For example, we could use a TensorNode to output a sin function:

[2]:
with nengo.Network() as net:
    def sin_func(t):
        # compute sin wave (based on simulation time)
        out = tf.sin(t)

        # convert output to the expected batched vector shape
        # (with batch size of 1 and vector dimensionality 1)
        out = tf.reshape(out, (1, 1))

        return out

    node = nengo_dl.TensorNode(sin_func)
    p = nengo.Probe(node)

with nengo_dl.Simulator(net) as sim:
    sim.run(5.0)

plt.figure()
plt.plot(sim.trange(), sim.data[p]);
Build finished in 0:00:00
Optimization finished in 0:00:00
|#                        Constructing graph                          | 0:00:00WARNING:tensorflow:From /home/travis/miniconda/envs/test/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
/home/travis/build/nengo/nengo-dl/nengo_dl/simulator.py:95: UserWarning: No GPU support detected. It is recommended that you install tensorflow-gpu (`pip install tensorflow-gpu`).
  "No GPU support detected. It is recommended that you "
WARNING:tensorflow:From /home/travis/miniconda/envs/test/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
Construction finished in 0:00:00
Simulation finished in 0:00:00
../_images/examples_tensorflow-models_3_3.png

However, outputting a sin function is something we could do more easily with a regular nengo.Node. The main use case for nengo_dl.TensorNode is to allow us to write more complex TensorFlow code and insert it into a NengoDL model. We will see two different examples of this below.

Inserting a Keras network

Keras is a popular software package for building and training deep learning style networks. It provides a higher-level syntactical wrapper around TensorFlow (or other packages, such as Theano). And because it is defining a TensorFlow network under the hood, we can define a network using Keras and then insert it into NengoDL using a TensorNode.

This example assumes familiarity with the Keras API. Specifically it is based on the introduction in the Tensorflow documentation, so if you are not yet familiar with Keras, you may find it helpful to read those tutorials first.

For this example we’ll train a neural network to classify the fashion MNIST dataset. This dataset contains images of clothing, and the goal of the network is to identify what type of clothing it is (e.g. t-shirt, trouser, coat, etc.).

[3]:
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = (
    fashion_mnist.load_data())
num_classes = np.unique(test_labels).shape[0]

# normalize images so values are between 0 and 1
train_images = train_images / 255.0
test_images = test_images / 255.0

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

for i in range(3):
    plt.figure()
    plt.imshow(test_images[i], cmap="gray")
    plt.axis("off")
    plt.title(class_names[test_labels[i]]);
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
../_images/examples_tensorflow-models_6_1.png
../_images/examples_tensorflow-models_6_2.png
../_images/examples_tensorflow-models_6_3.png

Next we build and train a simple neural network, using Keras. In this case we’re building a simple two layer, densely connected network.

Note that alternatively we could define the network in Keras and then train it in NengoDL (using the Simulator.train function). But for this example we’ll show how to do everything in Keras.

[4]:
image_shape = (28, 28)

model = keras.Sequential([
    keras.layers.Flatten(input_shape=image_shape, name='flatten'),
    keras.layers.Dense(128, activation=tf.nn.relu, name='hidden'),
    keras.layers.Dense(num_classes, activation=tf.nn.softmax,
                       name='softmax')
])

model.compile(optimizer=tf.train.AdamOptimizer(),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=5);
Epoch 1/5
60000/60000 [==============================] - 4s 71us/sample - loss: 0.5002 - acc: 0.8259
Epoch 2/5
60000/60000 [==============================] - 4s 69us/sample - loss: 0.3755 - acc: 0.8644
Epoch 3/5
60000/60000 [==============================] - 4s 71us/sample - loss: 0.3367 - acc: 0.8779
Epoch 4/5
60000/60000 [==============================] - 4s 73us/sample - loss: 0.3127 - acc: 0.8863
Epoch 5/5
60000/60000 [==============================] - 4s 70us/sample - loss: 0.2962 - acc: 0.8910

We’ll save the trained weights, so that we can load them later within our TensorNode.

[5]:
model_weights = "keras_weights.h5"
model.save_weights(model_weights)

Now we’re ready to create our TensorNode. Instead of using a function for our TensorNode output, in this case we’ll use a callable class so that we can include pre_build and post_build functions. These allow us to execute code at different stages during the build process, which can be necessary for more complicated TensorNodes.

NengoDL will call the pre_build function once when the model is first constructed, so we can use this function to perform any initial setup required for our node. In this case we’ll use the pre_build function to call the Keras clone_model function. This effectively reruns the Keras model definition from above, but because we’re calling it in the pre_build stage it will be naturally integrated into the NengoDL model that is being built.

The __call__ function is where we do the main job of constructing the TensorFlow elements that will implement our node. It will take TensorFlow Tensors as input and produce a tf.Tensor as output, as with the tf.sin example above. In this case we apply the Keras model to the TensorNode inputs (stored in the x variable). This adds the TensorFlow elements that implement that Keras model into the simulation graph.

The post_build function is called after the rest of the graph has been constructed (and whenever the simulation is reset). We’ll use this to load the pretrained weights into the model. We have to do this at the post_build stage because we need access to the initialized simulation session, which has the variables we want to load.

[6]:
class KerasNode:
    def __init__(self, keras_model):
        self.model = keras_model

    def pre_build(self, *args):
        self.model = keras.models.clone_model(self.model)

    def __call__(self, t, x):
        # reshape the flattened images into their 2D shape
        # (plus the batch dimension)
        images = tf.reshape(x, (-1,) + image_shape)
        # build the rest of the model into the graph
        return self.model.call(images)

    def post_build(self, sess, rng):
        self.model.load_weights(model_weights)

Notice that in the __call__ method we pass our input tensor to the Model.call method, not Model.predict (which you might be more familiar with if you frequently work with Keras). We do this because we want the model to return a Tensor object (i.e., an abstract representation of the computations that will be performed in the network), rather than actually simulating the network and computing predictions (as the predict function does). This way the returned Tensor can become part of the TensorFlow graph that NengoDL is constructing.

To better understand the difference between model.call(images) and model.predict(images), we can look at the code below:

[7]:
with tf.Session():
    model.load_weights(model_weights)

    # model.call takes a Tensor as input and returns a Tensor
    out1 = model.call(tf.convert_to_tensor(test_images[:10],
                                           dtype=tf.float32))
    print("Type of 'out1':", type(out1))

    # model.predict takes a numpy array as input and returns a numpy array
    out2 = model.predict(test_images[:10])
    print("Type of 'out2':", type(out2))
Type of 'out1': <class 'tensorflow.python.framework.ops.Tensor'>
Type of 'out2': <class 'numpy.ndarray'>

Now that we have our KerasNode class, we can use it to insert our Keras model into a Nengo network via a TensorNode.

[8]:
net_input_shape = np.prod(image_shape)  # because input will be a vector

with nengo.Network() as net:
    # create a normal input node to feed in our test image.
    # the `np.ones` array is a placeholder, these
    # values will be replaced with the Fashion MNIST images
    # when we run the Simulator.
    input_node = nengo.Node(output=np.ones((net_input_shape,)))

    # create a TensorNode containing the KerasNode we defined
    # above, passing it the Keras model we created.
    # we also need to specify size_in (the dimensionality of
    # our input vectors, the flattened images) and size_out (the number
    # of classification classes output by the keras network)
    keras_node = nengo_dl.TensorNode(
        KerasNode(model),
        size_in=net_input_shape,
        size_out=num_classes)

    # connect up our input to our keras node
    nengo.Connection(input_node, keras_node, synapse=None)

    # add a probes to collect output of keras node
    keras_p = nengo.Probe(keras_node)

At this point we could add any other Nengo components we like to the network, and connect them up to the Keras node (for example, if we wanted to take the classified image labels and use them as input to a spiking neural model). But to keep things simple, we’ll stop here.

We’ll grab some random images from our test set, in order to demonstrate that we have successfully loaded the trained Keras network.

[9]:
minibatch_size = 20

# pick some random images from test set
np.random.seed(1)
test_inds = np.random.randint(low=0, high=test_images.shape[0],
                              size=(minibatch_size,))
test_inputs = test_images[test_inds]

# flatten images so we can pass them as vectors to the input node
test_inputs = test_inputs.reshape((-1, net_input_shape))

# unlike in Keras, NengoDl simulations always run over time.
# so we need to add the time dimension to our data (even though
# in this case we'll just run for a single timestep).
test_inputs = test_inputs[:, None, :]

Now we are ready to run the simulation.

[10]:
with nengo_dl.Simulator(net, minibatch_size=minibatch_size) as sim:
    sim.step(data={input_node: test_inputs})
Build finished in 0:00:00
Optimization finished in 0:00:00
Construction finished in 0:00:00

We can see the results of the simulation using the Probe that we added to capture the output from the TensorNode. We use it as a key into the data attribute of the Simulator.

[11]:
tensornode_output = sim.data[keras_p]

for i in range(5):
    plt.figure()
    plt.imshow(test_images[test_inds][i], cmap="gray")
    plt.axis("off")
    plt.title("%s (%s)" % (
        class_names[test_labels[test_inds][i]],
        class_names[np.argmax(tensornode_output[i, 0])]));
../_images/examples_tensorflow-models_22_0.png
../_images/examples_tensorflow-models_22_1.png
../_images/examples_tensorflow-models_22_2.png
../_images/examples_tensorflow-models_22_3.png
../_images/examples_tensorflow-models_22_4.png

We can see that the network is doing a pretty good job of classifying the test images (the title shows the correct output, with the networks’ output shown in brackets).

Inserting a TensorFlow-Slim network

In this example we’ll show how to insert a more complicated network into NengoDL. Specifically, we will use an Inception-v1 network to classify Imagenet images.

TensorFlow provides a number of pre-defined models in the tensorflow/models repository. These are not included when you install TensorFlow, so we need to separately clone that repository and import the components we need.

[12]:
!git clone -q https://github.com/tensorflow/models
sys.path.append(os.path.join(".", "models", "research", "slim"))
from datasets import dataset_utils, imagenet
from nets import inception
from preprocessing import inception_preprocessing

As before, we will use a TensorNode to insert our TensorFlow code into Nengo. In this case we’re going to build a TensorNode that encapsulates the Inception-v1 network. However, this same approach could be used for any TensorFlow network.

This Inception-v1 network has been trained to perform image classification on the Imagenet dataset; if we show it an image, it will output a set of probabilities for the 1000 different object types it is trained to classify. So if we show it an image of a tree it should output a high probability for the “tree” class and a low probability for the “car” class.

The first thing we’ll do is download a sample image to test our network with (you could use a different image if you want).

[13]:
url = 'https://upload.wikimedia.org/wikipedia/commons/7/70/EnglishCockerSpaniel_simon.jpg'
image_string = urlopen(url).read()
image = np.array(Image.open(io.BytesIO(image_string)))
image_shape = image.shape

# display the test image
plt.figure()
plt.imshow(image)
plt.axis('off');
../_images/examples_tensorflow-models_27_0.png

Now we’re ready to create our TensorNode. As in the previous example, we will use a callable class so that we can use pre_build and post_build methods to help construct the model.

In this case we’ll use the pre_build function to download pre-trained weights for the Inception network. Again, if we wanted we could train the network from scratch using the sim.train function, but that would take a long time.

In the __call__ function we apply some pre-processing to transform the TensorNode inputs (stored in the x variable) into the form expected by the inception network. Then we call the inception_v1 method, which will construct all the TensorFlow elements required to implement that network, and return the resulting output Tensor.

We’ll use the post_build function to load the pretrained weights into the model, as in the previous example.

[14]:
checkpoints_dir = '/tmp/checkpoints'

class InceptionNode:
    def pre_build(self, *args):
        # the shape of the inputs to the inception network
        self.input_shape = inception.inception_v1.default_image_size

        # download model checkpoint file
        if not tf.gfile.Exists(checkpoints_dir):
            tf.gfile.MakeDirs(checkpoints_dir)
        dataset_utils.download_and_uncompress_tarball(
            "http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz",
            checkpoints_dir)

    def __call__(self, t, x):
        # convert our input vector to the shape/dtype of the input image
        img = tf.reshape(tf.cast(x, tf.uint8), image_shape)

        # reshape the image to the shape expected by the
        # inception network
        img = inception_preprocessing.preprocess_image(
            img, self.input_shape, self.input_shape, is_training=False)
        img = tf.expand_dims(img, 0)

        # create inception network
        with slim.arg_scope(inception.inception_v1_arg_scope()):
            logits, _ = inception.inception_v1(img,
                                               num_classes=1001,
                                               is_training=False)

        # return our classification probabilites
        return tf.nn.softmax(logits)

    def post_build(self, sess, rng):
        # load checkpoint file into model
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'inception_v1.ckpt'),
            slim.get_model_variables('InceptionV1'))

        init_fn(sess)

Next we create a Nengo Network containing our TensorNode.

[15]:
with nengo.Network() as net:
    # create a normal input node to feed in our test image
    input_node = nengo.Node(output=image.flatten())

    # create our TensorNode containing the InceptionNode() we defined
    # above.  we also need to specify size_in (the dimensionality of
    # our input vectors, the flattened images) and size_out (the number
    # of classification classes output by the inception network)
    incep_node = nengo_dl.TensorNode(
        InceptionNode(), size_in=np.prod(image_shape), size_out=1001)

    # connect up our input to our inception node
    nengo.Connection(input_node, incep_node, synapse=None)

    # add some probes to collect data
    input_p = nengo.Probe(input_node)
    incep_p = nengo.Probe(incep_node)

As with the previous example, at this point we could connect up the output of incep_node to any other part of our network, if this was part of a larger model. But to keep this example simple we’ll stop here.

All that’s left is to run our network, using our example image as input, and check the output.

[16]:
# run the network for one timestep
with nengo_dl.Simulator(net) as sim:
    sim.step()

# sort the output labels based on the classification probabilites
# output from the network
probabilities = sim.data[incep_p][0]
sorted_inds = [i[0] for i in sorted(enumerate(-probabilities),
                                    key=lambda x: x[1])]

# print top 5 classes
names = imagenet.create_readable_names_for_imagenet_labels()
for i in range(5):
    index = sorted_inds[i]
    print('Probability %0.2f%% => [%s]' % (
        probabilities[index] * 100, names[index]))

# display the test image
plt.figure()
plt.imshow(sim.data[input_p][0].reshape(image_shape).astype(np.uint8))
plt.axis('off');
Build finished in 0:00:00
Optimization finished in 0:00:00
>> Downloading inception_v1_2016_08_28.tar.gz 100.0%
Successfully downloaded inception_v1_2016_08_28.tar.gz 24642554 bytes.
Construction finished in 0:00:07
WARNING:tensorflow:From /home/travis/miniconda/envs/test/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
WARNING:tensorflow:From /home/travis/miniconda/envs/test/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
Probability 44.95% => [cocker spaniel, English cocker spaniel, cocker]
Probability 22.56% => [Sussex spaniel]
Probability 10.18% => [Irish setter, red setter]
Probability 4.48% => [Welsh springer spaniel]
Probability 3.42% => [clumber, clumber spaniel]
../_images/examples_tensorflow-models_33_3.png
[17]:
# delete the models repo we cloned
def onerror(func, path, exc_info):
    if not os.access(path, os.W_OK):
        os.chmod(path, stat.S_IWUSR)
        func(path)
    else:
        raise exc_info[1]
shutil.rmtree("models", onerror=onerror)