Source code for nengo.spa.basalganglia

import numpy as np

import nengo
from nengo.exceptions import ValidationError
from nengo.spa.action_objects import DotProduct, Source
from nengo.spa.module import Module
from nengo.utils.compat import is_number


[docs]class BasalGanglia(Module): """A basal ganglia, performing action selection on a set of given actions. See `.networks.BasalGanglia` for more details. Parameters ---------- actions : Actions The actions to choose between. input_synapse : float, optional (Default: 0.002) The synaptic filter on all input connections. label : str, optional (Default: None) A name for the ensemble. Used for debugging and visualization. seed : int, optional (Default: None) The seed used for random number generation. add_to_container : bool, optional (Default: None) Determines if this Network will be added to the current container. If None, will be true if currently within a Network. """ def __init__(self, actions, input_synapse=0.002, label=None, seed=None, add_to_container=None): self.actions = actions self.input_synapse = input_synapse self._bias = None Module.__init__(self, label, seed, add_to_container) nengo.networks.BasalGanglia(dimensions=self.actions.count, net=self) @property def bias(self): """Create a bias node, when needed.""" if self._bias is None: with self: self._bias = nengo.Node([1], label="basal ganglia bias") return self._bias
[docs] def on_add(self, spa): """Form the connections into the BG to compute the utilty values. Each action's condition variable contains the set of computations needed for that action's utility value, which is the input to the basal ganglia. """ Module.on_add(self, spa) self.spa = spa self.actions.process(spa) # parse the actions for i, action in enumerate(self.actions.actions): cond = action.condition.expression # the basal ganglia hangles the condition part of the action; # the effect is handled by the thalamus # Note: A Source is an output from a module, and a Symbol is # text that can be parsed to be a SemanticPointer for c in cond.items: if isinstance(c, DotProduct): if ((isinstance(c.item1, Source) and c.item1.inverted) or (isinstance(c.item2, Source) and c.item2.inverted)): raise NotImplementedError( "Inversion in subexpression '%s' from action '%s' " "is not supported by the Basal Ganglia." % (c, action)) if isinstance(c.item1, Source): if isinstance(c.item2, Source): # dot product between two different sources self.add_compare_input(i, c.item1, c.item2, c.scale) else: self.add_dot_input(i, c.item1, c.item2, c.scale) else: # enforced in DotProduct constructor assert isinstance(c.item2, Source) self.add_dot_input(i, c.item2, c.item1, c.scale) elif isinstance(c, Source): self.add_scalar_input(i, c) elif is_number(c): self.add_bias_input(i, c) else: raise NotImplementedError( "Subexpression '%s' from action '%s' is not supported " "by the Basal Ganglia." % (c, action))
[docs] def add_bias_input(self, index, value): """Make an input that is just a fixed scalar value. Parameters ---------- index : int the index of the action value : float or int the fixed utility value to add """ with self.spa: nengo.Connection(self.bias, self.input[index:index+1], transform=value, synapse=self.input_synapse)
[docs] def add_compare_input(self, index, source1, source2, scale): """Make an input that is the dot product of two different sources. This would be used for an input action such as ``dot(vision, memory)``. Each source might be transformed before being compared. If the two sources have different vocabularies, we use the vocabulary of the first one for comparison. Parameters ---------- index : int The index of the action. source1 : Source The first module output to read from. source2 : Source The second module output to read from. scale : float A scaling factor to be applied to the result. """ raise NotImplementedError("Compare between two sources will never be " "implemented as discussed in " "https://github.com/nengo/nengo/issues/759")
[docs] def add_dot_input(self, index, source, symbol, scale): """Make an input that is the dot product of a Source and a Symbol. This would be used for an input action such as ``dot(vision, A)``. The source may have a transformation applied first. Parameters ---------- index : int The index of the action. source : Source The module output to read from. symbol : Source The semantic pointer to compute the dot product with. scale : float A scaling factor to be applied to the result. """ output, vocab = self.spa.get_module_output(source.name) # the first transformation, to handle dot(vision*A, B) t1 = vocab.parse(source.transform.symbol).get_convolution_matrix() # the linear transform to compute the fixed dot product t2 = np.array([vocab.parse(symbol.symbol).v*scale]) transform = np.dot(t2, t1) with self.spa: nengo.Connection(output, self.input[index:index+1], transform=transform, synapse=self.input_synapse)
[docs] def add_scalar_input(self, index, source): """Add a scalar input that will vary over time. This is used for the output of the `.Compare` module. Parameters ---------- index : int The index of the action. source : Source The module output to read from. """ output, _ = self.spa.get_module_output(source.name) if output.size_out != 1: raise NotImplementedError( "Only 1-dimensional sources can be scalar inputs") try: scale = float(eval(source.transform.symbol)) except ValueError: raise ValidationError("Transform must be scalar; got '%s'" % source.transform.symbol, attr='source.transform') with self.spa: nengo.Connection(output, self.input[index:index+1], transform=scale, synapse=self.input_synapse)