Source code for nengo_spa.networks.matrix_multiplication

import nengo
import numpy as np


[docs]class MatrixMult(nengo.Network): """Computes the matrix product A*B. Both matrices need to be two dimensional. See the Nengo :doc:`Matrix Multiplication example <nengo:examples/advanced/matrix-multiplication>` for a description of the network internals. Parameters ---------- n_neurons : int Number of neurons used per product of two scalars. .. note:: If an odd number of neurons is given, one less neuron will be used per product to obtain an even number. This is due to the implementation the `.Product` network. shape_left : tuple Shape of the A input matrix. shape_right : tuple Shape of the B input matrix. **kwargs : dict Keyword arguments to pass through to the `nengo.Network` constructor. Attributes ---------- input_left : nengo.Node The left matrix (A) to multiply. input_right : nengo.Node The left matrix (A) to multiply. C : nengo.networks.Product The product network doing the matrix multiplication. output : nengo.node The resulting matrix result. """ def __init__(self, n_neurons, shape_left, shape_right, **kwargs): if len(shape_left) != 2: raise ValueError(f"Shape {shape_left} is not two dimensional.") if len(shape_right) != 2: raise ValueError(f"Shape {shape_right} is not two dimensional.") if shape_left[1] != shape_right[0]: raise ValueError( f"Matrix dimensions {shape_left} and {shape_right} are incompatible" ) super().__init__(**kwargs) size_left = np.prod(shape_left) size_right = np.prod(shape_right) with self: self.input_left = nengo.Node(size_in=size_left) self.input_right = nengo.Node(size_in=size_right) # The C matrix is composed of populations that each contain # one element of A (left) and one element of B (right). # These elements will be multiplied together in the next step. size_c = size_left * shape_right[1] self.C = nengo.networks.Product(n_neurons, size_c) # Determine the transformation matrices to get the correct pairwise # products computed. This looks a bit like black magic but if # you manually try multiplying two matrices together, you can see # the underlying pattern. Basically, we need to build up D1*D2*D3 # pairs of numbers in C to compute the product of. If i,j,k are # the indexes into the D1*D2*D3 products, we want to compute the # product # of element (i,j) in A with the element (j,k) in B. The # index in # A of (i,j) is j+i*D2 and the index in B of (j,k) is # k+j*D3. The index in C is j+k*D2+i*D2*D3, multiplied by 2 since # there are # two values per ensemble. We add 1 to the B index so # it goes into # the second value in the ensemble. transform_left = np.zeros((size_c, size_left)) transform_right = np.zeros((size_c, size_right)) for i, j, k in np.ndindex(shape_left[0], *shape_right): c_index = j + k * shape_right[0] + i * size_right transform_left[c_index][j + i * shape_right[0]] = 1 transform_right[c_index][k + j * shape_right[1]] = 1 nengo.Connection( self.input_left, self.C.input_a, transform=transform_left, synapse=None ) nengo.Connection( self.input_right, self.C.input_b, transform=transform_right, synapse=None, ) # Now do the appropriate summing size_output = shape_left[0] * shape_right[1] self.output = nengo.Node(size_in=size_output) # The mapping for this transformation is much easier, since we want # to combine D2 pairs of elements (we sum D2 products together) transform_c = np.zeros((size_output, size_c)) for i in range(size_c): transform_c[i // shape_right[0]][i] = 1 nengo.Connection( self.C.output, self.output, transform=transform_c, synapse=None )