"""Operator graph optimizers."""
from collections import defaultdict, namedtuple
from itertools import zip_longest
import logging
import warnings
import numpy as np
from nengo.builder.neurons import SimNeurons
from nengo.builder import operator as op
from nengo.builder.signal import Signal
from nengo.rc import rc
from nengo.utils.graphs import BidirectionalDAG, transitive_closure
from nengo.utils.stdlib import Timer, WeakKeyDefaultDict, WeakSet
logger = logging.getLogger(__name__)
[docs]def optimize(model, dg):
"""Optimizes the operator graph by merging operators.
This reduces the number of iterators to iterate over in slow Python code
(as opposed to fast C code). The resulting merged operators will also
operate on larger chunks of sequential memory, making better use of CPU
caching and prefetching.
The optimization algorithm has worst case complexity :math:`O(n^2 + e)`,
where :math:`n` is the number of operators and :math:`e` is the number
of edges in the dependency graph. In practice the run time will be much
better because not all :math:`n^2` pairwise combinations of operators
will be evaluated. A grouping depending on the operator type and view
bases is done with dictionaries. This grouping can be done in amortized
linear time and reduces the actual worst-case runtime of the optimization
algorithm to :math:`O(gm^2 + e)`, where :math:`g` is the number of groups
and :math:`m` is the number of elements in a group. Moreover, information
about memory alignment will be used to cut the inner loop short in
many cases and gives a runtime much closer to linear in most cases.
Note that this function modifies both ``model`` and ``dg``.
Parameters
----------
model : `nengo.builder.Model`
Builder output to optimize.
dg : dict
Dict of the form ``{a: {b, c}}`` where ``b`` and ``c`` depend on ``a``,
specifying the operator dependency graph of the model.
"""
logger.debug("Optimizing model...")
# We try first to merge operators with views only as these have a fixed
# order for the memory alignment whereas operators without views could
# be merged in a random order. Merging the views of operators will
# propagate requirements in the memory ordering via the other
# associated signals of the operator to other operators.
# Once no more operators with views can be merged, we try to merge
# operators without views and then try again merging views (because
# each operator merge might generate new views).
single_pass = OpMergePass(dg)
n_initial_ops = len(dg)
cum_duration = 0.0
before, after = None, None
only_merge_ops_with_view = True
while only_merge_ops_with_view or after < before:
only_merge_ops_with_view = before is None or before != after
before = len(single_pass.dg.forward)
with Timer() as t:
single_pass(only_merge_ops_with_view)
after = len(single_pass.dg.forward)
logger.debug(
"[%s]: Reduced %i to %i operators in %fs.",
"views" if only_merge_ops_with_view else "non-views",
before,
after,
t.duration,
)
# Prevent optimizer from running too long if we get up diminishing
# returns.
# Note that we don't break if there was no reduction at all because
# in that case we want to toggle only_merge_ops_with_view which might
# still yield some significant reduction.
cum_duration += t.duration
mean_reduction_rate = float(n_initial_ops - after) / cum_duration
last_reduction_rate = float(before - after) / t.duration
threshold = 0.01
scaled_rate = threshold * mean_reduction_rate
if 0.0 < last_reduction_rate < scaled_rate: # pragma: no cover
logger.debug(
"Operator reduction rate fell below {} mean reduction rate. "
"Stopping optimizer.".format(threshold)
)
break
# Update model signals
for sigdict in model.sig.values():
for name in sigdict:
while sigdict[name] in single_pass.sig_replacements:
sigdict[name] = single_pass.sig_replacements[sigdict[name]]
# Reinitialize the model's operator list
del model.operators[:]
for operator in dg:
model.add_op(operator)
[docs]class OpMergePass:
"""Manages a single optimization pass."""
def __init__(self, dg):
self.dg = BidirectionalDAG(dg)
self.might_merge = set(dg)
self.sig_replacements = {}
self.sig2ops = WeakKeyDefaultDict(WeakSet)
self.base2views = WeakKeyDefaultDict(WeakSet)
for operator in self.dg.forward:
for s in operator.all_signals:
self.sig2ops[s].add(operator)
self.base2views[s.base].add(s)
# These variables will be initialized and used on each pass
self.dependents = None
self.only_merge_ops_with_view = None
self.merged = set()
self.merged_dependents = set()
self.opinfo = OpInfo()
def __call__(self, only_merge_ops_with_view):
"""Perform a single optimization pass.
Parameters
----------
only_merge_ops_with_view : bool
Limits operator merges to operators with views.
"""
# --- Initialize pass state
self.dependents = transitive_closure(self.dg.forward)
self.only_merge_ops_with_view = only_merge_ops_with_view
self.merged.clear()
self.merged_dependents.clear()
self.opinfo.clear()
# --- Do an optimization pass
self.perform_merges()
[docs] def merge(self, tomerge):
"""Merges the given operators.
This method will also update ``op_replacements``, ``sig_replacements``,
and the internal list of merged operators to prevent further merges
on the same operators before all required operators and signals have
been replaced.
"""
merged_op, merged_sig = OpMerger.merge(tomerge.ops)
self.dg.merge(tomerge.ops, merged_op)
# Update tracking what has been merged and might be mergeable later
self.might_merge.difference_update(tomerge.ops)
self.might_merge.add(merged_op)
self.merged.update(tomerge.ops)
self.merged_dependents.update(tomerge.all_dependents)
for operator in tomerge.ops:
# Mark all operators referencing the same signals as merged
# (even though they are not) to prevent them from getting
# merged before their signals have been updated.
for s in operator.all_signals:
self.merged.update(self.sig2ops[s])
# Signal related updates
self.resolve_views_on_replaced_signals(merged_sig)
self.sig_replacements.update(merged_sig)
self.replace_op_signals(merged_sig)
self.update_signal_indexing(merged_op, merged_sig)
def resolve_views_on_replaced_signals(self, replaced_signals):
for sig in list(replaced_signals):
for view in self.base2views[sig]:
if view is sig:
continue
assert view.base is sig
base_replacement = replaced_signals[sig]
offset = view.offset
strides = tuple(
a // b * c
for a, b, c in zip_longest(
view.strides,
view.base.strides,
base_replacement.strides,
fillvalue=1,
)
)
if base_replacement.is_view:
offset += base_replacement.offset
base_replacement = base_replacement.base
buf = base_replacement.initial_value
initial_value = np.ndarray(
buffer=buf,
dtype=view.dtype,
shape=view.shape,
offset=offset,
strides=strides,
)
replaced_signals[view] = Signal(
initial_value,
name=view.name,
base=base_replacement,
readonly=view.readonly,
offset=offset,
)
def replace_op_signals(self, replaced_signals):
ops = (op for s in replaced_signals for op in self.sig2ops[s])
for v in ops:
# Update the op's signals
v.sets = [replaced_signals.get(s, s) for s in v.sets]
v.incs = [replaced_signals.get(s, s) for s in v.incs]
v.reads = [replaced_signals.get(s, s) for s in v.reads]
v.updates = [replaced_signals.get(s, s) for s in v.updates]
def update_signal_indexing(self, merged_op, replaced_signals):
for s in merged_op.all_signals:
self.sig2ops[s].add(merged_op)
if s.is_view:
self.base2views[s.base].add(s)
for from_sig, to_sig in replaced_signals.items():
self.sig2ops[to_sig] = self.sig2ops[from_sig]
if to_sig.is_view:
self.base2views[to_sig.base].add(to_sig)
[docs]class OpInfo:
"""Analyze and store extra information about operators."""
_OpDetails = namedtuple(
"_OpDetails", ["first_view", "v_offset", "v_size", "v_base"]
)
def __init__(self):
super().__init__()
self.info = {}
def get(self, op):
if op not in self.info:
try:
first_view = next(s for s in op.all_signals if s.is_view)
self.info[op] = self._OpDetails(
first_view=first_view,
v_offset=first_view.offset,
v_size=first_view.nbytes,
v_base=first_view.base,
)
except StopIteration:
self.info[op] = self._OpDetails(
first_view=None, v_offset=0, v_size=0, v_base=None
)
return self.info[op]
def clear(self):
self.info.clear()
[docs]class OpsToMerge:
"""Analyze and store extra information about a list of ops to be merged."""
def __init__(self, initial_op, merged, merged_dependents, dependents):
self.merged = merged
self.merged_dependents = merged_dependents
self.dependents = dependents
self.ops = [initial_op]
self.optype = type(initial_op)
self.opinfo = OpInfo()
self.all_signals = set(initial_op.all_signals)
self.all_dependents = set(self.dependents[initial_op])
@property
def last_op(self):
return self.ops[-1]
def add(self, op):
self.ops.append(op)
self.all_signals.update(op.all_signals)
self.all_dependents.update(self.dependents[op])
@staticmethod
def check_signals(op1, op2):
for s1, s2 in zip(op1.all_signals, op2.all_signals):
# If one op's signal is a view, the other must be as well
if s1.is_view is not s2.is_view:
return False
if s1.is_view:
# Views must be on the same base
if s1.base is not s2.base:
return False
# Views must have the same strides
elif s1.strides != s2.strides:
return False
return True
def not_sequential(self, op):
lastop = self.opinfo.get(self.ops[-1])
return lastop.v_offset + lastop.v_size < self.opinfo.get(op).v_offset
[docs]class OpMerger:
"""Manages the op merge classes known to the optimizer."""
mergers = {}
@classmethod
def is_mergeable(cls, op, tomerge):
merger = cls.mergers[tomerge.optype]
independent_of_ops_tomerge = (
op not in tomerge.all_dependents
and len(tomerge.dependents[op].intersection(tomerge.ops)) == 0
)
independent_of_prior_merges = (
op not in tomerge.merged
and op not in tomerge.merged_dependents
and all(o not in tomerge.merged for o in tomerge.dependents[op])
)
return (
type(op) is tomerge.optype
and independent_of_ops_tomerge
and independent_of_prior_merges
and cls.is_type_mergeable(tomerge.optype)
and tomerge.check_signals(tomerge.ops[-1], op)
and merger.check_signals(op, tomerge)
and merger.is_mergeable(tomerge.last_op, op)
)
@classmethod
def is_type_mergeable(cls, optype):
return optype in cls.mergers
@classmethod
def merge(cls, ops):
return cls.mergers[type(ops[0])].merge(ops)
@classmethod
def register(cls, optype):
def register(merger):
if optype in cls.mergers:
warnings.warn("Merger for operator type {} overwritten.".format(optype))
cls.mergers[optype] = merger
return merger
return register
[docs]class Merger:
"""Base class for all op merge classes."""
@staticmethod
def check_signals(op, tomerge):
return len(tomerge.all_signals.intersection(op.all_signals)) == 0
@staticmethod
def is_mergeable(op1, op2):
raise NotImplementedError("Subclasses must implement `is_mergeable`")
@staticmethod
def merge(ops):
raise NotImplementedError("Cannot merge arbitrary ops.")
[docs] @staticmethod
def merge_dicts(*dicts):
"""Merges the given dictionaries into a single dictionary.
This function assumes and enforces that no keys overlap.
"""
d = {}
for other_d in dicts:
assert all(k not in d for k in other_d)
d.update(other_d)
return d
[docs]@OpMerger.register(op.Reset)
class ResetMerger(Merger):
"""Merge `.Reset` ops."""
@staticmethod
def is_mergeable(op1, op2):
return SigMerger.check([op1.dst, op2.dst]) and op1.value == op2.value
@staticmethod
def merge(ops):
dst, replacements = SigMerger.merge([o.dst for o in ops])
return op.Reset(dst, ops[0].value), replacements
[docs]@OpMerger.register(op.Copy)
class CopyMerger(Merger):
"""Merge `.Copy` ops."""
@staticmethod
def is_mergeable(op1, op2):
return (
SigMerger.check([op1.src, op2.src])
and SigMerger.check([op1.dst, op2.dst])
and op1.src_slice is None
and op1.dst_slice is None
and op2.src_slice is None
and op2.dst_slice is None
and op1.inc == op2.inc
)
@staticmethod
def merge_slice(signals, slices):
if all(s is None for s in slices):
return None
elif any(s is None for s in slices):
raise ValueError("Mixed Ellipsis with list of indices.")
offset = 0
merged_slice = []
for sig, sl in zip(signals, slices):
assert isinstance(sl, list), "Expecting a list of indices, got %s" % sl
merged_slice.extend([i + offset for i in sl])
offset += sig.size
return merged_slice
@staticmethod
def merge(ops):
src_sigs = [o.src for o in ops]
dst_sigs = [o.dst for o in ops]
src, src_sigr = SigMerger.merge(src_sigs)
dst, dst_sigr = SigMerger.merge(dst_sigs)
src_slice = CopyMerger.merge_slice(src_sigs, [o.src_slice for o in ops])
dst_slice = CopyMerger.merge_slice(dst_sigs, [o.dst_slice for o in ops])
return (
op.Copy(src, dst, src_slice=src_slice, dst_slice=dst_slice, inc=ops[0].inc),
Merger.merge_dicts(src_sigr, dst_sigr),
)
[docs]@OpMerger.register(op.ElementwiseInc)
class ElementwiseIncMerger(Merger):
"""Merge `.ElementwiseInc` ops."""
@staticmethod
def is_mergeable(op1, op2):
scalar_mult = op1.A.shape == (1,) and op2.A.shape == (1,)
non_scalar_mult = op1.A.shape != (1,) and op2.A.shape != (1,)
return (
SigMerger.check([op1.X, op2.X], axis=op1.X.ndim - 1)
and SigMerger.check([op1.Y, op2.Y], axis=op1.Y.ndim - 1)
and (
(scalar_mult and op1.A.initial_value == op2.A.initial_value)
or (
non_scalar_mult
and SigMerger.check([op1.A, op2.A], axis=op1.A.ndim - 1)
)
)
)
@staticmethod
def merge(ops):
if all(o.A.shape == (1,) for o in ops):
assert all(o.A.initial_value == ops[0].A.initial_value for o in ops)
A, A_sigr = ops[0].A, {}
else:
A, A_sigr = SigMerger.merge([o.A for o in ops], axis=ops[0].A.ndim - 1)
X, X_sigr = SigMerger.merge([o.X for o in ops], axis=ops[0].X.ndim - 1)
Y, Y_sigr = SigMerger.merge([o.Y for o in ops], axis=ops[0].Y.ndim - 1)
return (op.ElementwiseInc(A, X, Y), Merger.merge_dicts(A_sigr, X_sigr, Y_sigr))
[docs]@OpMerger.register(op.DotInc)
class DotIncMerger(Merger):
"""Merge `.DotInc` ops."""
@staticmethod
def check_signals(op, tomerge):
none_shared = Merger.check_signals(op, tomerge) and (
len(tomerge.ops) < 2 or tomerge.ops[0].X is not tomerge.ops[1].X
)
all_x_shared = (
op.A not in tomerge.all_signals
and op.Y not in tomerge.all_signals
and all(op.X not in [o.A, o.Y] and op.X is o.X for o in tomerge.ops)
)
return none_shared or all_x_shared
@staticmethod
def is_mergeable(op1, op2):
if op1.X is op2.X:
# simple merge might be possible
return SigMerger.check([op1.Y, op2.Y]) and SigMerger.check([op1.A, op2.A])
# check if BSR merge is possible
try:
# Not using check() for A, because A must not be a view.
SigMerger.check_signals([op1.A, op2.A])
from scipy.sparse import ( # pylint: disable=import-outside-toplevel
bsr_matrix,
)
assert bsr_matrix
except ImportError:
warnings.warn(
"Skipping some optimization steps because SciPy is "
"not installed. Installing SciPy may result in "
"faster simulations."
)
return False
except ValueError:
return False
return (
SigMerger.check([op1.X, op2.X])
and SigMerger.check([op1.Y, op2.Y])
and op1.A.shape != (1,)
and op1.A.shape == op2.A.shape
)
@staticmethod
def merge(ops):
# Simple merge if all X are the same.
if all(o.X is ops[0].X for o in ops):
A, A_sigr = SigMerger.merge([o.A for o in ops])
Y, Y_sigr = SigMerger.merge([o.Y for o in ops])
return (op.DotInc(A, ops[0].X, Y), Merger.merge_dicts(A_sigr, Y_sigr))
assert all(o1.X is not o2.X for i, o1 in enumerate(ops) for o2 in ops[i + 1 :])
# BSR merge if X differ
X, X_sigr = SigMerger.merge([o.X for o in ops])
Y, Y_sigr = SigMerger.merge([o.Y for o in ops])
# Construct sparse A representation
data = np.array([o.A.initial_value for o in ops], dtype=rc.float_dtype)
if data.ndim == 1:
raise NotImplementedError("A.ndim should be > 2")
elif data.ndim == 2:
raise NotImplementedError("A.ndim should be > 2")
indptr = np.arange(len(ops) + 1, dtype=rc.int_dtype)
indices = np.arange(len(ops), dtype=rc.int_dtype)
name = "bsr_merged<{first}, ..., {last}>".format(
first=ops[0].A.name, last=ops[-1].A.name
)
readonly = all([o.A.readonly for o in ops])
A = Signal(data, name=name, readonly=readonly)
A_sigr = {}
for i, s in enumerate([o.A for o in ops]):
A_sigr[s] = Signal(
data[i],
name="%s[%i]" % (s.name, i),
base=A,
offset=i * A.itemsize * np.prod(A.shape[1:]),
)
assert np.allclose(
s.initial_value, A_sigr[s].initial_value, atol=0, rtol=0, equal_nan=True
)
assert s.shape == A_sigr[s].shape or (
s.shape == () and A_sigr[s].shape == (1, 1)
)
reshape = op.reshape_dot(
ops[0].A.initial_value,
ops[0].X.initial_value,
ops[0].Y.initial_value,
tag=ops[0].tag,
)
return (
op.BsrDotInc(A, X, Y, indices=indices, indptr=indptr, reshape=reshape),
Merger.merge_dicts(X_sigr, Y_sigr, A_sigr),
)
[docs]@OpMerger.register(SimNeurons)
class SimNeuronsMerger(Merger):
"""Merge `.SimNeurons` ops."""
@staticmethod
def is_mergeable(op1, op2):
return op1.neurons == op2.neurons and all(
SigMerger.check(s) for s in zip(op1.all_signals, op2.all_signals)
)
@staticmethod
def merge(ops):
def gather(ops, key):
return [getattr(o, key) for o in ops]
J, J_sigr = SigMerger.merge(gather(ops, "J"))
output, out_sigr = SigMerger.merge(gather(ops, "output"))
states = []
states_sigr = {}
for signals in zip(*gather(ops, "states")):
st, st_sigr = SigMerger.merge(signals)
states.append(st)
states_sigr.update(st_sigr)
return (
SimNeurons(ops[0].neurons, J, output, states),
Merger.merge_dicts(J_sigr, out_sigr, states_sigr),
)
[docs]class SigMerger:
"""Merge signals."""
[docs] @staticmethod
def check(signals, axis=0):
"""Checks that all signals can be concatenated along a given axis.
For views, this includes also a check that the signals have a common
base and agree on the strides.
In comparison to the ``check_*`` functions, this function does
not throw exceptions and allows for either signals or signal views.
"""
if len(set(signals)) != len(signals):
# Signal appears twice in list.
return False
if all(s.is_view for s in signals):
try:
SigMerger.check_views(signals, axis=axis)
except ValueError:
return False
elif all(not s.is_view for s in signals):
try:
SigMerger.check_signals(signals, axis=axis)
except ValueError:
return False
else:
# Mix of signals and not signals
return False
# If we haven't failed yet, then the signals are compatible
return True
[docs] @staticmethod
def check_signals(signals, axis=0):
"""Checks that all signals can be merged along a given axis.
If this is not possible, or any signals are views, a
``ValueError`` will be raised.
"""
if any(s.is_view for s in signals):
raise ValueError("Cannot merge views.")
if any(s.sparse for s in signals):
raise ValueError("Cannot merge sparse Signals")
for s in signals:
if s.ndim != signals[0].ndim:
raise ValueError("Signals must have the same number of dimensions.")
if s.ndim <= 0 and s.initial_value != signals[0].initial_value:
raise ValueError("0-d signals must have the same initial value.")
if (
s.shape[:axis] != signals[0].shape[:axis]
or s.shape[axis + 1 :] != signals[0].shape[axis + 1 :]
):
raise ValueError(
"Signals must have same shape except on concatenation axis."
)
if s.dtype is not signals[0].dtype:
raise ValueError("Signals must have the same dtype.")
[docs] @staticmethod
def check_views(signals, axis=0):
"""Checks that all signal views can be merged along a given axis.
If this is not possible, or any signals are not views,
a ``ValueError`` will be raised.
``signals`` must be ordered by the offset into the base signal.
"""
if any(not s.is_view for s in signals):
raise ValueError("Cannot merge non-views.")
start = signals[0].offset
for s in signals:
if s.base is not signals[0].base:
raise ValueError("Signals must share the same base.")
if s.ndim != signals[0].ndim:
raise ValueError("Signals must have the same number of dimensions.")
if s.strides != signals[0].strides:
raise ValueError("Signals must have equal strides.")
if (
s.shape[:axis] != signals[0].shape[:axis]
or s.shape[axis + 1 :] != signals[0].shape[axis + 1 :]
):
raise ValueError(
"Signals must have same shape except on concatenation axis."
)
if s.offset != start:
raise ValueError("Views are not sequential.")
start = s.offset + s.nbytes
[docs] @staticmethod
def merge(signals, axis=0):
"""Merges multiple signals or signal views into one contiguous signal.
Note that if any of the signals are linked to another signal (by being
the base of a view), the merged signal will not reflect those links
anymore.
Parameters
----------
signals : sequence
Signals to merge. Must not contain views.
axis : int, optional
Axis along which to concatenate the signals.
Returns
-------
merged_signal : Signal
The merged signal.
replacements : dict
Dictionary mapping from the old signals to new signals that are
a view into the merged signal. Used to replace old signals.
"""
are_views = [s.is_view for s in signals]
if all(are_views):
return SigMerger.merge_views(signals, axis=axis)
elif not any(are_views):
return SigMerger.merge_signals(signals, axis=axis)
else:
raise ValueError("Cannot merge mixed views and non-views.")
[docs] @staticmethod
def merge_signals(signals, axis=0):
"""Merges multiple signal into one contiguous signal.
Note that if any of the signals are linked to another signal (by being
the base of a view), the merged signal will not reflect
those links anymore.
Parameters
----------
signals : sequence
Signals to merge. Must not contain views.
axis : int, optional
Axis along which to concatenate the signals.
Returns
-------
merged_signal : Signal
The merged signal.
replacements : dict
Dictionary mapping from the old signals to new signals that are
a view into the merged signal. Used to replace old signals.
"""
SigMerger.check_signals(signals, axis=axis)
if signals[0].ndim > 0:
initial_value = np.concatenate(
[s.initial_value for s in signals], axis=axis
)
else:
initial_value = signals[0].initial_value
readonly = all(s.readonly for s in signals)
name = "merged<" + signals[0].name + ", ..., " + signals[-1].name + ">"
merged_signal = Signal(initial_value, name=name, readonly=readonly)
if signals[0].ndim > 0:
replacements = {}
start = 0
for s in signals:
size = s.shape[axis]
indexing = [slice(None)] * initial_value.ndim
indexing[axis] = slice(start, start + size)
replacements[s] = merged_signal[tuple(indexing)]
start += size
else:
replacements = {s: merged_signal for s in signals}
return merged_signal, replacements
[docs] @staticmethod
def merge_views(signals, axis=0):
"""Merges multiple signal views into one contiguous signal view.
Parameters
----------
signals : sequence
Signals to merge. Must only contain views.
axis : int, optional
Axis along which to concatenate the signals.
Returns
-------
merged_signal : Signal
The merged signal.
replacements : dict
Dictionary mapping from the old signals to new signals that are
a view into the merged signal. Used to replace old signals.
"""
SigMerger.check_views(signals, axis=axis)
# abs_offset = min(s.abs_offset for s in signals)
offset = min(s.offset for s in signals)
shape = (
signals[0].shape[:axis]
+ (sum(s.shape[axis] for s in signals),)
+ signals[0].shape[axis + 1 :]
)
initial_value = np.ndarray(
buffer=signals[0].base.initial_value,
dtype=signals[0].dtype,
shape=shape,
offset=offset,
strides=signals[0].strides,
)
merged_signal = Signal(
initial_value,
name=signals[0].base.name,
base=signals[0].base,
readonly=all(s.readonly for s in signals),
offset=offset,
)
return merged_signal, {}
[docs]def groupby(lst, keyfunc=lambda item: item):
"""Groups the given list by the value returned by ``keyfunc``.
Similar to ``itertools.groupby``, but returns a dict, and does not depend
on the order of the input list.
"""
d = defaultdict(list)
for item in lst:
d[keyfunc(item)].append(item)
return d