Skip to content

Commit fb5c194

Browse files
committed
add unit tests for common plotting functionality
1 parent 8b7d399 commit fb5c194

3 files changed

Lines changed: 127 additions & 2 deletions

File tree

control/phaseplot.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from scipy.integrate import odeint
3737

3838
from . import config
39-
from .ctrlplot import ControlPlot, _add_arrows_to_line2D
39+
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _process_ax_keyword
4040
from .exception import ControlNotImplemented
4141
from .nlsys import NonlinearIOSystem, find_eqpt, input_output_response
4242

@@ -141,6 +141,9 @@ def phase_plane_plot(
141141
pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata
142142

143143
# Create axis if needed
144+
user_ax = ax
145+
# TODO: make use of _process_ax_keyword
146+
# fig, ax = _process_ax_keyword(user_ax, squeeze=True)
144147
if ax is None:
145148
fig, ax = plt.gcf(), plt.gca()
146149
else:
@@ -212,7 +215,8 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
212215
if initial_kwargs:
213216
raise TypeError("unrecognized keywords: ", str(initial_kwargs))
214217

215-
if fig is not None:
218+
# TODO: update to common code pattern
219+
if user_ax is None:
216220
ax.set_title(f"Phase portrait for {sys.name}")
217221
ax.set_xlabel(sys.state_labels[0])
218222
ax.set_ylabel(sys.state_labels[1])

control/pzmap.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,15 @@ def pole_zero_plot(
315315

316316
# Initialize the figure
317317
# TODO: turn into standard utility function (from plotutil.py?)
318+
# fig, ax = _process_ax_keyword(
319+
# user_ax, rcParams=freqplot_rcParams, squeeze=True, create_axes=False)
320+
# axs = [ax] if ax is not None else []
318321
if user_ax is None:
319322
fig = plt.gcf()
320323
axs = fig.get_axes()
324+
elif isinstance(user_ax, np.ndarray):
325+
axs = user_ax.reshape(-1)
326+
fig = axs[0].figure
321327
else:
322328
fig = ax.figure
323329
axs = [ax]

control/tests/ctrlplot_test.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,126 @@
11
# ctrlplot_test.py - test out control plotting utilities
22
# RMM, 27 Jun 2024
33

4+
import inspect
5+
import warnings
6+
47
import matplotlib.pyplot as plt
8+
import numpy as np
59
import pytest
610

711
import control as ct
812

13+
# List of all plotting functions
14+
resp_plot_fcns = [
15+
# response function plotting function
16+
(ct.frequency_response, ct.bode_plot),
17+
(ct.frequency_response, ct.nichols_plot),
18+
(ct.singular_values_response, ct.singular_values_plot),
19+
(ct.gangof4_response, ct.gangof4_plot),
20+
(ct.describing_function_response, ct.describing_function_plot),
21+
(None, ct.phase_plane_plot),
22+
(ct.pole_zero_map, ct.pole_zero_plot),
23+
(ct.nyquist_response, ct.nyquist_plot),
24+
(ct.root_locus_map, ct.root_locus_plot),
25+
(ct.initial_response, ct.time_response_plot),
26+
(ct.step_response, ct.time_response_plot),
27+
(ct.impulse_response, ct.time_response_plot),
28+
(ct.forced_response, ct.time_response_plot),
29+
(ct.input_output_response, ct.time_response_plot),
30+
]
31+
32+
deprecated_fcns = [ct.phase_plot]
33+
34+
# Make sure we didn't miss any plotting functions
35+
def test_find_respplot_functions():
36+
# Get the list of plotting functions
37+
plot_fcns = {respplot[1] for respplot in resp_plot_fcns}
38+
39+
# Look through every object in the package
40+
found = 0
41+
for name, obj in inspect.getmembers(ct):
42+
# Skip anything that is outside of this module
43+
if inspect.getmodule(obj) is not None and \
44+
not inspect.getmodule(obj).__name__.startswith('control'):
45+
# Skip anything that isn't part of the control package
46+
continue
47+
48+
# Only look for non-deprecated functions ending in 'plot'
49+
if not inspect.isfunction(obj) or name[-4:] != 'plot' or \
50+
obj in deprecated_fcns:
51+
continue
52+
53+
# Make sure that we have this on our list of functions
54+
assert obj in plot_fcns
55+
found += 1
56+
57+
assert found == len(plot_fcns)
58+
59+
60+
@pytest.mark.parametrize("resp_fcn, plot_fcn", resp_plot_fcns)
61+
@pytest.mark.usefixtures('mplcleanup')
62+
def test_plot_functions(resp_fcn, plot_fcn):
63+
# Create some systems to use
64+
sys1 = ct.rss(2, 1, 1, strictly_proper=True)
65+
sys2 = ct.rss(4, 1, 1, strictly_proper=True)
66+
67+
# Set up arguments
68+
kwargs = meth_kwargs = plot_fcn_kwargs = {}
69+
match resp_fcn, plot_fcn:
70+
case ct.describing_function_response, _:
71+
F = ct.descfcn.saturation_nonlinearity(1)
72+
amp = np.linspace(1, 4, 10)
73+
args = (sys1, F, amp)
74+
75+
case ct.gangof4_response, _:
76+
args = (sys1, sys2)
77+
78+
case ct.frequency_response, ct.nichols_plot:
79+
args = (sys1, )
80+
meth_kwargs = {'plot_type': 'nichols'}
81+
82+
case ct.root_locus_map, ct.root_locus_plot:
83+
args = (sys1, )
84+
meth_kwargs = plot_fcn_kwargs = {'interactive': False}
85+
86+
case (ct.forced_response | ct.input_output_response, _):
87+
timepts = np.linspace(1, 10)
88+
U = np.sin(timepts)
89+
args = (sys1, timepts, U)
90+
91+
case _, _:
92+
args = (sys1, )
93+
94+
# Call the plot through the response function
95+
if resp_fcn is not None:
96+
resp = resp_fcn(*args, **kwargs)
97+
cplt1 = resp.plot(**kwargs, **meth_kwargs)
98+
assert isinstance(cplt1, ct.ControlPlot)
99+
100+
# Call the plot directly, plotting on top of previous plot
101+
if plot_fcn not in [None, ct.time_response_plot]:
102+
cplt2 = plot_fcn(*args, **kwargs, **plot_fcn_kwargs)
103+
assert isinstance(cplt2, ct.ControlPlot)
104+
105+
# Plot should have landed on top of previous plot
106+
if resp_fcn is not None:
107+
assert cplt2.figure == cplt1.figure
108+
if plot_fcn != ct.root_locus_plot:
109+
assert np.all(cplt2.axes == cplt1.axes)
110+
else:
111+
warnings.warn("test skipped for root locus plot")
112+
assert len(cplt2.lines[0]) == len(cplt1.lines[0])
113+
114+
# Pass axes explicitly
115+
if resp_fcn is not None:
116+
cplt3 = resp.plot(**kwargs, **meth_kwargs, ax=cplt1.axes)
117+
assert cplt3.figure == cplt1.figure
118+
if plot_fcn != ct.root_locus_plot:
119+
assert np.all(cplt3.axes == cplt1.axes)
120+
else:
121+
warnings.warn("test skipped for root locus plot")
122+
assert len(cplt3.lines[0]) == len(cplt1.lines[0])
123+
9124

10125
@pytest.mark.usefixtures('mplcleanup')
11126
def test_rcParams():

0 commit comments

Comments
 (0)