import itertools
import sys
import threading
import numpy as np
[docs]def signals_allclose( # noqa: C901
t,
targets,
signals,
atol=1e-8,
rtol=1e-5,
buf=0,
delay=0,
plt=None,
show=False,
labels=None,
individual_results=False,
allclose=np.allclose,
):
"""Ensure all signal elements are within tolerances.
Allows for delay, removing the beginning of the signal, and plotting.
Parameters
----------
t : array_like (T,)
Simulation time for the points in ``target`` and ``signals``.
targets : array_like (T, 1) or (T, N)
Reference signal or signals for error comparison.
signals : array_like (T, N)
Signals to be tested against the target signals.
atol, rtol : float
Absolute and relative tolerances.
buf : float
Length of time (in seconds) to remove from the beginnings of signals.
delay : float
Amount of delay (in seconds) to account for when doing comparisons.
plt : matplotlib.pyplot or mock
Pyplot interface for plotting the results, unless it's mocked out.
show : bool
Whether to show the plot immediately.
labels : list of string, length N
Labels of each signal to use when plotting.
individual_results : bool
If True, returns a separate ``allclose`` result for each signal.
allclose : callable
Function to compare two arrays for similarity.
"""
t = np.asarray(t)
dt = t[1] - t[0]
assert t.ndim == 1
assert np.allclose(np.diff(t), dt) # always use default allclose here
targets = np.asarray(targets)
signals = np.asarray(signals)
if targets.ndim == 1:
targets = targets.reshape((-1, 1))
if signals.ndim == 1:
signals = signals.reshape((-1, 1))
assert targets.ndim == 2 and signals.ndim == 2
assert t.size == targets.shape[0]
assert t.size == signals.shape[0]
assert targets.shape[1] == 1 or targets.shape[1] == signals.shape[1]
buf = int(np.round(buf / dt))
delay = int(np.round(delay / dt))
slice1 = slice(buf, len(t) - delay)
slice2 = slice(buf + delay, None)
if plt is not None:
if labels is None:
labels = [None] * len(signals)
elif isinstance(labels, str):
labels = [labels]
colors = ["b", "g", "r", "c", "m", "y", "k"]
def plot_target(ax, x, b=0, c="k"):
bound = atol + rtol * np.abs(x)
y = x - b
ax.plot(t[slice2], y[slice1], c + ":")
ax.plot(t[slice2], (y + bound)[slice1], c + "--")
ax.plot(t[slice2], (y - bound)[slice1], c + "--")
# signal plot
ax = plt.subplot(2, 1, 1)
for y, label in zip(signals.T, labels):
ax.plot(t, y, label=label)
if targets.shape[1] == 1:
plot_target(ax, targets[:, 0], c="k")
else:
color_cycle = itertools.cycle(colors)
for x in targets.T:
plot_target(ax, x, c=next(color_cycle))
ax.set_ylabel("signal")
if labels[0] is not None:
lgd = ax.legend(loc="upper left", bbox_to_anchor=(1.0, 1.0))
plt.bbox_extra_artists = (lgd,)
ax = plt.subplot(2, 1, 2)
if targets.shape[1] == 1:
x = targets[:, 0]
plot_target(ax, x, b=x, c="k")
for y, label in zip(signals.T, labels):
ax.plot(t[slice2], y[slice2] - x[slice1])
else:
color_cycle = itertools.cycle(colors)
for x, y, label in zip(targets.T, signals.T, labels):
c = next(color_cycle)
plot_target(ax, x, b=x, c=c)
ax.plot(t[slice2], y[slice2] - x[slice1], c, label=label)
ax.set_xlabel("time")
ax.set_ylabel("error")
if show:
plt.show()
if individual_results:
if targets.shape[1] == 1:
return [
allclose(y[slice2], targets[slice1, 0], atol=atol, rtol=rtol)
for y in signals.T
]
else:
return [
allclose(y[slice2], x[slice1], atol=atol, rtol=rtol)
for x, y in zip(targets.T, signals.T)
]
else:
return allclose(signals[slice2, :], targets[slice1, :], atol=atol, rtol=rtol)
[docs]class ThreadedAssertion:
"""Performs assertions in parallel.
Starts a number of threads, waits for each thread to execute some
initialization code, and then executes assertions in each thread.
"""
[docs] class AssertionWorker(threading.Thread):
def __init__(self, parent, barriers, n):
super().__init__()
self.parent = parent
self.barriers = barriers
self.n = n
self.assertion_result = None
self.exc_info = (None, None, None)
[docs] def run(self):
self.parent.init_thread(self)
self.barriers[self.n].set()
for barrier in self.barriers:
barrier.wait()
try:
self.parent.assert_thread(self)
self.assertion_result = True
except Exception:
self.assertion_result = False
self.exc_info = sys.exc_info()
finally:
self.parent.finish_thread(self)
def __init__(self, n_threads):
barriers = [threading.Event() for _ in range(n_threads)]
threads = [self.AssertionWorker(self, barriers, i) for i in range(n_threads)]
for t in threads:
t.start()
for t in threads:
t.join()
if not t.assertion_result:
raise self.exc_info[1].with_traceback(self.exc_info[2])
def init_thread(self, worker):
pass
def assert_thread(self, worker):
raise NotImplementedError()
def finish_thread(self, worker):
pass