Skip to content

Commit 4f4746d

Browse files
committed
deprecate get_plot_axes (with legacy testing)
1 parent 02f2724 commit 4f4746d

4 files changed

Lines changed: 74 additions & 44 deletions

File tree

control/ctrlplot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ class ControlPlot(object):
148148
def __init__(self, lines, axes=None, figure=None, legend=None):
149149
self.lines = lines
150150
if axes is None:
151-
axes = get_plot_axes(lines)
151+
_get_axes = np.vectorize(lambda lines: lines[0].axes)
152+
axes = _get_axes(lines)
152153
self.axes = np.atleast_2d(axes)
153154
if figure is None:
154155
figure = self.axes[0, 0].figure
@@ -240,6 +241,7 @@ def get_plot_axes(line_array):
240241
Only the first element of each array entry is used to determine the axes.
241242
242243
"""
244+
warnings.warn("get_plot_axes is deprecated; use cplt.axes", FutureWarning)
243245
_get_axes = np.vectorize(lambda lines: lines[0].axes)
244246
if isinstance(line_array, ControlPlot):
245247
return _get_axes(line_array.lines)

control/tests/ctrlplot_test.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ def test_plot_ax_processing(resp_fcn, plot_fcn):
8989
get_line_color = lambda cplt: cplt.lines.reshape(-1)[0][0].get_color()
9090
match resp_fcn, plot_fcn:
9191
case ct.describing_function_response, _:
92+
sys = ct.tf([1], [1, 2, 2, 1])
9293
F = ct.descfcn.saturation_nonlinearity(1)
9394
amp = np.linspace(1, 4, 10)
94-
args = (sys1, F, amp)
95+
args = (sys, F, amp)
9596
resp_kwargs = plot_kwargs = {'refine': False}
9697

9798
case ct.gangof4_response, _:
@@ -224,6 +225,8 @@ def test_plot_label_processing(resp_fcn, plot_fcn):
224225
expected_labels = ["sys1_", "sys2_"]
225226
match resp_fcn, plot_fcn:
226227
case ct.describing_function_response, _:
228+
sys1 = ct.tf([1], [1, 2, 2, 1], name="sys[1]")
229+
sys2 = ct.tf([1.1], [1, 2, 2, 1], name="sys[2]")
227230
F = ct.descfcn.saturation_nonlinearity(1)
228231
amp = np.linspace(1, 4, 10)
229232
args1 = (sys1, F, amp)
@@ -332,6 +335,8 @@ def test_siso_plot_legend_processing(resp_fcn, plot_fcn):
332335
default_labels = ["sys[1]", "sys[2]"]
333336
match resp_fcn, plot_fcn:
334337
case ct.describing_function_response, _:
338+
sys1 = ct.tf([1], [1, 2, 2, 1], name="sys[1]")
339+
sys2 = ct.tf([1.1], [1, 2, 2, 1], name="sys[2]")
335340
F = ct.descfcn.saturation_nonlinearity(1)
336341
amp = np.linspace(1, 4, 10)
337342
args1 = (sys1, F, amp)
@@ -488,6 +493,8 @@ def test_plot_title_processing(resp_fcn, plot_fcn):
488493
expected_title = "sys1_, sys2_"
489494
match resp_fcn, plot_fcn:
490495
case ct.describing_function_response, _:
496+
sys1 = ct.tf([1], [1, 2, 2, 1], name="sys[1]")
497+
sys2 = ct.tf([1.1], [1, 2, 2, 1], name="sys[2]")
491498
F = ct.descfcn.saturation_nonlinearity(1)
492499
amp = np.linspace(1, 4, 10)
493500
args1 = (sys1, F, amp)
@@ -617,6 +624,8 @@ def test_rcParams(resp_fcn, plot_fcn):
617624
expected_title = "sys1_, sys2_"
618625
match resp_fcn, plot_fcn:
619626
case ct.describing_function_response, _:
627+
sys1 = ct.tf([1], [1, 2, 2, 1], name="sys[1]")
628+
sys2 = ct.tf([1], [1, 2, 2, 1], name="sys[2]")
620629
F = ct.descfcn.saturation_nonlinearity(1)
621630
amp = np.linspace(1, 4, 10)
622631
args1 = (sys1, F, amp)
@@ -747,8 +756,29 @@ def test_rcParams(resp_fcn, plot_fcn):
747756
assert ct.ctrlplot.rcParams[key] != my_rcParams[key]
748757

749758

750-
def test_deprecation_warning():
759+
def test_deprecation_warnings():
751760
sys = ct.rss(2, 2, 2)
752761
lines = ct.step_response(sys).plot(overlay_traces=True)
753762
with pytest.warns(FutureWarning, match="deprecated"):
754763
assert len(lines[0, 0]) == 2
764+
765+
cplt = ct.step_response(sys).plot()
766+
with pytest.warns(FutureWarning, match="deprecated"):
767+
axs = ct.get_plot_axes(cplt)
768+
assert np.all(axs == cplt.axes)
769+
770+
with pytest.warns(FutureWarning, match="deprecated"):
771+
axs = ct.get_plot_axes(cplt.lines)
772+
assert np.all(axs == cplt.axes)
773+
774+
775+
def test_ControlPlot_init():
776+
sys = ct.rss(2, 2, 2)
777+
cplt = ct.step_response(sys).plot()
778+
779+
# Create a ControlPlot from data, without the axes or figure
780+
cplt_raw = ct.ControlPlot(cplt.lines)
781+
assert np.all(cplt_raw.lines == cplt.lines)
782+
assert np.all(cplt_raw.axes == cplt.axes)
783+
assert cplt_raw.figure == cplt.figure
784+

control/tests/freqplot_test.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def test_response_plots(
142142
def test_manual_response_limits():
143143
# Default response: limits should be the same across rows
144144
cplt = manual_response.plot()
145-
axs = ct.get_plot_axes(cplt) # legacy usage OK
145+
axs = cplt.axes
146146
for i in range(manual_response.noutputs):
147147
for j in range(1, manual_response.ninputs):
148148
# Everything in the same row should have the same limits
@@ -305,8 +305,8 @@ def test_bode_share_options():
305305
ct.set_defaults('freqplot', title_frame='figure')
306306

307307
# Default sharing should share along rows and cols for mag and phase
308-
lines = ct.bode_plot(manual_response)
309-
axs = ct.get_plot_axes(lines)
308+
cplt = ct.bode_plot(manual_response)
309+
axs = cplt.axes
310310
for i in range(axs.shape[0]):
311311
for j in range(axs.shape[1]):
312312
# Share y limits along rows
@@ -317,8 +317,8 @@ def test_bode_share_options():
317317

318318
# Sharing along y axis for mag but not phase
319319
plt.figure()
320-
lines = ct.bode_plot(manual_response, share_phase='none')
321-
axs = ct.get_plot_axes(lines)
320+
cplt = ct.bode_plot(manual_response, share_phase='none')
321+
axs = cplt.axes
322322
for i in range(int(axs.shape[0] / 2)):
323323
for j in range(axs.shape[1]):
324324
if i != 0:
@@ -330,8 +330,8 @@ def test_bode_share_options():
330330

331331
# Turn off sharing for magnitude and phase
332332
plt.figure()
333-
lines = ct.bode_plot(manual_response, sharey='none')
334-
axs = ct.get_plot_axes(lines)
333+
cplt = ct.bode_plot(manual_response, sharey='none')
334+
axs = cplt.axes
335335
for i in range(int(axs.shape[0] / 2)):
336336
for j in range(axs.shape[1]):
337337
if i != 0:
@@ -345,7 +345,7 @@ def test_bode_share_options():
345345

346346
# Turn off sharing in x axes
347347
plt.figure()
348-
lines = ct.bode_plot(manual_response, sharex='none')
348+
cplt = ct.bode_plot(manual_response, sharex='none')
349349
# TODO: figure out what to check
350350

351351

@@ -355,11 +355,11 @@ def test_freqplot_plot_type(plot_type):
355355
response = ct.singular_values_response(ct.rss(2, 1, 1))
356356
else:
357357
response = ct.frequency_response(ct.rss(2, 1, 1))
358-
lines = response.plot(plot_type=plot_type)
358+
cplt = response.plot(plot_type=plot_type)
359359
if plot_type == 'bode':
360-
assert lines.shape == (2, 1)
360+
assert cplt.lines.shape == (2, 1)
361361
else:
362-
assert lines.shape == (1, )
362+
assert cplt.lines.shape == (1, )
363363

364364
@pytest.mark.parametrize("plt_fcn", [ct.bode_plot, ct.singular_values_plot])
365365
@pytest.mark.usefixtures("editsdefaults")
@@ -379,14 +379,14 @@ def _get_visible_limits(ax):
379379
ct.tf([1], [1, 2, 1]), np.logspace(-1, 1))
380380

381381
# Generate a plot without overridding the limits
382-
lines = plt_fcn(response)
383-
ax = ct.get_plot_axes(lines)
382+
cplt = plt_fcn(response)
383+
ax = cplt.axes
384384
np.testing.assert_allclose(
385385
_get_visible_limits(ax.reshape(-1)[0]), np.array([0.1, 10]))
386386

387387
# Now reset the limits
388-
lines = plt_fcn(response, omega_limits=(1, 100))
389-
ax = ct.get_plot_axes(lines)
388+
cplt = plt_fcn(response, omega_limits=(1, 100))
389+
ax = cplt.axes
390390
np.testing.assert_allclose(
391391
_get_visible_limits(ax.reshape(-1)[0]), np.array([1, 100]))
392392

@@ -400,7 +400,7 @@ def test_gangof4_trace_labels():
400400
# Make sure default labels are as expected
401401
cplt = ct.gangof4_response(P1, C1).plot()
402402
cplt = ct.gangof4_response(P2, C2).plot()
403-
axs = ct.get_plot_axes(cplt) # legacy usage OK
403+
axs = cplt.axes
404404
legend = axs[0, 1].get_legend().get_texts()
405405
assert legend[0].get_text() == 'P=P1, C=C1'
406406
assert legend[1].get_text() == 'P=P2, C=C2'
@@ -409,7 +409,7 @@ def test_gangof4_trace_labels():
409409
# Suffix truncation
410410
cplt = ct.gangof4_response(P1, C1).plot()
411411
cplt = ct.gangof4_response(P2, C1).plot()
412-
axs = ct.get_plot_axes(cplt) # legacy usage OK
412+
axs = cplt.axes
413413
legend = axs[0, 1].get_legend().get_texts()
414414
assert legend[0].get_text() == 'P=P1'
415415
assert legend[1].get_text() == 'P=P2'
@@ -418,7 +418,7 @@ def test_gangof4_trace_labels():
418418
# Prefix turncation
419419
cplt = ct.gangof4_response(P1, C1).plot()
420420
cplt = ct.gangof4_response(P1, C2).plot()
421-
axs = ct.get_plot_axes(cplt) # legacy usage OK
421+
axs = cplt.axes
422422
legend = axs[0, 1].get_legend().get_texts()
423423
assert legend[0].get_text() == 'C=C1'
424424
assert legend[1].get_text() == 'C=C2'
@@ -427,7 +427,7 @@ def test_gangof4_trace_labels():
427427
# Override labels
428428
cplt = ct.gangof4_response(P1, C1).plot(label='xxx, line1, yyy')
429429
cplt = ct.gangof4_response(P2, C2).plot(label='xxx, line2, yyy')
430-
axs = ct.get_plot_axes(cplt) # legacy usage OK
430+
axs = cplt.axes
431431
legend = axs[0, 1].get_legend().get_texts()
432432
assert legend[0].get_text() == 'xxx, line1, yyy'
433433
assert legend[1].get_text() == 'xxx, line2, yyy'
@@ -446,7 +446,7 @@ def test_freqplot_line_labels(plt_fcn):
446446

447447
# Make sure default labels are as expected
448448
cplt = plt_fcn([sys1, sys2])
449-
axs = ct.get_plot_axes(cplt) # legacy usage OK
449+
axs = cplt.axes
450450
if axs.ndim == 1:
451451
legend = axs[0].get_legend().get_texts()
452452
else:
@@ -457,7 +457,7 @@ def test_freqplot_line_labels(plt_fcn):
457457

458458
# Override labels all at once
459459
cplt = plt_fcn([sys1, sys2], label=['line1', 'line2'])
460-
axs = ct.get_plot_axes(cplt) # legacy usage OK
460+
axs = cplt.axes
461461
if axs.ndim == 1:
462462
legend = axs[0].get_legend().get_texts()
463463
else:
@@ -469,7 +469,7 @@ def test_freqplot_line_labels(plt_fcn):
469469
# Override labels one at a time
470470
cplt = plt_fcn(sys1, label='line1')
471471
cplt = plt_fcn(sys2, label='line2')
472-
axs = ct.get_plot_axes(cplt) # legacy usage OK
472+
axs = cplt.axes
473473
if axs.ndim == 1:
474474
legend = axs[0].get_legend().get_texts()
475475
else:
@@ -495,7 +495,7 @@ def test_line_labels_bode(kwargs, labels):
495495
ct.bode_plot([sys1, sys2], label=['line1'])
496496

497497
cplt = ct.bode_plot([sys1, sys2], label=labels, **kwargs)
498-
axs = ct.get_plot_axes(cplt) # legacy usage OK
498+
axs = cplt.axes
499499
legend_texts = axs[0, -1].get_legend().get_texts()
500500
for i, legend in enumerate(legend_texts):
501501
assert legend.get_text() == labels[i]
@@ -524,7 +524,7 @@ def test_freqplot_ax_keyword(plt_fcn, ninputs, noutputs):
524524
cplt1 = plt_fcn(sys)
525525

526526
# Draw again on the same figure, using array
527-
axs = ct.get_plot_axes(cplt1) # legacy usage OK
527+
axs = cplt1.axes
528528
cplt2 = plt_fcn(sys, ax=axs)
529529
np.testing.assert_equal(cplt1.axes, cplt2.axes)
530530

control/tests/timeplot_test.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,6 @@ def test_response_plots(
194194

195195
@pytest.mark.usefixtures('mplcleanup')
196196
def test_axes_setup():
197-
get_plot_axes = ct.get_plot_axes
198-
199197
sys_2x3 = ct.rss(4, 2, 3)
200198
sys_2x3b = ct.rss(4, 2, 3)
201199
sys_3x2 = ct.rss(4, 3, 2)
@@ -204,30 +202,30 @@ def test_axes_setup():
204202
# Two plots of the same size leaves axes unchanged
205203
cplt1 = ct.step_response(sys_2x3).plot()
206204
cplt2 = ct.step_response(sys_2x3b).plot()
207-
np.testing.assert_equal(get_plot_axes(cplt1), get_plot_axes(cplt2))
205+
np.testing.assert_equal(cplt1.axes, cplt2.axes)
208206
plt.close()
209207

210208
# Two plots of same net size leaves axes unchanged (unfortunately)
211209
cplt1 = ct.step_response(sys_2x3).plot()
212210
cplt2 = ct.step_response(sys_3x2).plot()
213211
np.testing.assert_equal(
214-
get_plot_axes(cplt1).reshape(-1), get_plot_axes(cplt2).reshape(-1))
212+
cplt1.axes.reshape(-1), cplt2.axes.reshape(-1))
215213
plt.close()
216214

217215
# Plots of different shapes generate new plots
218216
cplt1 = ct.step_response(sys_2x3).plot()
219217
cplt2 = ct.step_response(sys_3x1).plot()
220-
ax1_list = get_plot_axes(cplt1).reshape(-1).tolist()
221-
ax2_list = get_plot_axes(cplt2).reshape(-1).tolist()
218+
ax1_list = cplt1.axes.reshape(-1).tolist()
219+
ax2_list = cplt2.axes.reshape(-1).tolist()
222220
for ax in ax1_list:
223221
assert ax not in ax2_list
224222
plt.close()
225223

226224
# Passing a list of axes preserves those axes
227225
cplt1 = ct.step_response(sys_2x3).plot()
228226
cplt2 = ct.step_response(sys_3x1).plot()
229-
cplt3 = ct.step_response(sys_2x3b).plot(ax=get_plot_axes(cplt1))
230-
np.testing.assert_equal(get_plot_axes(cplt1), get_plot_axes(cplt3))
227+
cplt3 = ct.step_response(sys_2x3b).plot(ax=cplt1.axes)
228+
np.testing.assert_equal(cplt1.axes, cplt3.axes)
231229
plt.close()
232230

233231
# Sending an axes array of the wrong size raises exception
@@ -433,7 +431,7 @@ def test_timeplot_trace_labels(resp_fcn):
433431

434432
# Make sure default labels are as expected
435433
cplt = resp_fcn([sys1, sys2], **kwargs).plot()
436-
axs = ct.get_plot_axes(cplt.lines)
434+
axs = cplt.axes
437435
if axs.ndim == 1:
438436
legend = axs[0].get_legend().get_texts()
439437
else:
@@ -444,7 +442,7 @@ def test_timeplot_trace_labels(resp_fcn):
444442

445443
# Override labels all at once
446444
cplt = resp_fcn([sys1, sys2], **kwargs).plot(label=['line1', 'line2'])
447-
axs = ct.get_plot_axes(cplt.lines)
445+
axs = cplt.axes
448446
if axs.ndim == 1:
449447
legend = axs[0].get_legend().get_texts()
450448
else:
@@ -456,7 +454,7 @@ def test_timeplot_trace_labels(resp_fcn):
456454
# Override labels one at a time
457455
cplt = resp_fcn(sys1, **kwargs).plot(label='line1')
458456
cplt = resp_fcn(sys2, **kwargs).plot(label='line2')
459-
axs = ct.get_plot_axes(cplt.lines)
457+
axs = cplt.axes
460458
if axs.ndim == 1:
461459
legend = axs[0].get_legend().get_texts()
462460
else:
@@ -489,7 +487,7 @@ def test_full_label_override():
489487
cplt = ct.step_response([sys1, sys2]).plot(
490488
overlay_signals=True, overlay_traces=True, plot_inputs=True,
491489
label=labels_4d)
492-
axs = ct.get_plot_axes(cplt.lines)
490+
axs = cplt.axes
493491
assert axs.shape == (2, 1)
494492
legend_text = axs[0, 0].get_legend().get_texts()
495493
for i, label in enumerate(labels_2d[0]):
@@ -502,7 +500,7 @@ def test_full_label_override():
502500
cplt = ct.step_response([sys1, sys2]).plot(
503501
overlay_signals=True, overlay_traces=True, plot_inputs=True,
504502
label=labels_2d)
505-
axs = ct.get_plot_axes(cplt.lines)
503+
axs = cplt.axes
506504
assert axs.shape == (2, 1)
507505
legend_text = axs[0, 0].get_legend().get_texts()
508506
for i, label in enumerate(labels_2d[0]):
@@ -522,7 +520,7 @@ def test_relabel():
522520

523521
# Generate a new plot, which overwrites labels
524522
cplt = ct.step_response(sys2).plot()
525-
ax = ct.get_plot_axes(cplt.lines)
523+
ax = cplt.axes
526524
assert ax[0, 0].get_ylabel() == 'y[0]'
527525

528526
# Regenerate the first plot
@@ -570,23 +568,23 @@ def test_legend_customization():
570568

571569
# Generic input/output plot
572570
cplt = resp.plot(overlay_signals=True)
573-
axs = ct.get_plot_axes(cplt.lines)
571+
axs = cplt.axes
574572
assert axs[0, 0].get_legend()._loc == 7 # center right
575573
assert len(axs[0, 0].get_legend().get_texts()) == 2
576574
assert axs[1, 0].get_legend() == None
577575
plt.close()
578576

579577
# Hide legend
580578
cplt = resp.plot(overlay_signals=True, show_legend=False)
581-
axs = ct.get_plot_axes(cplt.lines)
579+
axs = cplt.axes
582580
assert axs[0, 0].get_legend() == None
583581
assert axs[1, 0].get_legend() == None
584582
plt.close()
585583

586584
# Put legend in both axes
587585
cplt = resp.plot(
588586
overlay_signals=True, legend_map=[['center left'], ['center right']])
589-
axs = ct.get_plot_axes(cplt.lines)
587+
axs = cplt.axes
590588
assert axs[0, 0].get_legend()._loc == 6 # center left
591589
assert len(axs[0, 0].get_legend().get_texts()) == 2
592590
assert axs[1, 0].get_legend()._loc == 7 # center right

0 commit comments

Comments
 (0)