Skip to content

Commit 0103fe7

Browse files
committed
update root_locus_plot to use common ax processing
1 parent f818034 commit 0103fe7

5 files changed

Lines changed: 127 additions & 75 deletions

File tree

control/ctrlplot.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,75 @@
33
#
44
# Collection of functions that are used by various plotting functions.
55

6+
# Code pattern for control system plotting functions:
7+
#
8+
# def name_plot(sysdata, plot=None, **kwargs):
9+
# # Process keywords and set defaults
10+
# ax = kwargs.pop('ax', None)
11+
# color = kwargs.pop('color', None)
12+
# label = kwargs.pop('label', None)
13+
# rcParams = config._get_param(
14+
# 'nameplot', 'rcParams', kwargs, _nameplot_defaults, pop=True)
15+
#
16+
# # Make sure all keyword arguments were processed (if not checked later)
17+
# if kwargs:
18+
# raise TypeError("unrecognized keywords: ", str(kwargs))
19+
#
20+
# # Process the data (including generating responses for systems)
21+
# sysdata = list(sysdata)
22+
# if any([isinstance(sys, InputOutputSystem) for sys in sysdata]):
23+
# data = name_response(sysdata)
24+
# nrows = max([data.noutputs for data in sysdata])
25+
# ncols = max([data.ninputs for data in sysdata])
26+
#
27+
# # Legacy processing of plot keyword
28+
# if plot is False:
29+
# return data.x, data.y
30+
#
31+
# # Figure out the shape of the plot and find/create axes
32+
# fig, ax_array = _process_ax_keyword(ax, (nrows, ncols), rcParams)
33+
#
34+
# # Customize axes (curvilinear grids, shared axes, etc)
35+
#
36+
# # Plot the data
37+
# lines = np.full(ax_array.shape, [])
38+
# line_labels = _process_line_labels(label, ntraces, nrows, ncols)
39+
# for i, j in itertools.product(range(nrows), range(ncols)):
40+
# ax = ax_array[i, j]
41+
# color_cycle, color_offset = _process_color_keyword(ax)
42+
# for k in range(ntraces):
43+
# if color is None:
44+
# color = color_cycle[(k + color_offset) % len(color_cycle)]
45+
# label = line_labels[k, i, j]
46+
# lines[i, j] += ax.plot(data.x, data.y, color=color, label=label)
47+
#
48+
# # Customize and label the axes
49+
# for i, j in itertools.product(range(nrows), range(ncols)):
50+
# ax_array[i, j].set_xlabel("x label")
51+
# ax_array[i, j].set_ylabel("y label")
52+
#
53+
# # Create legends
54+
# legend_map = _process_legend_keywords(kwargs)
55+
# for i, j in itertools.product(range(nrows), range(ncols)):
56+
# if legend_map[i, j] is not None:
57+
# lines = ax_array[i, j].get_lines()
58+
# labels = _make_legend_labels(lines)
59+
# if len(labels) > 1:
60+
# legend_array[i, j] = ax.legend(
61+
# lines, labels, loc=legend_map[i, j])
62+
#
63+
# # Update the plot title
64+
# sysnames = [response.sysname for response in data]
65+
# if title is None:
66+
# title = "Name plot for " + ", ".join(sysnames)
67+
# _update_suptitle(fig, title, rcParams=rcParams)
68+
#
69+
# # Legacy processing of plot keyword
70+
# if plot is True:
71+
# return data
72+
#
73+
# return ControlPlot(lines, ax_array, fig, legend=legend_map)
74+
675
import warnings
776
from os.path import commonprefix
877

@@ -181,7 +250,8 @@ def get_plot_axes(line_array):
181250

182251

183252
def _process_ax_keyword(
184-
axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False):
253+
axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False,
254+
create_axes=True):
185255
"""Utility function to process ax keyword to plotting commands.
186256
187257
This function processes the `ax` keyword to plotting commands. If no
@@ -190,6 +260,11 @@ def _process_ax_keyword(
190260
current figure and axes are returned. Otherwise a new figure is
191261
created with axes of the desired shape.
192262
263+
If `create_axes` is False and a new/empty figure is returned, then axs
264+
is an array of the proper shape but None for each element. This allows
265+
the calling function to do the actual axis creation (needed for
266+
curvilinear grids that use the AxisArtist module).
267+
193268
Legacy behavior: some of the older plotting commands use a axes label
194269
to identify the proper axes for plotting. This behavior is supported
195270
through the use of the label keyword, but will only work if shape ==
@@ -204,14 +279,19 @@ def _process_ax_keyword(
204279
# Note: can't actually check the shape, just the total number of axes
205280
if len(axs) != np.prod(shape):
206281
with plt.rc_context(rcParams):
207-
if len(axs) != 0:
282+
if len(axs) != 0 and create_axes:
208283
# Create a new figure
209284
fig, axs = plt.subplots(*shape, squeeze=False)
210-
else:
285+
elif create_axes:
211286
# Create new axes on (empty) figure
212287
axs = fig.subplots(*shape, squeeze=False)
213-
fig.set_layout_engine('tight')
214-
fig.align_labels()
288+
else:
289+
# Create an empty array and let user create axes
290+
axs = np.full(shape, None)
291+
if create_axes: # if not creating axes, leave these to caller
292+
fig.set_layout_engine('tight')
293+
fig.align_labels()
294+
215295
else:
216296
# Use the existing axes, properly reshaped
217297
axs = np.asarray(axs).reshape(*shape)

control/pzmap.py

Lines changed: 36 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,6 @@ def plot(self, *args, **kwargs):
119119
and keywords.
120120
121121
"""
122-
# If this is a root locus plot, use rlocus defaults for grid
123-
if self.loci is not None:
124-
from .rlocus import _rlocus_defaults
125-
kwargs = kwargs.copy()
126-
kwargs['grid'] = config._get_param(
127-
'rlocus', 'grid', kwargs.get('grid', None), _rlocus_defaults)
128-
129122
return pole_zero_plot(self, *args, **kwargs)
130123

131124

@@ -267,11 +260,10 @@ def pole_zero_plot(
267260
268261
"""
269262
# Get parameter values
270-
grid = config._get_param('pzmap', 'grid', grid, _pzmap_defaults)
271263
marker_size = config._get_param('pzmap', 'marker_size', marker_size, 6)
272264
marker_width = config._get_param('pzmap', 'marker_width', marker_width, 1.5)
273265
xlim_user, ylim_user = xlim, ylim
274-
freqplot_rcParams = config._get_param(
266+
rcParams = config._get_param(
275267
'freqplot', 'rcParams', kwargs, _freqplot_defaults,
276268
pop=True, last=True)
277269
user_ax = ax
@@ -315,56 +307,41 @@ def pole_zero_plot(
315307
return poles, zeros
316308

317309
# Initialize the figure
318-
# TODO: turn into standard utility function (from plotutil.py?)
319-
# fig, ax = _process_ax_keyword(
320-
# user_ax, rcParams=freqplot_rcParams, squeeze=True, create_axes=False)
321-
# axs = [ax] if ax is not None else []
322-
if user_ax is None:
323-
fig = plt.gcf()
324-
axs = fig.get_axes()
325-
elif isinstance(user_ax, np.ndarray):
326-
axs = user_ax.reshape(-1)
327-
fig = axs[0].figure
328-
else:
329-
fig = ax.figure
330-
axs = [ax]
331-
332-
if len(axs) > 1:
333-
# Need to generate a new figure
334-
fig, axs = plt.figure(), []
335-
336-
with plt.rc_context(freqplot_rcParams):
337-
if grid and grid != 'empty':
338-
plt.clf()
339-
if all([isctime(dt=response.dt) for response in data]):
340-
ax, fig = sgrid(scaling=scaling)
341-
elif all([isdtime(dt=response.dt) for response in data]):
342-
ax, fig = zgrid(scaling=scaling)
343-
else:
344-
raise ValueError(
345-
"incompatible time bases; don't know how to grid")
346-
# Store the limits for later use
347-
xlim, ylim = ax.get_xlim(), ax.get_ylim()
348-
elif len(axs) == 0:
349-
if grid == 'empty':
350-
# Leave off grid entirely
310+
fig, ax = _process_ax_keyword(
311+
user_ax, rcParams=rcParams, squeeze=True, create_axes=False)
312+
313+
if ax is None:
314+
# Determine what type of grid to use
315+
if rlocus_plot:
316+
from .rlocus import _rlocus_defaults
317+
grid = config._get_param('rlocus', 'grid', grid, _rlocus_defaults)
318+
else:
319+
grid = config._get_param('pzmap', 'grid', grid, _pzmap_defaults)
320+
321+
# Create the axes with the appropriate grid
322+
with plt.rc_context(rcParams):
323+
if grid and grid != 'empty':
324+
if all([isctime(dt=response.dt) for response in data]):
325+
ax, fig = sgrid(scaling=scaling)
326+
elif all([isdtime(dt=response.dt) for response in data]):
327+
ax, fig = zgrid(scaling=scaling)
328+
else:
329+
raise ValueError(
330+
"incompatible time bases; don't know how to grid")
331+
# Store the limits for later use
332+
xlim, ylim = ax.get_xlim(), ax.get_ylim()
333+
elif grid == 'empty':
351334
ax = plt.axes()
352335
xlim = ylim = [np.inf, -np.inf] # use data to set limits
353336
else:
354-
# draw stability boundary; use first response timebase
355337
ax, fig = nogrid(data[0].dt, scaling=scaling)
356338
xlim, ylim = ax.get_xlim(), ax.get_ylim()
357-
else:
358-
# Use the existing axes and any grid that is there
359-
ax = axs[0]
360-
361-
# Store the limits for later use
362-
xlim, ylim = ax.get_xlim(), ax.get_ylim()
363-
364-
# Issue a warning if the user tried to set the grid type
365-
if grid:
366-
warnings.warn("axis already exists; grid keyword ignored")
367-
339+
else:
340+
# Store the limits for later use
341+
xlim, ylim = ax.get_xlim(), ax.get_ylim()
342+
if grid is not None:
343+
warnings.warn("axis already exists; grid keyword ignored")
344+
368345
# Handle color cycle manually as all root locus segments
369346
# of the same system are expected to be of the same color
370347
# TODO: replace with common function?
@@ -459,13 +436,13 @@ def pole_zero_plot(
459436
handle = (pole_line, zero_line)
460437
line_tuples.append(handle)
461438

462-
with plt.rc_context(freqplot_rcParams):
439+
with plt.rc_context(rcParams):
463440
legend = ax.legend(
464441
line_tuples, labels, loc=legend_loc,
465442
handler_map={tuple: HandlerTuple(ndivide=None)})
466443
else:
467444
# Regular legend, with lines
468-
with plt.rc_context(freqplot_rcParams):
445+
with plt.rc_context(rcParams):
469446
legend = ax.legend(lines, labels, loc=legend_loc)
470447
else:
471448
legend = None
@@ -475,7 +452,8 @@ def pole_zero_plot(
475452
title = ("Root locus plot for " if rlocus_plot
476453
else "Pole/zero plot for ") + ", ".join(labels)
477454
if user_ax is None:
478-
suptitle(title)
455+
with plt.rc_context(rcParams):
456+
fig.suptitle(title)
479457

480458
# Add dispather to handle choosing a point on the diagram
481459
if interactive:
@@ -497,7 +475,7 @@ def _click_dispatcher(event):
497475
_mark_root_locus_gain(ax, sys, K)
498476

499477
# Display the parameters in the axes title
500-
with plt.rc_context(freqplot_rcParams):
478+
with plt.rc_context(rcParams):
501479
ax.set_title(_create_root_locus_label(sys, K, s))
502480

503481
ax.figure.canvas.draw()

control/rlocus.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,6 @@ def root_locus_plot(
173173
for oldkey in ['kvect', 'k']:
174174
gains = config._process_legacy_keyword(kwargs, oldkey, 'gains', gains)
175175

176-
# Set default parameters
177-
grid = config._get_param('rlocus', 'grid', grid, _rlocus_defaults)
178-
179176
if isinstance(sysdata, list) and all(
180177
[isinstance(sys, LTI) for sys in sysdata]) or \
181178
isinstance(sysdata, LTI):

control/tests/ctrlplot_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,14 @@ def test_plot_functions(resp_fcn, plot_fcn):
105105
# Plot should have landed on top of previous plot
106106
if resp_fcn is not None:
107107
assert cplt2.figure == cplt1.figure
108-
if plot_fcn != ct.root_locus_plot:
109-
assert np.all(cplt2.axes == cplt1.axes)
110-
else:
111-
warnings.warn("test skipped for root locus plot")
108+
assert np.all(cplt2.axes == cplt1.axes)
112109
assert len(cplt2.lines[0]) == len(cplt1.lines[0])
113110

114111
# Pass axes explicitly
115112
if resp_fcn is not None:
116113
cplt3 = resp.plot(**kwargs, **meth_kwargs, ax=cplt1.axes)
117114
assert cplt3.figure == cplt1.figure
118-
if plot_fcn != ct.root_locus_plot:
119-
assert np.all(cplt3.axes == cplt1.axes)
120-
else:
121-
warnings.warn("test skipped for root locus plot")
115+
assert np.all(cplt3.axes == cplt1.axes)
122116
assert len(cplt3.lines[0]) == len(cplt1.lines[0])
123117

124118

control/tests/rlocus_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_root_locus_plot_grid(self, sys, grid, method):
9595
if grid == 'empty':
9696
assert n_gridlines == 0
9797
assert not isinstance(ax, AA.Axes)
98-
elif grid is False or method == 'pzmap' and grid is None:
98+
elif grid is False:
9999
assert n_gridlines == 2 if sys.isctime() else 3
100100
assert not isinstance(ax, AA.Axes)
101101
elif sys.isdtime(strict=True):
@@ -174,6 +174,7 @@ def test_rlocus_default_wn(self):
174174
"sys, grid, xlim, ylim, interactive", [
175175
(ct.tf([1], [1, 2, 1]), None, None, None, False),
176176
])
177+
@pytest.mark.usefixtures("mplcleanup")
177178
def test_root_locus_plots(sys, grid, xlim, ylim, interactive):
178179
ct.root_locus_map(sys).plot(
179180
grid=grid, xlim=xlim, ylim=ylim, interactive=interactive)
@@ -182,13 +183,15 @@ def test_root_locus_plots(sys, grid, xlim, ylim, interactive):
182183

183184
# Test deprecated keywords
184185
@pytest.mark.parametrize("keyword", ["kvect", "k"])
186+
@pytest.mark.usefixtures("mplcleanup")
185187
def test_root_locus_legacy(keyword):
186188
sys = ct.rss(2, 1, 1)
187189
with pytest.warns(DeprecationWarning, match=f"'{keyword}' is deprecated"):
188190
ct.root_locus_plot(sys, **{keyword: [0, 1, 2]})
189191

190192

191193
# Generate plots used in documentation
194+
@pytest.mark.usefixtures("mplcleanup")
192195
def test_root_locus_documentation(savefigs=False):
193196
plt.figure()
194197
sys = ct.tf([1, 2], [1, 2, 3], name='SISO transfer function')

0 commit comments

Comments
 (0)