Optimizing Existing Networks

Nengo DL is not confined to opimizing custom made networks, it can also be used to make existing networks better, or achieve the same result with fewer neurons. What this example will show is how to train a circular convolution network.

Circular convolution is a key operation used to process semantic pointers. By optimizing this smaller network, larger more complex networks that utilize circular convolution can benefit.

[1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import nengo
import nengo_dl
from nengo.spa import Vocabulary
import tensorflow as tf

To properly train the network, we generate novel training data by randomly generating semantic pointers.

[2]:
def gen_pointers(n_inputs, dims, rng):
    vocabulary = Vocabulary(dimensions=dims, rng=rng, max_similarity=1)
    for v in range(n_inputs):
        # keys start with A, second element starts with B, third starts with C
        conv_key = f"C{v}"
        point_key_1 = f"A{v}"
        pointer_1 = vocabulary.create_pointer()
        point_key_2 = f"B{v}"
        pointer_2 = vocabulary.create_pointer()
        vocabulary.add(point_key_1, pointer_1)
        vocabulary.add(point_key_2, pointer_2)
        vocabulary.add(conv_key, vocabulary.parse(point_key_2 + "*" + point_key_1))

    A = np.asarray([vocabulary[f"A{i}"].v for i in range(n_inputs)])[:, None, :]
    B = np.asarray([vocabulary[f"B{i}"].v for i in range(n_inputs)])[:, None, :]
    C = np.asarray([vocabulary[f"C{i}"].v for i in range(n_inputs)])[:, None, :]
    return A, B, C, vocabulary


rng = np.random.RandomState(0)
dimensions = 50
test_a, test_b, test_c, vocab = gen_pointers(10, dimensions, rng)

We want our optimized network to work with spiking LIF neurons, so we will use SoftLIFRate neurons (a differentiable approximation of LIF neurons) to train the network.

We’ll start with the nengo.networks.CircularConvolution network, where all the parameters are initialized using the standard Nengo methods, and then further optimize those parameters using deep learning training methods.

In this example only 5 neurons are used per dimension for the circular convolution. This is fewer than would typically be used in a Nengo model, but the enhanced performance enabled by the training process will allow the network to function well with this restricted number of neurons.

[3]:
with nengo.Network(seed=rng.randint(1e6)) as net:
    net.config[nengo.Ensemble].neuron_type = nengo_dl.SoftLIFRate(sigma=0.1)
    net.config[nengo.Connection].synapse = None

    # Get the raw vectors for the pointers using `vocab['A'].v`
    a = nengo.Node(output=vocab["A0"].v)
    b = nengo.Node(output=vocab["B0"].v)

    # Make the circular convolution network with 5 neurons per dimension
    cconv = nengo.networks.CircularConvolution(5, dimensions=dimensions)

    # Connect the input nodes to the input slots `A` and `B` on the network
    nengo.Connection(a, cconv.input_a)
    nengo.Connection(b, cconv.input_b)

    # Probe the output
    out = nengo.Probe(cconv.output)
    out_filtered = nengo.Probe(cconv.output, synapse=0.01)

We now run the network in its default state to get an idea of the baseline performance. Ideally the output would be clearly C0, the result of the convolution between A0 and B0, but we can see that it is poorly differentiated.

[4]:
with nengo.Simulator(net) as sim:
    sim.run(0.3)
plt.figure()
output_vocab = vocab.create_subset([f"C{i}" for i in range(10)])
plt.plot(sim.trange(), nengo.spa.similarity(sim.data[out_filtered], output_vocab))
plt.legend(output_vocab.keys, loc=4)
plt.ylim([-1, 1])
plt.xlabel("t [s]")
plt.ylabel("dot product")
0%
 
/home/tbekolay/Code/nengo-dl/nengo_dl/neurons.py:77: RuntimeWarning: divide by zero encountered in true_divide
  q = np.where(j_valid, np.log1p(1 / z), -js - np.log(self.sigma))
/home/tbekolay/Code/nengo-dl/nengo_dl/neurons.py:77: RuntimeWarning: overflow encountered in true_divide
  q = np.where(j_valid, np.log1p(1 / z), -js - np.log(self.sigma))
0%
 
[4]:
Text(0, 0.5, 'dot product')
../_images/deeplearning_circularconvolution-softlif_7_6.png

Now we can optimize our network, by showing it random input pointers and training it to output their circular convolution.

[5]:
with nengo_dl.Simulator(net, minibatch_size=100, device="/cpu:0") as sim:
    optimizer = tf.compat.v1.train.RMSPropOptimizer(5e-3)

    # generate random data
    train_a, train_b, train_c, _ = gen_pointers(1000, dimensions, rng)
    input_feed = {a: train_a, b: train_b}
    output_feed = {out: train_c}

    # train the network for one epoch
    sim.compile(loss="mse", optimizer=optimizer)
    sim.fit(input_feed, output_feed, epochs=100)

    sim.run(0.3)
|#####################Building network (40%)                     | ETA: 0:00:00
/home/tbekolay/Code/nengo-dl/nengo_dl/neurons.py:75: RuntimeWarning: overflow encountered in exp
  z = np.where(js > 30, js, np.log1p(np.exp(js))) * self.sigma
/home/tbekolay/Code/nengo-dl/nengo_dl/neurons.py:77: RuntimeWarning: divide by zero encountered in true_divide
  q = np.where(j_valid, np.log1p(1 / z), -js - np.log(self.sigma))
/home/tbekolay/Code/nengo-dl/nengo_dl/neurons.py:77: RuntimeWarning: overflow encountered in true_divide
  q = np.where(j_valid, np.log1p(1 / z), -js - np.log(self.sigma))
Build finished in 0:00:01
Optimization finished in 0:00:00
|##############Constructing graph: build stage (63%)             | ETA: 0:00:00
2022-01-21 11:48:39.546632: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
Construction finished in 0:00:01
Epoch 1/100
|             Constructing graph: build stage (0%)             | ETA:  --:--:--
/home/tbekolay/Code/nengo-dl/nengo_dl/simulator.py:1024: UserWarning: Running for one timestep, but the network contains synaptic filters (which will introduce at least a one-timestep delay); did you mean to set synapse=None?
  warnings.warn(
10/10 [==============================] - 2s 22ms/step - loss: 0.0260 - probe_loss: 0.0260 - probe_1_loss: 0.0000e+00
Epoch 2/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0191 - probe_loss: 0.0191 - probe_1_loss: 0.0000e+00
Epoch 3/100
10/10 [==============================] - 0s 21ms/step - loss: 0.0175 - probe_loss: 0.0175 - probe_1_loss: 0.0000e+00
Epoch 4/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0251 - probe_loss: 0.0251 - probe_1_loss: 0.0000e+00
Epoch 5/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0309 - probe_loss: 0.0309 - probe_1_loss: 0.0000e+00
Epoch 6/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0304 - probe_loss: 0.0304 - probe_1_loss: 0.0000e+00
Epoch 7/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0283 - probe_loss: 0.0283 - probe_1_loss: 0.0000e+00
Epoch 8/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0253 - probe_loss: 0.0253 - probe_1_loss: 0.0000e+00
Epoch 9/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0223 - probe_loss: 0.0223 - probe_1_loss: 0.0000e+00
Epoch 10/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0211 - probe_loss: 0.0211 - probe_1_loss: 0.0000e+00
Epoch 11/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0190 - probe_loss: 0.0190 - probe_1_loss: 0.0000e+00
Epoch 12/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0176 - probe_loss: 0.0176 - probe_1_loss: 0.0000e+00
Epoch 13/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0163 - probe_loss: 0.0163 - probe_1_loss: 0.0000e+00
Epoch 14/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0160 - probe_loss: 0.0160 - probe_1_loss: 0.0000e+00
Epoch 15/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0147 - probe_loss: 0.0147 - probe_1_loss: 0.0000e+00
Epoch 16/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0141 - probe_loss: 0.0141 - probe_1_loss: 0.0000e+00
Epoch 17/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0141 - probe_loss: 0.0141 - probe_1_loss: 0.0000e+00
Epoch 18/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0139 - probe_loss: 0.0139 - probe_1_loss: 0.0000e+00
Epoch 19/100
10/10 [==============================] - 0s 21ms/step - loss: 0.0136 - probe_loss: 0.0136 - probe_1_loss: 0.0000e+00
Epoch 20/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0132 - probe_loss: 0.0132 - probe_1_loss: 0.0000e+00
Epoch 21/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0130 - probe_loss: 0.0130 - probe_1_loss: 0.0000e+00
Epoch 22/100
10/10 [==============================] - 0s 21ms/step - loss: 0.0128 - probe_loss: 0.0128 - probe_1_loss: 0.0000e+00
Epoch 23/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0124 - probe_loss: 0.0124 - probe_1_loss: 0.0000e+00
Epoch 24/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0122 - probe_loss: 0.0122 - probe_1_loss: 0.0000e+00
Epoch 25/100
10/10 [==============================] - 0s 21ms/step - loss: 0.0118 - probe_loss: 0.0118 - probe_1_loss: 0.0000e+00
Epoch 26/100
10/10 [==============================] - 0s 21ms/step - loss: 0.0121 - probe_loss: 0.0121 - probe_1_loss: 0.0000e+00
Epoch 27/100
10/10 [==============================] - 0s 20ms/step - loss: 0.0115 - probe_loss: 0.0115 - probe_1_loss: 0.0000e+00
Epoch 28/100
10/10 [==============================] - 0s 21ms/step - loss: 0.0111 - probe_loss: 0.0111 - probe_1_loss: 0.0000e+00
Epoch 29/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0113 - probe_loss: 0.0113 - probe_1_loss: 0.0000e+00
Epoch 30/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0112 - probe_loss: 0.0112 - probe_1_loss: 0.0000e+00
Epoch 31/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0107 - probe_loss: 0.0107 - probe_1_loss: 0.0000e+00
Epoch 32/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0111 - probe_loss: 0.0111 - probe_1_loss: 0.0000e+00
Epoch 33/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0103 - probe_loss: 0.0103 - probe_1_loss: 0.0000e+00
Epoch 34/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0109 - probe_loss: 0.0109 - probe_1_loss: 0.0000e+00
Epoch 35/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0102 - probe_loss: 0.0102 - probe_1_loss: 0.0000e+00
Epoch 36/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0100 - probe_loss: 0.0100 - probe_1_loss: 0.0000e+00
Epoch 37/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0118 - probe_loss: 0.0118 - probe_1_loss: 0.0000e+00
Epoch 38/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0095 - probe_loss: 0.0095 - probe_1_loss: 0.0000e+00
Epoch 39/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0096 - probe_loss: 0.0096 - probe_1_loss: 0.0000e+00
Epoch 40/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0115 - probe_loss: 0.0115 - probe_1_loss: 0.0000e+00
Epoch 41/100
10/10 [==============================] - 0s 25ms/step - loss: 0.0093 - probe_loss: 0.0093 - probe_1_loss: 0.0000e+00
Epoch 42/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0093 - probe_loss: 0.0093 - probe_1_loss: 0.0000e+00
Epoch 43/100
10/10 [==============================] - 0s 25ms/step - loss: 0.0104 - probe_loss: 0.0104 - probe_1_loss: 0.0000e+00
Epoch 44/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0091 - probe_loss: 0.0091 - probe_1_loss: 0.0000e+00
Epoch 45/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0103 - probe_loss: 0.0103 - probe_1_loss: 0.0000e+00
Epoch 46/100
10/10 [==============================] - 0s 21ms/step - loss: 0.0098 - probe_loss: 0.0098 - probe_1_loss: 0.0000e+00
Epoch 47/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0088 - probe_loss: 0.0088 - probe_1_loss: 0.0000e+00
Epoch 48/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0105 - probe_loss: 0.0105 - probe_1_loss: 0.0000e+00
Epoch 49/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0089 - probe_loss: 0.0089 - probe_1_loss: 0.0000e+00
Epoch 50/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0089 - probe_loss: 0.0089 - probe_1_loss: 0.0000e+00
Epoch 51/100
10/10 [==============================] - 0s 25ms/step - loss: 0.0102 - probe_loss: 0.0102 - probe_1_loss: 0.0000e+00
Epoch 52/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0085 - probe_loss: 0.0085 - probe_1_loss: 0.0000e+00
Epoch 53/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0086 - probe_loss: 0.0086 - probe_1_loss: 0.0000e+00
Epoch 54/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0116 - probe_loss: 0.0116 - probe_1_loss: 0.0000e+00
Epoch 55/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0082 - probe_loss: 0.0082 - probe_1_loss: 0.0000e+00
Epoch 56/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0083 - probe_loss: 0.0083 - probe_1_loss: 0.0000e+00
Epoch 57/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0112 - probe_loss: 0.0112 - probe_1_loss: 0.0000e+00
Epoch 58/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0081 - probe_loss: 0.0081 - probe_1_loss: 0.0000e+00
Epoch 59/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0082 - probe_loss: 0.0082 - probe_1_loss: 0.0000e+00
Epoch 60/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0100 - probe_loss: 0.0100 - probe_1_loss: 0.0000e+00
Epoch 61/100
10/10 [==============================] - 0s 21ms/step - loss: 0.0084 - probe_loss: 0.0084 - probe_1_loss: 0.0000e+00
Epoch 62/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0081 - probe_loss: 0.0081 - probe_1_loss: 0.0000e+00
Epoch 63/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0108 - probe_loss: 0.0108 - probe_1_loss: 0.0000e+00
Epoch 64/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0078 - probe_loss: 0.0078 - probe_1_loss: 0.0000e+00
Epoch 65/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0080 - probe_loss: 0.0080 - probe_1_loss: 0.0000e+00
Epoch 66/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0107 - probe_loss: 0.0107 - probe_1_loss: 0.0000e+00
Epoch 67/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0077 - probe_loss: 0.0077 - probe_1_loss: 0.0000e+00
Epoch 68/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0078 - probe_loss: 0.0078 - probe_1_loss: 0.0000e+00
Epoch 69/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0108 - probe_loss: 0.0108 - probe_1_loss: 0.0000e+00
Epoch 70/100
10/10 [==============================] - 0s 20ms/step - loss: 0.0076 - probe_loss: 0.0076 - probe_1_loss: 0.0000e+00
Epoch 71/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0078 - probe_loss: 0.0078 - probe_1_loss: 0.0000e+00
Epoch 72/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0101 - probe_loss: 0.0101 - probe_1_loss: 0.0000e+00
Epoch 73/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0075 - probe_loss: 0.0075 - probe_1_loss: 0.0000e+00
Epoch 74/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0078 - probe_loss: 0.0078 - probe_1_loss: 0.0000e+00
Epoch 75/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0102 - probe_loss: 0.0102 - probe_1_loss: 0.0000e+00
Epoch 76/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0074 - probe_loss: 0.0074 - probe_1_loss: 0.0000e+00
Epoch 77/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0084 - probe_loss: 0.0084 - probe_1_loss: 0.0000e+00
Epoch 78/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0093 - probe_loss: 0.0093 - probe_1_loss: 0.0000e+00
Epoch 79/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0074 - probe_loss: 0.0074 - probe_1_loss: 0.0000e+00
Epoch 80/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0088 - probe_loss: 0.0088 - probe_1_loss: 0.0000e+00
Epoch 81/100
10/10 [==============================] - 0s 21ms/step - loss: 0.0095 - probe_loss: 0.0095 - probe_1_loss: 0.0000e+00
Epoch 82/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0072 - probe_loss: 0.0072 - probe_1_loss: 0.0000e+00
Epoch 83/100
10/10 [==============================] - 0s 20ms/step - loss: 0.0076 - probe_loss: 0.0076 - probe_1_loss: 0.0000e+00
Epoch 84/100
10/10 [==============================] - 0s 20ms/step - loss: 0.0098 - probe_loss: 0.0098 - probe_1_loss: 0.0000e+00
Epoch 85/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0071 - probe_loss: 0.0071 - probe_1_loss: 0.0000e+00
Epoch 86/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0077 - probe_loss: 0.0077 - probe_1_loss: 0.0000e+00
Epoch 87/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0099 - probe_loss: 0.0099 - probe_1_loss: 0.0000e+00
Epoch 88/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0070 - probe_loss: 0.0070 - probe_1_loss: 0.0000e+00
Epoch 89/100
10/10 [==============================] - 0s 20ms/step - loss: 0.0082 - probe_loss: 0.0082 - probe_1_loss: 0.0000e+00
Epoch 90/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0099 - probe_loss: 0.0099 - probe_1_loss: 0.0000e+00
Epoch 91/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0070 - probe_loss: 0.0070 - probe_1_loss: 0.0000e+00
Epoch 92/100
10/10 [==============================] - 0s 20ms/step - loss: 0.0073 - probe_loss: 0.0073 - probe_1_loss: 0.0000e+00
Epoch 93/100
10/10 [==============================] - 0s 21ms/step - loss: 0.0098 - probe_loss: 0.0098 - probe_1_loss: 0.0000e+00
Epoch 94/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0069 - probe_loss: 0.0069 - probe_1_loss: 0.0000e+00
Epoch 95/100
10/10 [==============================] - 0s 20ms/step - loss: 0.0078 - probe_loss: 0.0078 - probe_1_loss: 0.0000e+00
Epoch 96/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0094 - probe_loss: 0.0094 - probe_1_loss: 0.0000e+00
Epoch 97/100
10/10 [==============================] - 0s 23ms/step - loss: 0.0070 - probe_loss: 0.0070 - probe_1_loss: 0.0000e+00
Epoch 98/100
10/10 [==============================] - 0s 24ms/step - loss: 0.0074 - probe_loss: 0.0074 - probe_1_loss: 0.0000e+00
Epoch 99/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0101 - probe_loss: 0.0101 - probe_1_loss: 0.0000e+00
Epoch 100/100
10/10 [==============================] - 0s 22ms/step - loss: 0.0068 - probe_loss: 0.0068 - probe_1_loss: 0.0000e+00
Simulation finished in 0:00:03

After training we run the same test on the network and plot the output. Now we can clearly see that the output of the network is closest to the ideal output, C0.

[6]:
output = sim.data[out_filtered]
plt.figure()
plt.plot(sim.trange(), nengo.spa.similarity(output[0], output_vocab))
plt.legend(output_vocab.keys, loc=4)
plt.ylim([-1, 1])
plt.xlabel("t [s]")
plt.ylabel("dot product")
[6]:
Text(0, 0.5, 'dot product')
../_images/deeplearning_circularconvolution-softlif_11_1.png

In a future example we will show how to integrate these training improvements into a larger network and improve the performance of the network as a whole.