Source code for nengo_spa.modules.thalamus

import nengo
import numpy as np
from nengo.dists import Uniform
from nengo.params import Default, IntParam, NumberParam
from nengo.synapses import Lowpass, SynapseParam

from nengo_spa.ast import dynamic
from nengo_spa.modules.scalar import Scalar
from nengo_spa.modules.state import State
from nengo_spa.network import Network
from nengo_spa.types import TScalar


[docs]class Thalamus(Network): """Inhibits non-selected actions. The thalamus is intended to work in tandem with a `.BasalGanglia` module. It converts basal ganglia output into a signal with (approximately) 1 for the selected action and 0 elsewhere. In order to suppress low responses and strengthen high responses, a constant bias is added to each dimension (i.e., action), and dimensions mutually inhibit each other. Additionally, the ensemble representing each dimension is created with positive encoders and can be assigned positive x-intercepts to threshold low responses. Parameters ---------- neurons_action : int, optional (Default: 50) Number of neurons per action to represent the selection. threshold_action : float, optional (Default: 0.2) Minimum value for action representation. mutual_inhibit : float, optional (Default: 1.0) Strength of inhibition between actions. route_inhibit : float, optional (Default: 3.0) Strength of inhibition for unchosen actions. synapse_inhibit : float, optional (Default: 0.008) Synaptic filter to apply for inhibition between actions. synapse_bg : float, optional (Default: 0.008) Synaptic filter for connection between basal ganglia and thalamus. synapse_direct : float, optional (Default: 0.01) Synaptic filter for direct outputs. neurons_channel_dim : int, optional (Default: 50) Number of neurons per routing channel dimension. synapse_channel : float, optional (Default: 0.01) Synaptic filter for channel inputs and outputs. neurons_gate : int, optional (Default: 40) Number of neurons per gate. threshold_gate : float, optional (Default: 0.3) Minimum value for gating neurons. synapse_to-gate : float, optional (Default: 0.002) Synaptic filter for controlling a gate. **kwargs : dict Keyword arguments passed through to `nengo_spa.Network`. Attributes ---------- actions : nengo.networks.EnsembleArray Each ensemble represents one dimension (action). bias : nengo.Node The constant bias injected in each *actions* ensemble. input : nengo.Node Input to the *actions* ensembles. output : nengo.Node Output from the *actions* ensembles. """ neurons_action = IntParam("neurons_action", default=50) threshold_action = NumberParam("threshold_action", default=0.2) mutual_inhibit = NumberParam("mutual_inhibit", default=1.0) route_inhibit = NumberParam("route_inhibit", default=3.0) synapse_inhibit = SynapseParam("synapse_inhibit", default=Lowpass(0.008)) synapse_bg = SynapseParam("synapse_bg", default=Lowpass(0.008)) neurons_channel_dim = IntParam("neurons_channel_dim", default=50) synapse_channel = SynapseParam("synapse_channel", default=Lowpass(0.01)) neurons_gate = IntParam("neurons_gate", default=40) threshold_gate = NumberParam("threshold_gate", default=0.3) synapse_to_gate = SynapseParam("synapse_to_gate", default=Lowpass(0.002)) def __init__( self, action_count, neurons_action=Default, threshold_action=Default, mutual_inhibit=Default, route_inhibit=Default, synapse_inhibit=Default, synapse_bg=Default, neurons_channel_dim=Default, synapse_channel=Default, neurons_gate=Default, threshold_gate=Default, synapse_to_gate=Default, **kwargs ): super(Thalamus, self).__init__(**kwargs) self.action_count = action_count self.neurons_action = neurons_action self.mutual_inhibit = mutual_inhibit self.route_inhibit = route_inhibit self.synapse_inhibit = synapse_inhibit self.threshold_action = threshold_action self.neurons_channel_dim = neurons_channel_dim self.synapse_channel = synapse_channel self.neurons_gate = neurons_gate self.threshold_gate = threshold_gate self.synapse_to_gate = synapse_to_gate self.synapse_bg = synapse_bg self.gates = {} # gating ensembles per action (created as needed) self.channels = [] # channels to pass data between networks self.gate_in_connections = {} self.gate_out_connections = {} self.channel_out_connections = [] self.fixed_connections = {} self.bg_connection = None with self: self.actions = nengo.networks.EnsembleArray( self.neurons_action, self.action_count, intercepts=nengo.dists.Uniform(self.threshold_action, 1), encoders=nengo.dists.Choice([[1.0]]), label="actions", ) nengo.Connection( self.actions.output, self.actions.input, transform=(np.eye(self.action_count) - 1) * self.mutual_inhibit, ) self.bias = nengo.Node([1], label="thalamus bias") nengo.Connection( self.bias, self.actions.input, transform=np.ones((self.action_count, 1)) ) self.input = self.actions.input self.output = self.actions.output
[docs] def construct_gate(self, index, bias, label=None): """Construct a gate ensemble. The gate neurons have no activity when the action is selected, but are active when the action is not selected. This makes the gate useful for inhibiting ensembles that should only be active when this action is active. Parameters ---------- index : int Index to identify the gate. bias : :class:`nengo.Network` Node providing a bias input of 1. label : str, optional Label for the gate. Returns ------- nengo.Ensemble The constructed gate. """ if label is None: label = "gate[%d]" % index intercepts = Uniform(self.threshold_gate, 1) self.gates[index] = gate = nengo.Ensemble( self.neurons_gate, dimensions=1, intercepts=intercepts, label=label, encoders=[[1]] * self.neurons_gate, ) nengo.Connection(bias, gate, synapse=None) self.gate_in_connections[index] = nengo.Connection( self.actions.ensembles[index], self.gates[index], synapse=self.synapse_to_gate, transform=-1, ) return self.gates[index]
[docs] def construct_channel(self, sink, type_, label=None): """Construct a channel. Channels are an additional neural population in-between a source population and a target population. This allows inhibiting the channel without affecting the source and thus is useful in routing information. Parameters ---------- sink : nengo.base.NengoObject Sink/target that the channel feeds into. type_ : nengo_spa.types.Type Type of the data transmitted through the channel. label : str, optional Label for the channel. Returns ------- :class:`nengo.networks.EnsembleArray` The constructed channel. """ if label is None: label = "channel" if type_ == TScalar: channel = dynamic.ScalarRealization() else: channel = dynamic.StateRealization(vocab=type_.vocab) self.channels.append(channel) self.channel_out_connections.append( nengo.Connection(channel.output, sink, synapse=self.synapse_channel) ) return channel
[docs] def connect_bg(self, bg): """Connect a basal ganglia network to this thalamus.""" self.bg_connection = nengo.Connection( bg.output, self.input, synapse=self.synapse_bg )
[docs] def connect_gate(self, index, channel): """Connect a gate to a channel for information routing. Parameters ---------- index : int Index of the gate to connect. channel : nengo.networks.EnsembleArray Channel to inhibit with the gate. """ if isinstance(channel, Scalar): target = channel.scalar.neurons elif isinstance(channel, State): target = channel.state_ensembles.add_neuron_input() else: raise NotImplementedError() inhibit = [[-self.route_inhibit]] * (target.size_in) self.gate_out_connections[index] = nengo.Connection( self.gates[index], target, transform=inhibit, synapse=self.synapse_inhibit )
[docs] def connect_fixed(self, index, target, transform): """Create connection to route fixed value. Parameters ---------- index : int Index of the action to connect. target : nengo.base.NengoObject Target of the connection. transform : array-like Transform to apply to apply to the connection. """ self.fixed_connections[index] = self.connect( self.actions.ensembles[index], target, transform )
[docs] def connect(self, source, target, transform): """Create connection. The connection will use the thalamus's *synapse_channel*. Parameters ---------- source : nengo.base.NengoObject Source object. target : nengo.base.NengoObject Target object. transform : array-like Transform to apply to the connection. """ return nengo.Connection( source, target, transform=transform, synapse=self.synapse_channel )