{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Inserting a TensorFlow network into a Nengo model\n", "\n", "TensorFlow comes with a wide range of pre-defined deep learning models, which we might want to incorporate into a Nengo model. For example, suppose we are building a biological reinforcement learning model, but we'd like the inputs to our model to be natural images rather than artificial vectors. We could load a vision network from TensorFlow, insert it into our model using NengoDL, and then build the rest of our model using normal Nengo syntax.\n", "\n", "In this example we'll show how to use TensorNodes to insert a pre-trained TensorFlow model (Inception-v1) into Nengo." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "import sys\n", "import os\n", "from urllib.request import urlopen\n", "import io\n", "import shutil\n", "import stat\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import tensorflow as tf\n", "import tensorflow.contrib.slim as slim;\n", "\n", "import nengo\n", "import nengo_dl" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TensorFlow provides a number of pre-defined models in the [tensorflow/models](https://github.com/tensorflow/models) repository. These are not included when you install TensorFlow, so we need to separately clone that repository and import the components we need." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "!git clone https://github.com/tensorflow/models\n", "sys.path.append(os.path.join(\".\", \"models\", \"slim\"))\n", "from datasets import dataset_utils, imagenet\n", "from nets import inception\n", "from preprocessing import inception_preprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use a [TensorNode](https://www.nengo.ai/nengo_dl/tensor_node.html) to insert our TensorFlow code into Nengo. `nengo_dl.TensorNode` works very similarly to `nengo.Node`, except instead of using the node to insert Python code into our model we will use it to insert TensorFlow code. \n", "\n", "The first thing we need to do is define our TensorNode output. This should be a function that accepts the current simulation time (and, optionally, a batch of vectors) as input, and produces a batch of vectors as output. All of these variables will be represented as `tf.Tensor` objects, and the internal operations of the TensorNode will be implemented with TensorFlow operations. For example, we could use a TensorNode to output a `sin` function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "with nengo.Network() as net:\n", " node = nengo_dl.TensorNode(lambda t: tf.reshape(tf.sin(t), (1, 1)))\n", " p = nengo.Probe(node)\n", "\n", "with nengo_dl.Simulator(net) as sim:\n", " sim.run(5.0)\n", " \n", "plt.figure()\n", "plt.plot(sim.trange(), sim.data[p])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, outputing a `sin` function is something we could do more easily with a regular `nengo.Node`. The main use case for `nengo_dl.TensorNode` is to work with artificial neural networks that are not easily defined in Nengo.\n", "\n", "In this case we're going to build a TensorNode that encapsulates the [Inception-v1](https://arxiv.org/abs/1409.4842) network. Inception-v1 isn't state-of-the-art anymore (we're up to Inception-v4 now), but it is relatively small so it will be quick to download/run in this example. However, this same approach could be used for any TensorFlow network.\n", "\n", "Inception-v1 performs image classification; if we show it an image, it will output a set of probabilities for the 1000 different object types it is trained to classify. So if we show it an image of a tree it should output a high probability for the \"tree\" class and a low probability for the \"car\" class.\n", "\n", "The first thing we'll do is download a sample image to test our network with (you could use a different image if you want)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "url = 'https://upload.wikimedia.org/wikipedia/commons/7/70/EnglishCockerSpaniel_simon.jpg'\n", "image_string = urlopen(url).read()\n", "image = np.array(Image.open(io.BytesIO(image_string)))\n", "image_shape = image.shape\n", "\n", "# display the test image\n", "plt.figure()\n", "plt.imshow(image)\n", "plt.axis('off');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we're ready to create our TensorNode. Instead of using a function for our TensorNode output, in this case we'll use a callable class. This allows us to include a `pre_build` function in our TensorNode. NengoDL will call this function once when the model is first constructed, so we can use this function to perform any initial setup required for our node." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "checkpoints_dir = '/tmp/checkpoints'\n", "\n", "class InceptionNode(object):\n", " def pre_build(self, *args):\n", " # the shape of the inputs to the inception network\n", " self.input_shape = inception.inception_v1.default_image_size\n", " \n", " # download model checkpoint file\n", " if not tf.gfile.Exists(checkpoints_dir):\n", " tf.gfile.MakeDirs(checkpoints_dir)\n", " url = \"http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz\"\n", " dataset_utils.download_and_uncompress_tarball(url, checkpoints_dir)\n", "\n", " def __call__(self, t, x):\n", " # this is the function that will be executed each timestep while the\n", " # network is running\n", " \n", " # convert our input vector to the shape/dtype of the input image\n", " image = tf.reshape(tf.cast(x, tf.uint8), image_shape)\n", "\n", " # reshape the image to the shape expected by the inception network\n", " processed_image = inception_preprocessing.preprocess_image(\n", " image, self.input_shape, self.input_shape, is_training=False)\n", " processed_images = tf.expand_dims(processed_image, 0)\n", "\n", " # create inception network\n", " with slim.arg_scope(inception.inception_v1_arg_scope()):\n", " logits, _ = inception.inception_v1(processed_images,\n", " num_classes=1001,\n", " is_training=False)\n", " probabilities = tf.nn.softmax(logits)\n", "\n", " # return our classification probabilites\n", " return probabilities" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we create a Nengo Network, containing our TensorNode." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with nengo.Network() as net:\n", " # create a normal input node to feed in our test image\n", " input_node = nengo.Node(output=image.flatten())\n", "\n", " # create our TensorNode containing the InceptionNode() we defined\n", " # above. we also need to specify size_in (the dimensionality of\n", " # our input vectors, the flattened images) and size_out (the number\n", " # of classification classes output by the inception network)\n", " incep_node = nengo_dl.TensorNode(\n", " InceptionNode(), size_in=np.prod(image_shape), size_out=1001)\n", " \n", " # connect up our input to our inception node\n", " nengo.Connection(input_node, incep_node, synapse=None)\n", " \n", " # add some probes to collect data\n", " input_p = nengo.Probe(input_node)\n", " incep_p = nengo.Probe(incep_node)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that at this point we could connect up the output of `incep_node` to any other part of our network, if this was part of a larger model. But to keep this example simple we'll stop here.\n", "\n", "Next we'll load the pre-trained weights for the Inception network. If we wanted we could train the network from scratch using the `sim.train` function, but that would take a long time and require some expertise in training deep networks." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "sim = nengo_dl.Simulator(net)\n", "\n", "with sim.tensor_graph.graph.as_default():\n", " # load checkpoint file\n", " init_fn = slim.assign_from_checkpoint_fn(\n", " os.path.join(checkpoints_dir, 'inception_v1.ckpt'),\n", " slim.get_model_variables('InceptionV1'))\n", " \n", " # update the parameters of the model\n", " init_fn(sim.sess)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And finally, all that's left is to run our network, using our example image as input, and check the output" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# run the network for one timestep\n", "sim.step()\n", "\n", "# sort the output labels based on the classification probabilites \n", "# output from the network\n", "probabilities = sim.data[incep_p][0]\n", "sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), \n", " key=lambda x: x[1])]\n", "\n", "# print top 5 classes\n", "names = imagenet.create_readable_names_for_imagenet_labels()\n", "for i in range(5):\n", " index = sorted_inds[i]\n", " print('Probability %0.2f%% => [%s]' % (\n", " probabilities[index] * 100, names[index]))\n", " \n", "# display the test image\n", "plt.figure()\n", "plt.imshow(sim.data[input_p][0].reshape(image_shape).astype(np.uint8))\n", "plt.axis('off');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sim.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# delete the models repo we cloned\n", "def onerror(func, path, exc_info):\n", " if not os.access(path, os.W_OK):\n", " os.chmod(path, stat.S_IWUSR)\n", " func(path)\n", " else:\n", " raise\n", "shutil.rmtree(\"models\", onerror=onerror)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.3" } }, "nbformat": 4, "nbformat_minor": 1 }