Source code for nengo_spa.action_selection

"""Implementation of action syntax."""

from collections import OrderedDict
from collections.abc import Mapping

import nengo

from nengo_spa.ast import dynamic
from nengo_spa.connectors import ModuleInput, RoutedConnection, input_vocab_registry
from nengo_spa.exceptions import SpaActionSelectionError, SpaTypeError
from nengo_spa.types import TScalar


[docs]class ActionSelection(Mapping): """Implements an action selection system with basal ganglia and thalamus. The *ActionSelection* instance has to be used as context manager and each potential action is defined by an `.ifmax` call providing an expression for the utility value and any number of effects (routing of information) to activate when this utility value is highest of all. Attributes ---------- active : ActionSelection Class attribute providing the currently active ActionSelection instance (if any). built : bool Indicates whether the action selection system has been built successfully. bg : nengo.Network Basal ganglia network. Available after the action selection system has been built. thalamus : nengo.Network Thalamus network. Available after the action selection system has been built. See Also -------- nengo_spa.modules.BasalGanglia : Default basal ganglia network nengo_spa.modules.Thalamus : Default thalamus network Examples -------- .. code-block:: python with ActionSelection(): ifmax(dot(state, sym.A), sym.B >> state) ifmax(dot(state, sym.B), sym.C >> state) ifmax(dot(state, sym.C), sym.A >> state) This will route the *B* Semantic Pointer to *state* when *state* is more similar to *A* than any of the other Semantic Pointers. Similarly, *C* will be routed to *state* when *state* is *B*. Once, *state* is *C*, it will be reset to *A* and the cycle begins anew. Further action selection examples: * :ref:`/examples/question-control.ipynb` * :ref:`/examples/spa-parser.ipynb` * :ref:`/examples/spa-sequence.ipynb` * :ref:`/examples/spa-sequence-routed.ipynb` """ active = None def __init__(self): super(ActionSelection, self).__init__() self.built = False self.bias = None self.bg = None self.thalamus = None self._utilities = [] self._actions = [] # Maps labels of actions to the index of that action self._name2idx = OrderedDict() def __enter__(self): assert not self.built if ActionSelection.active is None: ActionSelection.active = self ModuleInput.routed_mode = True else: raise SpaActionSelectionError("Must not nest action selection contexts.") return self def __exit__(self, exc_type, exc_value, traceback): ActionSelection.active = None ModuleInput.routed_mode = False if exc_type is not None: RoutedConnection.free_floating.clear() return self._build() def _build(self): try: if len(RoutedConnection.free_floating) > 0: raise SpaActionSelectionError( "All actions in an action selection context must be part " "of an ifmax call." ) finally: RoutedConnection.free_floating.clear() if len(self._utilities) <= 0: return self.bias = nengo.Node(1.0, label="bias") self.bg = dynamic.BasalGangliaRealization(len(self._utilities)) self.thalamus = dynamic.ThalamusRealization(len(self._utilities)) self.thalamus.connect_bg(self.bg) for index, utility in enumerate(self._utilities): self.bg.connect_input(utility, index=index) for index, action in enumerate(self._actions): for effect in action: if effect.fixed: self.thalamus.connect_fixed( index, effect.sink.input, transform=effect.transform() ) else: self.thalamus.construct_gate(index, self.bias) channel = self.thalamus.construct_channel( effect.sink.input, effect.type ) effect.connect_to(channel.input) self.thalamus.connect_gate(index, channel) self.built = True def __getitem__(self, key): if isinstance(key, str): key = self._name2idx[key] return self._utilities[key] def __iter__(self): # Given not all actions have names, there will actions whose keys # will be numbers and not names. i = -1 for i, (name, v) in enumerate(self._name2idx.items()): while i < v: yield i i += 1 yield name for i in range(i + 1, len(self)): yield i def __len__(self): return len(self._actions) def add_action(self, name, *actions): assert ActionSelection.active is self if name is not None: self._name2idx[name] = len(self._actions) else: name = str(len(self._actions)) utility = input_vocab_registry.declare_connector( nengo.Node(size_in=1, label="Utility for action " + name), None ) self._utilities.append(utility) self._actions.append(actions) RoutedConnection.free_floating.difference_update(actions) return utility
def ifmax(name, condition, *actions): """Defines a potential action within an `ActionSelection` context. Parameters ---------- name : str Name for the action condition : nengo_spa.ast.base.Node The utility value for the given actions. actions : sequence of `RoutedConnection` The actions to activate if the given utility is the highest. Returns ------- NengoObject Nengo object that can be connected to, to provide additional input to the utility value. """ if ActionSelection.active is None: raise SpaActionSelectionError( "ifmax must be used within the context of an ActionSelection " "instance." ) if condition.type != TScalar: raise SpaTypeError( f"ifmax condition must evaluate to a scalar, but got {condition.type}." ) if any(not isinstance(a, RoutedConnection) for a in actions): raise SpaActionSelectionError( "ifmax actions must be routing expressions like 'a >> b'." ) utility = ActionSelection.active.add_action(name, *actions) condition.connect_to(utility) return utility