Source code for nengo.utils.matplotlib

from distutils.version import LooseVersion
import warnings

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from nengo.utils.ensemble import tuning_curves

has_prop_cycle = LooseVersion(matplotlib.__version__) >= "1.5.0"

if has_prop_cycle:
    from cycler import cycler  # Dependency of MPL form 1.5.0 onward


[docs]def get_color_cycle(): """Get matplotlib colour cycle.""" if has_prop_cycle: cycle = matplotlib.rcParams["axes.prop_cycle"] # Apparently the 'color' key may not exist, so have to fail gracefully try: return [prop["color"] for prop in cycle] except KeyError: pass # Fall back on deprecated axes.color_cycle return matplotlib.rcParams["axes.color_cycle"]
[docs]def set_color_cycle(colors, ax=None): """Set matplotlib colour cycle.""" if has_prop_cycle: if ax is None: plt.rc("axes", prop_cycle=cycler("color", colors)) else: ax.set_prop_cycle("color", colors) else: if ax is None: plt.rc("axes", color_cycle=colors) else: ax.set_color_cycle(colors)
[docs]def axis_size(ax=None): """Get axis width and height in pixels. Based on a StackOverflow response: https://stackoverflow.com/questions/19306510/determine-matplotlib-axis-size-in-pixels Parameters ---------- ax : axis object The axes to determine the size of. Defaults to current axes. Returns ------- width : float Width of axes in pixels. height : float Height of axes in pixels. """ ax = plt.gca() if ax is None else ax fig = ax.figure bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) return bbox.width * fig.dpi, bbox.height * fig.dpi
[docs]def implot(plt, x, y, Z, ax=None, colorbar=True, **kwargs): """Image plot of general data (like imshow but with non-pixel axes). Parameters ---------- plt : plot object Plot object, typically ``matplotlib.pyplot``. x : (M,) array_like Vector of x-axis points, must be linear (equally spaced). y : (N,) array_like Vector of y-axis points, must be linear (equally spaced). Z : (M, N) array_like Matrix of data to be displayed, the value at each (x, y) point. ax : axis object (optional) A specific axis to plot on (defaults to ``plt.gca()``). colorbar: boolean (optional) Whether to plot a colorbar. **kwargs Additional arguments for ``ax.imshow``. """ ax = plt.gca() if ax is None else ax def is_linear(x): diff = np.diff(x) return np.allclose(diff, diff[0]) assert is_linear(x) and is_linear(y) image = ax.imshow(Z, aspect="auto", extent=(x[0], x[-1], y[-1], y[0]), **kwargs) if colorbar: plt.colorbar(image, ax=ax)
[docs]def rasterplot(time, spikes, ax=None, use_eventplot=False, **kwargs): # noqa """Generate a raster plot of the provided spike data. Parameters ---------- time : array Time data from the simulation spikes : array The spike data with columns for each neuron and 1s indicating spikes ax : matplotlib.axes.Axes, optional The figure axes to plot into. If None, we will use current axes. use_eventplot : boolean, optional Whether to use the new Matplotlib ``eventplot`` routine. It is slower and makes larger image files, so we do not use it by default. Returns ------- ax : matplotlib.axes.Axes The axes that were plotted into Examples -------- .. testcode:: from nengo.utils.matplotlib import rasterplot with nengo.Network() as net: a = nengo.Ensemble(20, 1) p = nengo.Probe(a.neurons) with nengo.Simulator(net) as sim: sim.run(1) rasterplot(sim.trange(), sim.data[p]) .. testoutput:: :hide: ... """ n_times, n_neurons = spikes.shape if ax is None: ax = plt.gca() if use_eventplot and not hasattr(ax, "eventplot"): warnings.warn( "Matplotlib version %s does not have 'eventplot'. " "Falling back to non-eventplot version." % matplotlib.__version__ ) use_eventplot = False colors = kwargs.pop("colors", None) if colors is None: color_cycle = get_color_cycle() colors = [color_cycle[i % len(color_cycle)] for i in range(n_neurons)] # --- plotting if use_eventplot: spiketimes = [time[s > 0].ravel() for s in spikes.T] for ix in range(n_neurons): if spiketimes[ix].size == 0: spiketimes[ix] = np.array([-np.inf]) # hack to make 'eventplot' count from 1 instead of 0 spiketimes = [np.array([-np.inf])] + spiketimes colors = [["k"]] + colors ax.eventplot(spiketimes, colors=colors, **kwargs) else: kwargs.setdefault("linestyle", "None") kwargs.setdefault("marker", "|") # Default markersize determined by matching eventplot ax_height = axis_size(ax)[1] markersize = max(ax_height * 0.8 / n_neurons, 1) # For 1 - 3 neurons, we need an extra fudge factor to match eventplot markersize -= max(4 - n_neurons, 0) ** 2 * ax_height * 0.005 kwargs.setdefault("markersize", markersize) kwargs.setdefault("markeredgewidth", 1) for i in range(n_neurons): spiketimes = time[spikes[:, i] > 0].ravel() ax.plot( spiketimes, np.zeros_like(spiketimes) + (i + 1), color=colors[i], **kwargs, ) # --- set axes limits if n_times > 1: ax.set_xlim(time[0], time[-1]) ax.set_ylim(n_neurons + 0.6, 0.4) if n_neurons < 5: # make sure only integer ticks for small neuron numbers ax.set_yticks(np.arange(1, n_neurons + 1)) # --- remove ticks as these are distracting in rasters ax.xaxis.set_ticks_position("none") ax.yaxis.set_ticks_position("none") return ax
[docs]def plot_tuning_curves(ensemble, sim, connection=None, ax=None): """Plot tuning curves for the given ensemble and simulator. If a connection is provided, the decoders will be used to set the colours of the tuning curves. """ if ax is None: ax = plt.gca() evals, t_curves = tuning_curves(ensemble, sim) if connection is not None: if connection.dimensions > 1: warnings.warn("Ignoring dimensions > 1 in plot_tuning_curves") cm = plt.cm.ScalarMappable(cmap=plt.cm.coolwarm) set_color_cycle(cm.to_rgba(sim.data[connection].decoders[0]), ax=ax) ax.plot(evals, t_curves)