Skip to content

Commit 234e6ec

Browse files
committed
move code around to new locations
1 parent 6406868 commit 234e6ec

File tree

8 files changed

+281
-270
lines changed

8 files changed

+281
-270
lines changed

control/ctrlplot.py

Lines changed: 227 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,36 @@
55

66
from os.path import commonprefix
77

8+
import matplotlib as mpl
89
import matplotlib.pyplot as plt
910
import numpy as np
1011

1112
from . import config
1213

1314
__all__ = ['suptitle', 'get_plot_axes']
1415

16+
#
17+
# Style parameters
18+
#
19+
20+
_ctrlplot_rcParams = mpl.rcParams.copy()
21+
_ctrlplot_rcParams.update({
22+
'axes.labelsize': 'small',
23+
'axes.titlesize': 'small',
24+
'figure.titlesize': 'medium',
25+
'legend.fontsize': 'x-small',
26+
'xtick.labelsize': 'small',
27+
'ytick.labelsize': 'small',
28+
})
29+
30+
31+
#
32+
# User functions
33+
#
34+
# The functions below can be used by users to modify ctrl plots or get
35+
# information about them.
36+
#
37+
1538

1639
def suptitle(
1740
title, fig=None, frame='axes', **kwargs):
@@ -35,7 +58,7 @@ def suptitle(
3558
Additional keywords (passed to matplotlib).
3659
3760
"""
38-
rcParams = config._get_param('freqplot', 'rcParams', kwargs, pop=True)
61+
rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
3962

4063
if fig is None:
4164
fig = plt.gcf()
@@ -61,10 +84,10 @@ def suptitle(
6184
def get_plot_axes(line_array):
6285
"""Get a list of axes from an array of lines.
6386
64-
This function can be used to return the set of axes corresponding to
65-
the line array that is returned by `time_response_plot`. This is useful for
66-
generating an axes array that can be passed to subsequent plotting
67-
calls.
87+
This function can be used to return the set of axes corresponding
88+
to the line array that is returned by `time_response_plot`. This
89+
is useful for generating an axes array that can be passed to
90+
subsequent plotting calls.
6891
6992
Parameters
7093
----------
@@ -89,6 +112,125 @@ def get_plot_axes(line_array):
89112
#
90113
# Utility functions
91114
#
115+
# These functions are used by plotting routines to provide a consistent way
116+
# of processing and displaing information.
117+
#
118+
119+
120+
def _process_ax_keyword(
121+
axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False):
122+
"""Utility function to process ax keyword to plotting commands.
123+
124+
This function processes the `ax` keyword to plotting commands. If no
125+
ax keyword is passed, the current figure is checked to see if it has
126+
the correct shape. If the shape matches the desired shape, then the
127+
current figure and axes are returned. Otherwise a new figure is
128+
created with axes of the desired shape.
129+
130+
Legacy behavior: some of the older plotting commands use a axes label
131+
to identify the proper axes for plotting. This behavior is supported
132+
through the use of the label keyword, but will only work if shape ==
133+
(1, 1) and squeeze == True.
134+
135+
"""
136+
if axs is None:
137+
fig = plt.gcf() # get current figure (or create new one)
138+
axs = fig.get_axes()
139+
140+
# Check to see if axes are the right shape; if not, create new figure
141+
# Note: can't actually check the shape, just the total number of axes
142+
if len(axs) != np.prod(shape):
143+
with plt.rc_context(rcParams):
144+
if len(axs) != 0:
145+
# Create a new figure
146+
fig, axs = plt.subplots(*shape, squeeze=False)
147+
else:
148+
# Create new axes on (empty) figure
149+
axs = fig.subplots(*shape, squeeze=False)
150+
fig.set_layout_engine('tight')
151+
fig.align_labels()
152+
else:
153+
# Use the existing axes, properly reshaped
154+
axs = np.asarray(axs).reshape(*shape)
155+
156+
if clear_text:
157+
# Clear out any old text from the current figure
158+
for text in fig.texts:
159+
text.set_visible(False) # turn off the text
160+
del text # get rid of it completely
161+
else:
162+
try:
163+
axs = np.asarray(axs).reshape(shape)
164+
except ValueError:
165+
raise ValueError(
166+
"specified axes are not the right shape; "
167+
f"got {axs.shape} but expecting {shape}")
168+
fig = axs[0, 0].figure
169+
170+
# Process the squeeze keyword
171+
if squeeze and shape == (1, 1):
172+
axs = axs[0, 0] # Just return the single axes object
173+
elif squeeze:
174+
axs = axs.squeeze()
175+
176+
return fig, axs
177+
178+
179+
# Turn label keyword into array indexed by trace, output, input
180+
# TODO: move to ctrlutil.py and update parameter names to reflect general use
181+
def _process_line_labels(label, ntraces, ninputs=0, noutputs=0):
182+
if label is None:
183+
return None
184+
185+
if isinstance(label, str):
186+
label = [label] * ntraces # single label for all traces
187+
188+
# Convert to an ndarray, if not done aleady
189+
try:
190+
line_labels = np.asarray(label)
191+
except ValueError:
192+
raise ValueError("label must be a string or array_like")
193+
194+
# Turn the data into a 3D array of appropriate shape
195+
# TODO: allow more sophisticated broadcasting (and error checking)
196+
try:
197+
if ninputs > 0 and noutputs > 0:
198+
if line_labels.ndim == 1 and line_labels.size == ntraces:
199+
line_labels = line_labels.reshape(ntraces, 1, 1)
200+
line_labels = np.broadcast_to(
201+
line_labels, (ntraces, ninputs, noutputs))
202+
else:
203+
line_labels = line_labels.reshape(ntraces, ninputs, noutputs)
204+
except ValueError:
205+
if line_labels.shape[0] != ntraces:
206+
raise ValueError("number of labels must match number of traces")
207+
else:
208+
raise ValueError("labels must be given for each input/output pair")
209+
210+
return line_labels
211+
212+
213+
# Get labels for all lines in an axes
214+
def _get_line_labels(ax, use_color=True):
215+
labels, lines = [], []
216+
last_color, counter = None, 0 # label unknown systems
217+
for i, line in enumerate(ax.get_lines()):
218+
label = line.get_label()
219+
if use_color and label.startswith("Unknown"):
220+
label = f"Unknown-{counter}"
221+
if last_color is None:
222+
last_color = line.get_color()
223+
elif last_color != line.get_color():
224+
counter += 1
225+
last_color = line.get_color()
226+
elif label[0] == '_':
227+
continue
228+
229+
if label not in labels:
230+
lines.append(line)
231+
labels.append(label)
232+
233+
return lines, labels
92234

93235

94236
# Utility function to make legend labels
@@ -160,3 +302,83 @@ def _find_axes_center(fig, axs):
160302
ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])]
161303

162304
return (np.sum(xlim)/2, np.sum(ylim)/2)
305+
306+
307+
# Internal function to add arrows to a curve
308+
def _add_arrows_to_line2D(
309+
axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8],
310+
arrowstyle='-|>', arrowsize=1, dir=1):
311+
"""
312+
Add arrows to a matplotlib.lines.Line2D at selected locations.
313+
314+
Parameters:
315+
-----------
316+
axes: Axes object as returned by axes command (or gca)
317+
line: Line2D object as returned by plot command
318+
arrow_locs: list of locations where to insert arrows, % of total length
319+
arrowstyle: style of the arrow
320+
arrowsize: size of the arrow
321+
322+
Returns:
323+
--------
324+
arrows: list of arrows
325+
326+
Based on https://stackoverflow.com/questions/26911898/
327+
328+
"""
329+
# Get the coordinates of the line, in plot coordinates
330+
if not isinstance(line, mpl.lines.Line2D):
331+
raise ValueError("expected a matplotlib.lines.Line2D object")
332+
x, y = line.get_xdata(), line.get_ydata()
333+
334+
# Determine the arrow properties
335+
arrow_kw = {"arrowstyle": arrowstyle}
336+
337+
color = line.get_color()
338+
use_multicolor_lines = isinstance(color, np.ndarray)
339+
if use_multicolor_lines:
340+
raise NotImplementedError("multicolor lines not supported")
341+
else:
342+
arrow_kw['color'] = color
343+
344+
linewidth = line.get_linewidth()
345+
if isinstance(linewidth, np.ndarray):
346+
raise NotImplementedError("multiwidth lines not supported")
347+
else:
348+
arrow_kw['linewidth'] = linewidth
349+
350+
# Figure out the size of the axes (length of diagonal)
351+
xlim, ylim = axes.get_xlim(), axes.get_ylim()
352+
ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]])
353+
diag = np.linalg.norm(ul - lr)
354+
355+
# Compute the arc length along the curve
356+
s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2))
357+
358+
# Truncate the number of arrows if the curve is short
359+
# TODO: figure out a smarter way to do this
360+
frac = min(s[-1] / diag, 1)
361+
if len(arrow_locs) and frac < 0.05:
362+
arrow_locs = [] # too short; no arrows at all
363+
elif len(arrow_locs) and frac < 0.2:
364+
arrow_locs = [0.5] # single arrow in the middle
365+
366+
# Plot the arrows (and return list if patches)
367+
arrows = []
368+
for loc in arrow_locs:
369+
n = np.searchsorted(s, s[-1] * loc)
370+
371+
if dir == 1 and n == 0:
372+
# Move the arrow forward by one if it is at start of a segment
373+
n = 1
374+
375+
# Place the head of the arrow at the desired location
376+
arrow_head = [x[n], y[n]]
377+
arrow_tail = [x[n - dir], y[n - dir]]
378+
379+
p = mpl.patches.FancyArrowPatch(
380+
arrow_tail, arrow_head, transform=axes.transData, lw=0,
381+
**arrow_kw)
382+
axes.add_patch(p)
383+
arrows.append(p)
384+
return arrows

0 commit comments

Comments
 (0)