Multiplication

A quick and easy example to start off with is to build a toy model which takes in two numbers, and outputs the result. Although the model doesn’t accomplish anything significant the same techniques can be used to model and train much larger and complex networks.

Numpy is seeded to allow deterministic results, this seeding has no relevance to the architecture or the training of the network

[1]:
%matplotlib inline
import nengo
import tensorflow as tf
import nengo_dl
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)

Architecture

We connect two input nodes (i_1, i_2), both of which generate random numbers, to ensemble a. Then a is connected to a second ensemble b, which we probe for the output.

[2]:
with nengo.Network() as net:

    net.config[nengo.Ensemble].neuron_type = nengo.RectifiedLinear()
    net.config[nengo.Ensemble].gain = nengo.dists.Choice([1])
    net.config[nengo.Ensemble].bias = nengo.dists.Uniform(-1, 1)
    net.config[nengo.Connection].synapse = None

    i_1 = nengo.Node(output=lambda t: np.random.random())
    i_2 = nengo.Node(output=lambda t: np.random.random())

    a = nengo.Ensemble(100, 2)
    b = nengo.Ensemble(100, 1)
    nengo.Connection(i_1, a[0])
    nengo.Connection(i_2, a[1])
    nengo.Connection(a, b, function=lambda x: [0])

    i_1_probe = nengo.Probe(i_1)
    i_2_probe = nengo.Probe(i_2)
    output_probe = nengo.Probe(b)

Before we train the network the output is approximately zero, since that is the function we specified on the connection from a to b. However we don’t want that output, so we need to train the network to multiply the inputs.

[3]:
n_steps = 50
minibatch_size = 256
# Showing the output of the model pre training
with nengo_dl.Simulator(net) as sim:
    sim.run_steps(n_steps)
    true_value = np.multiply(sim.data[i_1_probe], sim.data[i_2_probe])
    fig = plt.figure()
    fig.suptitle("Pre-Training")
    plt.plot(sim.data[output_probe], "g", label="predicted value")
    plt.plot(true_value, "m", label="true value")
    plt.legend()
Build finished in 0:00:00
Optimization finished in 0:00:00
Construction finished in 0:00:00
2022-01-21 11:53:06.875694: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.

Simulation finished in 0:00:01
../_images/deeplearning_multiplication_5_3.png

To train the network we generate training feeds which consist of two batches of random numbers (the inputs) and then the result of those batches multiplied together (the outputs). Additionally we generate some test data to easily track the progress of the network throughout training.

[4]:
with nengo_dl.Simulator(net, minibatch_size=minibatch_size) as sim:
    # This feed is used as the "test" data
    # It's run through the network after every iteration
    # to allow easy visulization of how the output is changing
    test_inputs = {
        i_1: np.random.uniform(0, 1, size=(minibatch_size, 1, 1)),
        i_2: np.random.uniform(0, 1, size=(minibatch_size, 1, 1)),
    }
    test_targets = {output_probe: np.multiply(test_inputs[i_1], test_inputs[i_2])}

    # running through 10 rounds of training/testing
    outputs = []
    sim.compile(loss="mse", optimizer=tf.compat.v1.train.MomentumOptimizer(5e-2, 0.9))
    for i in range(10):
        # check performance on test set
        sim.step(data=test_inputs)
        print(f"LOSS: {sim.evaluate(test_inputs, test_targets)}")
        outputs.append(sim.data[output_probe].flatten())

        # run training
        input_feed = {
            i_1: np.random.uniform(0, 1, size=(minibatch_size * 5, 1, 1)),
            i_2: np.random.uniform(0, 1, size=(minibatch_size * 5, 1, 1)),
        }
        output_feed = {output_probe: np.multiply(input_feed[i_1], input_feed[i_2])}
        sim.fit(input_feed, output_feed, epochs=12)
        sim.soft_reset(include_probes=True)

    # check final performance on test set
    sim.step(data=test_inputs)
    print(f"LOSS: {sim.evaluate(test_inputs, test_targets)}")
    outputs.append(sim.data[output_probe].flatten())
Build finished in 0:00:00
Optimization finished in 0:00:00
Construction finished in 0:00:00
1/1 [==============================] - 1s 577ms/step - loss: 0.1043 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.1043
LOSS: {'loss': 0.1042896956205368, 'probe_loss': 0.0, 'probe_1_loss': 0.0, 'probe_2_loss': 0.1042896956205368}
Epoch 1/12
5/5 [==============================] - 1s 6ms/step - loss: 0.2337 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.2337
Epoch 2/12
5/5 [==============================] - 0s 5ms/step - loss: 0.3275 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.3275
Epoch 3/12
5/5 [==============================] - 0s 6ms/step - loss: 0.1308 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.1308
Epoch 4/12
5/5 [==============================] - 0s 8ms/step - loss: 0.0706 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.0706
Epoch 5/12
5/5 [==============================] - 0s 8ms/step - loss: 0.0579 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.0579
Epoch 6/12
5/5 [==============================] - 0s 9ms/step - loss: 0.0205 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.0205
Epoch 7/12
5/5 [==============================] - 0s 9ms/step - loss: 0.0147 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.0147
Epoch 8/12
5/5 [==============================] - 0s 10ms/step - loss: 0.0066 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.0066
Epoch 9/12
5/5 [==============================] - 0s 8ms/step - loss: 0.0037 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.0037
Epoch 10/12
5/5 [==============================] - 0s 7ms/step - loss: 0.0022 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.0022
Epoch 11/12
5/5 [==============================] - 0s 7ms/step - loss: 0.0013 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 0.0013
Epoch 12/12
5/5 [==============================] - 0s 9ms/step - loss: 5.4811e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.4811e-04
1/1 [==============================] - 0s 25ms/step - loss: 2.6955e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 2.6955e-04
LOSS: {'loss': 0.0002695529256016016, 'probe_loss': 0.0, 'probe_1_loss': 0.0, 'probe_2_loss': 0.0002695529256016016}
Epoch 1/12
5/5 [==============================] - 0s 4ms/step - loss: 4.0840e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 4.0840e-04
Epoch 2/12
5/5 [==============================] - 0s 5ms/step - loss: 2.4546e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 2.4546e-04
Epoch 3/12
5/5 [==============================] - 0s 4ms/step - loss: 2.5043e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 2.5043e-04
Epoch 4/12
5/5 [==============================] - 0s 4ms/step - loss: 2.0081e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 2.0081e-04
Epoch 5/12
5/5 [==============================] - 0s 5ms/step - loss: 1.5814e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.5814e-04
Epoch 6/12
5/5 [==============================] - 0s 6ms/step - loss: 1.1901e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.1901e-04
Epoch 7/12
5/5 [==============================] - 0s 7ms/step - loss: 1.1392e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.1392e-04
Epoch 8/12
5/5 [==============================] - 0s 7ms/step - loss: 1.0630e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.0630e-04
Epoch 9/12
5/5 [==============================] - 0s 8ms/step - loss: 1.0858e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.0858e-04
Epoch 10/12
5/5 [==============================] - 0s 8ms/step - loss: 1.0353e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.0353e-04
Epoch 11/12
5/5 [==============================] - 0s 10ms/step - loss: 1.0179e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.0179e-04
Epoch 12/12
5/5 [==============================] - 0s 9ms/step - loss: 1.0156e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.0156e-04
1/1 [==============================] - 0s 20ms/step - loss: 1.1966e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.1966e-04
LOSS: {'loss': 0.00011965574230998755, 'probe_loss': 0.0, 'probe_1_loss': 0.0, 'probe_2_loss': 0.00011965574230998755}
Epoch 1/12
5/5 [==============================] - 0s 4ms/step - loss: 1.1315e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.1315e-04
Epoch 2/12
5/5 [==============================] - 0s 5ms/step - loss: 1.1112e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.1112e-04
Epoch 3/12
5/5 [==============================] - 0s 6ms/step - loss: 1.0973e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.0973e-04
Epoch 4/12
5/5 [==============================] - 0s 5ms/step - loss: 1.0618e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.0618e-04
Epoch 5/12
5/5 [==============================] - 0s 4ms/step - loss: 1.0484e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.0484e-04
Epoch 6/12
5/5 [==============================] - 0s 5ms/step - loss: 1.0146e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.0146e-04
Epoch 7/12
5/5 [==============================] - 0s 5ms/step - loss: 1.0535e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.0535e-04
Epoch 8/12
5/5 [==============================] - 0s 5ms/step - loss: 1.0335e-04 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 1.0335e-04
Epoch 9/12
5/5 [==============================] - 0s 5ms/step - loss: 9.7415e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 9.7415e-05
Epoch 10/12
5/5 [==============================] - 0s 6ms/step - loss: 9.5973e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 9.5973e-05
Epoch 11/12
5/5 [==============================] - 0s 5ms/step - loss: 9.3890e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 9.3890e-05
Epoch 12/12
5/5 [==============================] - 0s 5ms/step - loss: 9.3631e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 9.3631e-05
1/1 [==============================] - 0s 18ms/step - loss: 8.6050e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.6050e-05
LOSS: {'loss': 8.604974573245272e-05, 'probe_loss': 0.0, 'probe_1_loss': 0.0, 'probe_2_loss': 8.604974573245272e-05}
Epoch 1/12
5/5 [==============================] - 0s 4ms/step - loss: 8.9583e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.9583e-05
Epoch 2/12
5/5 [==============================] - 0s 4ms/step - loss: 8.8900e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.8900e-05
Epoch 3/12
5/5 [==============================] - 0s 5ms/step - loss: 9.5365e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 9.5365e-05
Epoch 4/12
5/5 [==============================] - 0s 4ms/step - loss: 8.9794e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.9794e-05
Epoch 5/12
5/5 [==============================] - 0s 5ms/step - loss: 8.5178e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.5178e-05
Epoch 6/12
5/5 [==============================] - 0s 5ms/step - loss: 8.7636e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.7636e-05
Epoch 7/12
5/5 [==============================] - 0s 4ms/step - loss: 8.3952e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.3952e-05
Epoch 8/12
5/5 [==============================] - 0s 4ms/step - loss: 8.2921e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.2921e-05
Epoch 9/12
5/5 [==============================] - 0s 4ms/step - loss: 8.3131e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.3131e-05
Epoch 10/12
5/5 [==============================] - 0s 4ms/step - loss: 8.4504e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.4504e-05
Epoch 11/12
5/5 [==============================] - 0s 6ms/step - loss: 8.3002e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.3002e-05
Epoch 12/12
5/5 [==============================] - 0s 7ms/step - loss: 8.0583e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.0583e-05
1/1 [==============================] - 0s 18ms/step - loss: 7.4353e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 7.4353e-05
LOSS: {'loss': 7.435342558892444e-05, 'probe_loss': 0.0, 'probe_1_loss': 0.0, 'probe_2_loss': 7.435342558892444e-05}
Epoch 1/12
5/5 [==============================] - 0s 4ms/step - loss: 8.6869e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.6869e-05
Epoch 2/12
5/5 [==============================] - 0s 5ms/step - loss: 8.6029e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.6029e-05
Epoch 3/12
5/5 [==============================] - 0s 5ms/step - loss: 8.5612e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.5612e-05
Epoch 4/12
5/5 [==============================] - 0s 7ms/step - loss: 8.5974e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.5974e-05
Epoch 5/12
5/5 [==============================] - 0s 8ms/step - loss: 8.2408e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.2408e-05
Epoch 6/12
5/5 [==============================] - 0s 6ms/step - loss: 8.1877e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.1877e-05
Epoch 7/12
5/5 [==============================] - 0s 5ms/step - loss: 8.6833e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.6833e-05
Epoch 8/12
5/5 [==============================] - 0s 5ms/step - loss: 8.1718e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.1718e-05
Epoch 9/12
5/5 [==============================] - 0s 5ms/step - loss: 8.2583e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.2583e-05
Epoch 10/12
5/5 [==============================] - 0s 7ms/step - loss: 8.0660e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.0660e-05
Epoch 11/12
5/5 [==============================] - 0s 5ms/step - loss: 8.1962e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.1962e-05
Epoch 12/12
5/5 [==============================] - 0s 5ms/step - loss: 8.0052e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 8.0052e-05
1/1 [==============================] - 0s 20ms/step - loss: 6.7072e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.7072e-05
LOSS: {'loss': 6.707150168949738e-05, 'probe_loss': 0.0, 'probe_1_loss': 0.0, 'probe_2_loss': 6.707150168949738e-05}
Epoch 1/12
5/5 [==============================] - 0s 4ms/step - loss: 6.5084e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.5084e-05
Epoch 2/12
5/5 [==============================] - 0s 4ms/step - loss: 6.4651e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.4651e-05
Epoch 3/12
5/5 [==============================] - 0s 4ms/step - loss: 6.4055e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.4055e-05
Epoch 4/12
5/5 [==============================] - 0s 4ms/step - loss: 6.5517e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.5517e-05
Epoch 5/12
5/5 [==============================] - 0s 5ms/step - loss: 6.5879e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.5879e-05
Epoch 6/12
5/5 [==============================] - 0s 5ms/step - loss: 6.8302e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.8302e-05
Epoch 7/12
5/5 [==============================] - 0s 6ms/step - loss: 6.4038e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.4038e-05
Epoch 8/12
5/5 [==============================] - 0s 6ms/step - loss: 6.2072e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.2072e-05
Epoch 9/12
5/5 [==============================] - 0s 7ms/step - loss: 6.2071e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.2071e-05
Epoch 10/12
5/5 [==============================] - 0s 6ms/step - loss: 6.2098e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.2098e-05
Epoch 11/12
5/5 [==============================] - 0s 10ms/step - loss: 6.0936e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.0936e-05
Epoch 12/12
5/5 [==============================] - 0s 5ms/step - loss: 6.2646e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.2646e-05
1/1 [==============================] - 0s 22ms/step - loss: 6.6837e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.6837e-05
LOSS: {'loss': 6.683747778879479e-05, 'probe_loss': 0.0, 'probe_1_loss': 0.0, 'probe_2_loss': 6.683747778879479e-05}
Epoch 1/12
5/5 [==============================] - 0s 4ms/step - loss: 7.7861e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 7.7861e-05
Epoch 2/12
5/5 [==============================] - 0s 4ms/step - loss: 7.9567e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 7.9567e-05
Epoch 3/12
5/5 [==============================] - 0s 4ms/step - loss: 7.0654e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 7.0654e-05
Epoch 4/12
5/5 [==============================] - 0s 4ms/step - loss: 7.5837e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 7.5837e-05
Epoch 5/12
5/5 [==============================] - 0s 5ms/step - loss: 7.4511e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 7.4511e-05
Epoch 6/12
5/5 [==============================] - 0s 5ms/step - loss: 7.1558e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 7.1558e-05
Epoch 7/12
5/5 [==============================] - 0s 5ms/step - loss: 7.0029e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 7.0029e-05
Epoch 8/12
5/5 [==============================] - 0s 5ms/step - loss: 6.9910e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.9910e-05
Epoch 9/12
5/5 [==============================] - 0s 5ms/step - loss: 6.7262e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.7262e-05
Epoch 10/12
5/5 [==============================] - 0s 6ms/step - loss: 6.7590e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.7590e-05
Epoch 11/12
5/5 [==============================] - 0s 5ms/step - loss: 6.7905e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.7905e-05
Epoch 12/12
5/5 [==============================] - 0s 6ms/step - loss: 6.6173e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.6173e-05
1/1 [==============================] - 0s 18ms/step - loss: 5.3877e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.3877e-05
LOSS: {'loss': 5.387669807532802e-05, 'probe_loss': 0.0, 'probe_1_loss': 0.0, 'probe_2_loss': 5.387669807532802e-05}
Epoch 1/12
5/5 [==============================] - 0s 4ms/step - loss: 6.5790e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.5790e-05
Epoch 2/12
5/5 [==============================] - 0s 5ms/step - loss: 6.5030e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.5030e-05
Epoch 3/12
5/5 [==============================] - 0s 4ms/step - loss: 6.1451e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.1451e-05
Epoch 4/12
5/5 [==============================] - 0s 4ms/step - loss: 6.2528e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.2528e-05
Epoch 5/12
5/5 [==============================] - 0s 3ms/step - loss: 6.0573e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.0573e-05
Epoch 6/12
5/5 [==============================] - 0s 5ms/step - loss: 6.6364e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.6364e-05
Epoch 7/12
5/5 [==============================] - 0s 6ms/step - loss: 6.3155e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.3155e-05
Epoch 8/12
5/5 [==============================] - 0s 5ms/step - loss: 6.6353e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.6353e-05
Epoch 9/12
5/5 [==============================] - 0s 6ms/step - loss: 6.7410e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.7410e-05
Epoch 10/12
5/5 [==============================] - 0s 6ms/step - loss: 6.3646e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.3646e-05
Epoch 11/12
5/5 [==============================] - 0s 6ms/step - loss: 6.0384e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.0384e-05
Epoch 12/12
5/5 [==============================] - 0s 6ms/step - loss: 5.9496e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.9496e-05
1/1 [==============================] - 0s 22ms/step - loss: 5.7556e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.7556e-05
LOSS: {'loss': 5.7556455431040376e-05, 'probe_loss': 0.0, 'probe_1_loss': 0.0, 'probe_2_loss': 5.7556455431040376e-05}
Epoch 1/12
5/5 [==============================] - 0s 5ms/step - loss: 6.0345e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 6.0345e-05
Epoch 2/12
5/5 [==============================] - 0s 5ms/step - loss: 5.7973e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.7973e-05
Epoch 3/12
5/5 [==============================] - 0s 4ms/step - loss: 5.5279e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.5279e-05
Epoch 4/12
5/5 [==============================] - 0s 5ms/step - loss: 5.4666e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.4666e-05
Epoch 5/12
5/5 [==============================] - 0s 5ms/step - loss: 5.6673e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.6673e-05
Epoch 6/12
5/5 [==============================] - 0s 5ms/step - loss: 5.3452e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.3452e-05
Epoch 7/12
5/5 [==============================] - 0s 4ms/step - loss: 5.2916e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.2916e-05
Epoch 8/12
5/5 [==============================] - 0s 4ms/step - loss: 5.3558e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.3558e-05
Epoch 9/12
5/5 [==============================] - 0s 4ms/step - loss: 5.2590e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.2590e-05
Epoch 10/12
5/5 [==============================] - 0s 4ms/step - loss: 5.2151e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.2151e-05
Epoch 11/12
5/5 [==============================] - 0s 4ms/step - loss: 5.2926e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.2926e-05
Epoch 12/12
5/5 [==============================] - 0s 6ms/step - loss: 5.2411e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.2411e-05
1/1 [==============================] - 0s 19ms/step - loss: 4.8881e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 4.8881e-05
LOSS: {'loss': 4.888128751190379e-05, 'probe_loss': 0.0, 'probe_1_loss': 0.0, 'probe_2_loss': 4.888128751190379e-05}
Epoch 1/12
5/5 [==============================] - 0s 4ms/step - loss: 5.4918e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.4918e-05
Epoch 2/12
5/5 [==============================] - 0s 4ms/step - loss: 5.4579e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.4579e-05
Epoch 3/12
5/5 [==============================] - 0s 4ms/step - loss: 5.3743e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.3743e-05
Epoch 4/12
5/5 [==============================] - 0s 4ms/step - loss: 5.3452e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.3452e-05
Epoch 5/12
5/5 [==============================] - 0s 4ms/step - loss: 5.3543e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.3543e-05
Epoch 6/12
5/5 [==============================] - 0s 6ms/step - loss: 5.5079e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.5079e-05
Epoch 7/12
5/5 [==============================] - 0s 6ms/step - loss: 5.8177e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.8177e-05
Epoch 8/12
5/5 [==============================] - 0s 7ms/step - loss: 5.9010e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.9010e-05
Epoch 9/12
5/5 [==============================] - 0s 8ms/step - loss: 5.7728e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.7728e-05
Epoch 10/12
5/5 [==============================] - 0s 7ms/step - loss: 5.2985e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.2985e-05
Epoch 11/12
5/5 [==============================] - 0s 8ms/step - loss: 5.0936e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.0936e-05
Epoch 12/12
5/5 [==============================] - 0s 6ms/step - loss: 5.1032e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 5.1032e-05
1/1 [==============================] - 0s 17ms/step - loss: 4.5568e-05 - probe_loss: 0.0000e+00 - probe_1_loss: 0.0000e+00 - probe_2_loss: 4.5568e-05
LOSS: {'loss': 4.5567761844722554e-05, 'probe_loss': 0.0, 'probe_1_loss': 0.0, 'probe_2_loss': 4.5567761844722554e-05}

We visualize the results by plotting the pre-trained, trained and ideal outputs next to each other

[5]:
fig = plt.figure()
fig.suptitle("Pre/Post Training Comparison")
plt.plot(outputs[0][:50], "r", label="pre-training")
plt.plot(outputs[10][:50], "k", label="trained")
plt.plot(test_targets[output_probe].flatten()[:50], "m", label="ideal")
plt.legend()
[5]:
<matplotlib.legend.Legend at 0x7f00dd1d27f0>
../_images/deeplearning_multiplication_9_1.png