Source code for nengo.utils.simulator

from collections import defaultdict
import itertools

from .graphs import add_edges
from .stdlib import groupby


[docs]def operator_dependency_graph(operators): # noqa: C901 """Sort operators in a directed graph based on read/write dependencies.""" # -- all views of a base object in a particular dictionary by_base_sets = defaultdict(set) by_base_writes = defaultdict(set) by_base_reads = defaultdict(set) reads = defaultdict(list) sets = defaultdict(list) incs = defaultdict(list) ups = defaultdict(list) for op in operators: for sig in op.sets: by_base_sets[sig.base].add(sig) for sig in op.sets + op.incs: by_base_writes[sig.base].add(sig) for sig in op.reads: by_base_reads[sig.base].add(sig) for sig in op.reads: reads[sig].append(op) for sig in op.sets: sets[sig].append(op) for sig in op.incs: incs[sig].append(op) for sig in op.updates: ups[sig].append(op) validate_ops(sets, ups, incs) # -- Scheduling algorithm for serial evaluation: # 1) All sets on a given memory block # 2) All incs on a given memory block # 3) All reads on a given memory block # 4) All updates on a given memory block dg = {op: set() for op in operators} # ops are nodes of the graph # -- incs depend on sets for sig, post_ops in incs.items(): pre_ops = list(sets[sig]) for sig2 in by_base_sets[sig.base]: if sig.may_share_memory(sig2): pre_ops.extend(sets[sig2]) add_edges(dg, itertools.product(set(pre_ops), post_ops)) # -- reads depend on writes (sets and incs) for sig, post_ops in reads.items(): pre_ops = sets[sig] + incs[sig] for sig2 in by_base_writes[sig.base]: if sig.may_share_memory(sig2): pre_ops.extend(sets[sig2] + incs[sig2]) add_edges(dg, itertools.product(set(pre_ops), post_ops)) # -- updates depend on reads, sets, and incs. for sig, post_ops in ups.items(): pre_ops = sets[sig] + incs[sig] + reads[sig] for sig2 in by_base_reads[sig.base].union(by_base_writes[sig.base]): if sig.may_share_memory(sig2): pre_ops.extend(sets[sig2] + incs[sig2] + reads[sig2]) add_edges(dg, itertools.product(set(pre_ops), post_ops)) return dg
[docs]def validate_ops(sets, ups, incs): """Validate operator reads/writes.""" # -- assert that only one op sets any particular view for sig in sets: sig_sets = sets[sig] + (sets.get(sig.base, []) if sig.is_view else []) assert len(sig_sets) == 1, (sig, sig_sets) # -- assert that only one op updates any particular view for sig in ups: sig_ups = ups[sig] + (ups.get(sig.base, []) if sig.is_view else []) assert len(sig_ups) == 1, (sig, sig_ups) # -- assert that no two views are both set and aliased for _, base_group in groupby(sets, lambda x: x.base, hashable=True): for sig, sig2 in itertools.combinations(base_group, 2): assert not sig.may_share_memory(sig2), "%s shares memory with %s" % ( sig, sig2, ) # -- assert that no two views are both updated and aliased for _, base_group in groupby(ups, lambda x: x.base, hashable=True): for sig, sig2 in itertools.combinations(base_group, 2): assert not sig.may_share_memory(sig2), "%s shares memory with %s" % ( sig, sig2, )