Source code for nengo.networks.workingmemory

import numpy as np

import nengo
from nengo.exceptions import ObsoleteError
from nengo.networks import EnsembleArray


[docs]class InputGatedMemory(nengo.Network): """Stores a given vector in memory, with input controlled by a gate. Parameters ---------- n_neurons : int Number of neurons per dimension in the vector. dimensions : int Dimensionality of the vector. feedback : float, optional Strength of the recurrent connection from the memory to itself. difference_gain : float, optional Strength of the connection from the difference ensembles to the memory ensembles. recurrent_synapse : float, optional difference_synapse : Synapse If None, ... **kwargs Keyword arguments passed through to ``nengo.Network`` like 'label' and 'seed'. Attributes ---------- diff : EnsembleArray Represents the difference between the desired vector and the current vector represented by ``mem``. gate : Node With input of 0, the network is not gated, and ``mem`` will be updated to minimize ``diff``. With input greater than 0, the network will be increasingly gated such that ``mem`` will retain its current value, and ``diff`` will be inhibited. input : Node The desired vector. mem : EnsembleArray Integrative population that stores the vector. output : Node The vector currently represented by ``mem``. reset : Node With positive input, the ``mem`` population will be inhibited, effectively wiping out the vector currently being remembered. """ def __init__( self, n_neurons, dimensions, feedback=1.0, difference_gain=1.0, recurrent_synapse=0.1, difference_synapse=None, **kwargs ): if "net" in kwargs: raise ObsoleteError("The 'net' argument is no longer supported.") kwargs.setdefault("label", "Input gated memory") super().__init__(**kwargs) if difference_synapse is None: difference_synapse = recurrent_synapse n_total_neurons = n_neurons * dimensions with self: # integrator to store value self.mem = EnsembleArray(n_neurons, dimensions, label="mem") nengo.Connection( self.mem.output, self.mem.input, transform=feedback, synapse=recurrent_synapse, ) # calculate difference between stored value and input self.diff = EnsembleArray(n_neurons, dimensions, label="diff") nengo.Connection(self.mem.output, self.diff.input, transform=-1) # feed difference into integrator nengo.Connection( self.diff.output, self.mem.input, transform=difference_gain, synapse=difference_synapse, ) # gate difference (if gate==0, update stored value, # otherwise retain stored value) self.gate = nengo.Node(size_in=1) self.diff.add_neuron_input() nengo.Connection( self.gate, self.diff.neuron_input, transform=np.ones((n_total_neurons, 1)) * -10, synapse=None, ) # reset input (if reset=1, remove all values, and set to 0) self.reset = nengo.Node(size_in=1) nengo.Connection( self.reset, self.mem.add_neuron_input(), transform=np.ones((n_total_neurons, 1)) * -3, synapse=None, ) self.input = self.diff.input self.output = self.mem.output