Source code for nengo_dl.builder

import logging
import warnings

from nengo.exceptions import BuildError
import tensorflow as tf

logger = logging.getLogger(__name__)


[docs]class Builder(object): """Manages the operator build classes known to the ``nengo_dl`` build process.""" builders = {}
[docs] @classmethod def pre_build(cls, ops, signals, rng, op_builds): """Setup step for build classes, in which they compute any of the values that are constant across simulation timesteps. Parameters ---------- ops : tuple of :class:`~nengo:nengo.builder.Operator` the operator group to build into the model signals : :class:`.signals.SignalDict` mapping from :class:`~nengo:nengo.builder.Signal` to ``tf.Tensor`` (updated by operations) rng : :class:`~numpy:numpy.random.RandomState` random number generator instance op_builds : dict of {tuple of :class:`~nengo.builder.Operator`, \ :class:~`.op_builders.OpBuilder`} ``pre_build`` will populate this dictionary with the OpBuilder objects (which execute the pre-build step in their ``__init__``) """ logger.debug("===================") logger.debug("PRE BUILD %s", ops) logger.debug("sets %s", [op.sets for op in ops]) logger.debug("incs %s", [op.incs for op in ops]) logger.debug("reads %s", [op.reads for op in ops]) logger.debug("updates %s", [op.updates for op in ops]) if type(ops[0]) not in cls.builders: raise BuildError("No registered builder for operators of type %r" % type(ops[0])) BuildClass = cls.builders[type(ops[0])] kwargs = {} if BuildClass.pass_rng: kwargs["rng"] = rng op_builds[ops] = BuildClass(ops, signals, **kwargs)
[docs] @classmethod def build(cls, ops, signals, op_builds): """Build the computations implementing a single simulator timestep. Parameters ---------- ops : tuple of :class:`~nengo:nengo.builder.Operator` the operator group to build into the model signals : :class:`.signals.SignalDict` mapping from :class:`~nengo:nengo.builder.Signal` to ``tf.Tensor`` (updated by operations) op_builds : dict of {tuple of :class:`~nengo.builder.Operator`, \ :class:~`.op_builders.OpBuilder`} mapping from operator groups to the pre-built builder objects """ logger.debug("===================") logger.debug("BUILD %s", ops) if ops not in op_builds: raise BuildError("Operators build has not been initialized " "(missed pre-build step)") output = op_builds[ops].build_step(signals) if isinstance(output, (tf.Tensor, tf.Variable)): output = [output] elif isinstance(output, tuple): output = list(output) return output
[docs] @classmethod def register(cls, nengo_op): """A decorator for adding a class to the build function registry. Parameters ---------- nengo_op : :class:`~nengo:nengo.builder.Operator` The operator associated with the build function being decorated. """ def register_builder(build_class): if not issubclass(build_class, OpBuilder): warnings.warn("Build classes should inherit from OpBuilder") if nengo_op in cls.builders: warnings.warn("Operator '%s' already has a builder. " "Overwriting." % nengo_op) cls.builders[nengo_op] = build_class return build_class return register_builder
[docs]class OpBuilder(object): # pragma: no cover """The constructor should set up any computations that are fixed for this op (i.e., things that do not need to be recomputed each timestep). Parameters ---------- ops : list of :class:`~nengo:nengo.builder.Operator` the operator group to build into the model signals : :class:`.signals.SignalDict` mapping from :class:`~nengo:nengo.builder.Signal` to ``tf.Tensor`` (updated by operations) Attributes ---------- pass_rng : bool set to True if this build class requires the simulation random number generator to be passed to the constructor """ pass_rng = False def __init__(self, ops, signals): pass
[docs] def build_step(self, signals): """This function builds whatever computations need to be executed in each simulation timestep. Parameters ---------- signals : :class:`.signals.SignalDict` mapping from :class:`~nengo:nengo.builder.Signal` to ``tf.Tensor`` (updated by operations) Returns ------- list of ``tf.Tensor``, optional if not None, the returned tensors correspond to outputs with possible side-effects, i.e. computations that need to be executed in the tensorflow graph even if their output doesn't appear to be used """ raise BuildError("OpBuilders must implement a `build_step` function")