Source code for nengo.learning_rules

import warnings

from nengo.config import SupportDefaultsMixin
from nengo.exceptions import ValidationError
from nengo.params import (Default, IntParam, FrozenObject, NumberParam,
                          Parameter, Unconfigurable)
from nengo.synapses import Lowpass, SynapseParam
from nengo.utils.compat import is_iterable, is_string, itervalues


class LearningRuleTypeSizeInParam(IntParam):
    valid_strings = ('pre', 'post', 'mid', 'pre_state', 'post_state')

    def coerce(self, instance, size_in):
        if is_string(size_in):
            if size_in not in self.valid_strings:
                raise ValidationError(
                    "%r is not a valid string value (must be one of %s)"
                    % (size_in, self.strings), attr=self.name, obj=instance)
            return size_in
        else:
            return super(LearningRuleTypeSizeInParam, self).coerce(
                instance, size_in)  # IntParam validation


[docs]class LearningRuleType(FrozenObject, SupportDefaultsMixin): """Base class for all learning rule objects. To use a learning rule, pass it as a ``learning_rule_type`` keyword argument to the `~nengo.Connection` on which you want to do learning. Each learning rule exposes two important pieces of metadata that the builder uses to determine what information should be stored. The ``size_in`` is the dimensionality of the incoming error signal. It can either take an integer or one of the following string values: * ``'pre'``: vector error signal in pre-object space * ``'post'``: vector error signal in post-object space * ``'mid'``: vector error signal in the ``conn.size_mid`` space * ``'pre_state'``: vector error signal in pre-synaptic ensemble space * ``'post_state'``: vector error signal in pre-synaptic ensemble space The difference between ``'post_state'`` and ``'post'`` is that with the former, if a ``Neurons`` object is passed, it will use the dimensionality of the corresponding ``Ensemble``, whereas the latter simply uses the ``post`` object ``size_in``. Similarly with ``'pre_state'`` and ``'pre'``. The ``modifies`` attribute denotes the signal targeted by the rule. Options are: * ``'encoders'`` * ``'decoders'`` * ``'weights'`` Parameters ---------- learning_rate : float, optional (Default: 1e-6) A scalar indicating the rate at which ``modifies`` will be adjusted. size_in : int, str, optional (Default: 0) Dimensionality of the error signal (see above). Attributes ---------- learning_rate : float A scalar indicating the rate at which ``modifies`` will be adjusted. size_in : int, str Dimensionality of the error signal. modifies : str The signal targeted by the learning rule. """ modifies = None probeable = () learning_rate = NumberParam( 'learning_rate', low=0, readonly=True, default=1e-6) size_in = LearningRuleTypeSizeInParam('size_in', low=0) def __init__(self, learning_rate=Default, size_in=0): super(LearningRuleType, self).__init__() self.learning_rate = learning_rate self.size_in = size_in def __repr__(self): r = [] for name, default in self._argdefaults: value = getattr(self, name) if value != default: r.append("%s=%r" % (name, value)) return '%s(%s)' % (type(self).__name__, ", ".join(r)) @property def _argdefaults(self): return ('learning_rate', LearningRuleType.learning_rate.default),
def _deprecated_tau(old_attr, new_attr): def get_tau(self): return (None if getattr(self, new_attr) is None else getattr(self, new_attr).tau) def set_tau(self, val): if val is Unconfigurable: return since = "v2.8.0" url = "https://github.com/nengo/nengo/pull/1095" msg = ("%s has been deprecated, use %s instead (since %s).\n" "For more information, please visit %s" % ( old_attr, new_attr, since, url)) warnings.warn(msg, DeprecationWarning) setattr(self, new_attr, None if val is None else Lowpass(val)) return property(get_tau, set_tau)
[docs]class PES(LearningRuleType): """Prescribed Error Sensitivity learning rule. Modifies a connection's decoders to minimize an error signal provided through a connection to the connection's learning rule. Parameters ---------- learning_rate : float, optional (Default: 1e-4) A scalar indicating the rate at which weights will be adjusted. pre_synapse : `.Synapse`, optional \ (Default: ``nengo.synapses.Lowpass(tau=0.005)``) Synapse model used to filter the pre-synaptic activities. Attributes ---------- learning_rate : float A scalar indicating the rate at which weights will be adjusted. pre_synapse : `.Synapse` Synapse model used to filter the pre-synaptic activities. """ modifies = 'decoders' probeable = ('error', 'correction', 'activities', 'delta') learning_rate = NumberParam( 'learning_rate', low=0, readonly=True, default=1e-4) pre_synapse = SynapseParam( 'pre_synapse', default=Lowpass(tau=0.005), readonly=True) pre_tau = _deprecated_tau("pre_tau", "pre_synapse") def __init__(self, learning_rate=Default, pre_synapse=Default, pre_tau=Unconfigurable): super(PES, self).__init__(learning_rate, size_in='post_state') if learning_rate is not Default and learning_rate >= 1.0: warnings.warn("This learning rate is very high, and can result " "in floating point errors from too much current.") if pre_tau is Unconfigurable: self.pre_synapse = pre_synapse else: self.pre_tau = pre_tau @property def _argdefaults(self): return (('learning_rate', PES.learning_rate.default), ('pre_synapse', PES.pre_synapse.default))
[docs]class BCM(LearningRuleType): """Bienenstock-Cooper-Munroe learning rule. Modifies connection weights as a function of the presynaptic activity and the difference between the postsynaptic activity and the average postsynaptic activity. Notes ----- The BCM rule is dependent on pre and post neural activities, not decoded values, and so is not affected by changes in the size of pre and post ensembles. However, if you are decoding from the post ensemble, the BCM rule will have an increased effect on larger post ensembles because more connection weights are changing. In these cases, it may be advantageous to scale the learning rate on the BCM rule by ``1 / post.n_neurons``. Parameters ---------- learning_rate : float, optional (Default: 1e-9) A scalar indicating the rate at which weights will be adjusted. pre_synapse : `.Synapse`, optional \ (Default: ``nengo.synapses.Lowpass(tau=0.005)``) Synapse model used to filter the pre-synaptic activities. post_synapse : `.Synapse`, optional (Default: ``None``) Synapse model used to filter the post-synaptic activities. If None, ``post_synapse`` will be the same as ``pre_synapse``. theta_synapse : `.Synapse`, optional \ (Default: ``nengo.synapses.Lowpass(tau=1.0)``) Synapse model used to filter the theta signal. Attributes ---------- learning_rate : float A scalar indicating the rate at which weights will be adjusted. post_synapse : `.Synapse` Synapse model used to filter the post-synaptic activities. pre_synapse : `.Synapse` Synapse model used to filter the pre-synaptic activities. theta_synapse : `.Synapse` Synapse model used to filter the theta signal. """ modifies = 'weights' probeable = ('theta', 'pre_filtered', 'post_filtered', 'delta') learning_rate = NumberParam( 'learning_rate', low=0, readonly=True, default=1e-9) pre_synapse = SynapseParam( 'pre_synapse', default=Lowpass(tau=0.005), readonly=True) post_synapse = SynapseParam( 'post_synapse', default=None, readonly=True) theta_synapse = SynapseParam( 'theta_synapse', default=Lowpass(tau=1.0), readonly=True) pre_tau = _deprecated_tau("pre_tau", "pre_synapse") post_tau = _deprecated_tau("post_tau", "post_synapse") theta_tau = _deprecated_tau("theta_tau", "theta_synapse") def __init__(self, learning_rate=Default, pre_synapse=Default, post_synapse=Default, theta_synapse=Default, pre_tau=Unconfigurable, post_tau=Unconfigurable, theta_tau=Unconfigurable): super(BCM, self).__init__(learning_rate, size_in=0) if pre_tau is Unconfigurable: self.pre_synapse = pre_synapse else: self.pre_tau = pre_tau if post_tau is Unconfigurable: self.post_synapse = (self.pre_synapse if post_synapse is Default else post_synapse) else: self.post_tau = post_tau if theta_tau is Unconfigurable: self.theta_synapse = theta_synapse else: self.theta_tau = theta_tau @property def _argdefaults(self): return (('learning_rate', BCM.learning_rate.default), ('pre_synapse', BCM.pre_synapse.default), ('post_synapse', self.pre_synapse), ('theta_synapse', BCM.theta_synapse.default))
[docs]class Oja(LearningRuleType): """Oja learning rule. Modifies connection weights according to the Hebbian Oja rule, which augments typically Hebbian coactivity with a "forgetting" term that is proportional to the weight of the connection and the square of the postsynaptic activity. Notes ----- The Oja rule is dependent on pre and post neural activities, not decoded values, and so is not affected by changes in the size of pre and post ensembles. However, if you are decoding from the post ensemble, the Oja rule will have an increased effect on larger post ensembles because more connection weights are changing. In these cases, it may be advantageous to scale the learning rate on the Oja rule by ``1 / post.n_neurons``. Parameters ---------- learning_rate : float, optional (Default: 1e-6) A scalar indicating the rate at which weights will be adjusted. pre_synapse : `.Synapse`, optional \ (Default: ``nengo.synapses.Lowpass(tau=0.005)``) Synapse model used to filter the pre-synaptic activities. post_synapse : `.Synapse`, optional (Default: ``None``) Synapse model used to filter the post-synaptic activities. If None, ``post_synapse`` will be the same as ``pre_synapse``. beta : float, optional (Default: 1.0) A scalar weight on the forgetting term. Attributes ---------- beta : float A scalar weight on the forgetting term. learning_rate : float A scalar indicating the rate at which weights will be adjusted. post_synapse : `.Synapse` Synapse model used to filter the post-synaptic activities. pre_synapse : `.Synapse` Synapse model used to filter the pre-synaptic activities. """ modifies = 'weights' probeable = ('pre_filtered', 'post_filtered', 'delta') learning_rate = NumberParam( 'learning_rate', low=0, readonly=True, default=1e-6) pre_synapse = SynapseParam( 'pre_synapse', default=Lowpass(tau=0.005), readonly=True) post_synapse = SynapseParam( 'post_synapse', default=None, readonly=True) beta = NumberParam('beta', low=0, readonly=True, default=1.0) pre_tau = _deprecated_tau("pre_tau", "pre_synapse") post_tau = _deprecated_tau("post_tau", "post_synapse") def __init__(self, learning_rate=Default, pre_synapse=Default, post_synapse=Default, beta=Default, pre_tau=Unconfigurable, post_tau=Unconfigurable): super(Oja, self).__init__(learning_rate, size_in=0) self.beta = beta if pre_tau is Unconfigurable: self.pre_synapse = pre_synapse else: self.pre_tau = pre_tau if post_tau is Unconfigurable: self.post_synapse = (self.pre_synapse if post_synapse is Default else post_synapse) else: self.post_tau = post_tau @property def _argdefaults(self): return (('learning_rate', Oja.learning_rate.default), ('pre_synapse', Oja.pre_synapse.default), ('post_synapse', self.pre_synapse), ('beta', Oja.beta.default))
[docs]class Voja(LearningRuleType): """Vector Oja learning rule. Modifies an ensemble's encoders to be selective to its inputs. A connection to the learning rule will provide a scalar weight for the learning rate, minus 1. For instance, 0 is normal learning, -1 is no learning, and less than -1 causes anti-learning or "forgetting". Parameters ---------- learning_rate : float, optional (Default: 1e-2) A scalar indicating the rate at which encoders will be adjusted. post_synapse : `.Synapse`, optional \ (Default: ``nengo.synapses.Lowpass(tau=0.005)``) Synapse model used to filter the post-synaptic activities. Attributes ---------- learning_rate : float A scalar indicating the rate at which encoders will be adjusted. post_synapse : `.Synapse` Synapse model used to filter the post-synaptic activities. """ modifies = 'encoders' probeable = ('post_filtered', 'scaled_encoders', 'delta') learning_rate = NumberParam( 'learning_rate', low=0, readonly=True, default=1e-2) post_synapse = SynapseParam( 'post_synapse', default=Lowpass(tau=0.005), readonly=True) post_tau = _deprecated_tau("post_tau", "post_synapse") def __init__(self, learning_rate=Default, post_synapse=Default, post_tau=Unconfigurable): super(Voja, self).__init__(learning_rate, size_in=1) if post_tau is Unconfigurable: self.post_synapse = post_synapse else: self.post_tau = post_tau @property def _argdefaults(self): return (('learning_rate', Voja.learning_rate.default), ('post_synapse', Voja.post_synapse.default))
class LearningRuleTypeParam(Parameter): def check_rule(self, instance, rule): if not isinstance(rule, LearningRuleType): raise ValidationError( "'%s' must be a learning rule type or a dict or " "list of such types." % rule, attr=self.name, obj=instance) if rule.modifies not in ('encoders', 'decoders', 'weights'): raise ValidationError("Unrecognized target %r" % rule.modifies, attr=self.name, obj=instance) def coerce(self, instance, rule): if is_iterable(rule): for r in (itervalues(rule) if isinstance(rule, dict) else rule): self.check_rule(instance, r) elif rule is not None: self.check_rule(instance, rule) return super(LearningRuleTypeParam, self).coerce(instance, rule)