Source code for keras_spiking.regularizers

"""
Regularization methods designed to work with spiking layers.
"""

import tensorflow as tf


[docs]class L1L2(tf.keras.regularizers.L1L2): """ A version of ``tf.keras.regularizers.L1L2`` that allows the user to specify a nonzero target output. Parameters ---------- l1 : float Weight on L1 regularization penalty. l2 : float Weight on L2 regularization penalty. target : float Target output value (values will be penalized based on their distance from this point). """ def __init__(self, l1=0.0, l2=0.0, target=0, **kwargs): super().__init__(l1=l1, l2=l2, **kwargs) self.target = target def __call__(self, x): return super().__call__(x - self.target)
[docs] def get_config(self): """Return config (for serialization during model saving/loading).""" cfg = super().get_config() cfg["target"] = self.target return cfg
[docs]class L1(L1L2): """ A version of ``tf.keras.regularizers.L1`` that allows the user to specify a nonzero target output. Parameters ---------- l1 : float Weight on L1 regularization penalty. target : float Target output value (values will be penalized based on their distance from this point). """ def __init__(self, l1=0.01, target=0, **kwargs): super().__init__(l1=l1, target=target, **kwargs)
[docs]class L2(L1L2): """ A version of ``tf.keras.regularizers.L2`` that allows the user to specify a nonzero target output. Parameters ---------- l2 : float Weight on L2 regularization penalty. target : float Target output value (values will be penalized based on their distance from this point). """ def __init__(self, l2=0.01, target=0, **kwargs): super().__init__(l2=l2, target=target, **kwargs)