Skip to content

Commit ed8a1c1

Browse files
committed
allow label keyword to override generated labels
1 parent cc6aeb6 commit ed8a1c1

4 files changed

Lines changed: 147 additions & 16 deletions

File tree

control/freqplot.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2647,12 +2647,13 @@ def _get_line_labels(ax, use_color=True):
26472647

26482648

26492649
# Turn label keyword into array indexed by trace, output, input
2650-
def _process_line_labels(label, nsys, ninputs=0, noutputs=0):
2650+
# TODO: move to ctrlutil.py and update parameter names to reflect general use
2651+
def _process_line_labels(label, ntraces, ninputs=0, noutputs=0):
26512652
if label is None:
26522653
return None
26532654

26542655
if isinstance(label, str):
2655-
label = [label]
2656+
label = [label] * ntraces # single label for all traces
26562657

26572658
# Convert to an ndarray, if not done aleady
26582659
try:
@@ -2664,12 +2665,14 @@ def _process_line_labels(label, nsys, ninputs=0, noutputs=0):
26642665
# TODO: allow more sophisticated broadcasting (and error checking)
26652666
try:
26662667
if ninputs > 0 and noutputs > 0:
2667-
if line_labels.ndim == 1:
2668-
line_labels = line_labels.reshape(nsys, 1, 1)
2669-
line_labels = np.broadcast_to(
2670-
line_labels,(nsys, ninputs, noutputs))
2668+
if line_labels.ndim == 1 and line_labels.size == ntraces:
2669+
line_labels = line_labels.reshape(ntraces, 1, 1)
2670+
line_labels = np.broadcast_to(
2671+
line_labels, (ntraces, ninputs, noutputs))
2672+
else:
2673+
line_labels = line_labels.reshape(ntraces, ninputs, noutputs)
26712674
except:
2672-
if line_labels.shape[0] != nsys:
2675+
if line_labels.shape[0] != ntraces:
26732676
raise ValueError("number of labels must match number of traces")
26742677
else:
26752678
raise ValueError("labels must be given for each input/output pair")

control/tests/timeplot_test.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,113 @@ def test_rcParams():
424424
my_rcParams['ytick.labelsize']
425425
assert fig._suptitle.get_fontsize() == my_rcParams['figure.titlesize']
426426

427+
428+
@pytest.mark.parametrize("resp_fcn", [
429+
ct.step_response, ct.initial_response, ct.impulse_response,
430+
ct.forced_response, ct.input_output_response])
431+
@pytest.mark.usefixtures("editsdefaults")
432+
def test_timeplot_trace_labels(resp_fcn):
433+
plt.close('all')
434+
sys1 = ct.rss(2, 2, 2, strictly_proper=True, name='sys1')
435+
sys2 = ct.rss(2, 2, 2, strictly_proper=True, name='sys2')
436+
437+
# Figure out the expected shape of the system
438+
match resp_fcn:
439+
case ct.step_response | ct.impulse_response:
440+
shape = (2, 2)
441+
kwargs = {}
442+
case ct.initial_response:
443+
shape = (2, 1)
444+
kwargs = {}
445+
case ct.forced_response | ct.input_output_response:
446+
shape = (4, 1) # outputs and inputs both plotted
447+
T = np.linspace(0, 10)
448+
U = [np.sin(T), np.cos(T)]
449+
kwargs = {'T': T, 'U': U}
450+
451+
# Use figure frame for suptitle to speed things up
452+
ct.set_defaults('freqplot', suptitle_frame='figure')
453+
454+
# Make sure default labels are as expected
455+
out = resp_fcn([sys1, sys2], **kwargs).plot()
456+
axs = ct.get_plot_axes(out)
457+
if axs.ndim == 1:
458+
legend = axs[0].get_legend().get_texts()
459+
else:
460+
legend = axs[0, -1].get_legend().get_texts()
461+
assert legend[0].get_text() == 'sys1'
462+
assert legend[1].get_text() == 'sys2'
463+
plt.close()
464+
465+
# Override labels all at once
466+
out = resp_fcn([sys1, sys2], **kwargs).plot(label=['line1', 'line2'])
467+
axs = ct.get_plot_axes(out)
468+
if axs.ndim == 1:
469+
legend = axs[0].get_legend().get_texts()
470+
else:
471+
legend = axs[0, -1].get_legend().get_texts()
472+
assert legend[0].get_text() == 'line1'
473+
assert legend[1].get_text() == 'line2'
474+
plt.close()
475+
476+
# Override labels one at a time
477+
out = resp_fcn(sys1, **kwargs).plot(label='line1')
478+
out = resp_fcn(sys2, **kwargs).plot(label='line2')
479+
axs = ct.get_plot_axes(out)
480+
if axs.ndim == 1:
481+
legend = axs[0].get_legend().get_texts()
482+
else:
483+
legend = axs[0, -1].get_legend().get_texts()
484+
assert legend[0].get_text() == 'line1'
485+
assert legend[1].get_text() == 'line2'
486+
plt.close()
487+
488+
489+
def test_full_label_override():
490+
sys1 = ct.rss(2, 2, 2, strictly_proper=True, name='sys1')
491+
sys2 = ct.rss(2, 2, 2, strictly_proper=True, name='sys2')
492+
493+
labels_2d = np.array([
494+
["outsys1u1y1", "outsys1u1y2", "outsys1u2y1", "outsys1u2y2",
495+
"outsys2u1y1", "outsys2u1y2", "outsys2u2y1", "outsys2u2y2"],
496+
["inpsys1u1y1", "inpsys1u1y2", "inpsys1u2y1", "inpsys1u2y2",
497+
"inpsys2u1y1", "inpsys2u1y2", "inpsys2u2y1", "inpsys2u2y2"]])
498+
499+
500+
labels_4d = np.empty((2, 2, 2, 2), dtype=object)
501+
for i, sys in enumerate(['sys1', 'sys2']):
502+
for j, trace in enumerate(['u1', 'u2']):
503+
for k, out in enumerate(['y1', 'y2']):
504+
labels_4d[i, j, k, 0] = "out" + sys + trace + out
505+
labels_4d[i, j, k, 1] = "inp" + sys + trace + out
506+
507+
# Test 4D labels
508+
out = ct.step_response([sys1, sys2]).plot(
509+
overlay_signals=True, overlay_traces=True, plot_inputs=True,
510+
label=labels_4d)
511+
axs = ct.get_plot_axes(out)
512+
assert axs.shape == (2, 1)
513+
legend_text = axs[0, 0].get_legend().get_texts()
514+
for i, label in enumerate(labels_2d[0]):
515+
assert legend_text[i].get_text() == label
516+
legend_text = axs[1, 0].get_legend().get_texts()
517+
for i, label in enumerate(labels_2d[1]):
518+
assert legend_text[i].get_text() == label
519+
520+
# Test 2D labels
521+
out = ct.step_response([sys1, sys2]).plot(
522+
overlay_signals=True, overlay_traces=True, plot_inputs=True,
523+
label=labels_2d)
524+
axs = ct.get_plot_axes(out)
525+
assert axs.shape == (2, 1)
526+
legend_text = axs[0, 0].get_legend().get_texts()
527+
for i, label in enumerate(labels_2d[0]):
528+
assert legend_text[i].get_text() == label
529+
legend_text = axs[1, 0].get_legend().get_texts()
530+
for i, label in enumerate(labels_2d[1]):
531+
assert legend_text[i].get_text() == label
532+
533+
427534
def test_relabel():
428535
sys1 = ct.rss(2, inputs='u', outputs='y')
429536
sys2 = ct.rss(1, 1, 1) # uses default i/o labels

control/timeplot.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
def time_response_plot(
5050
data, *fmt, ax=None, plot_inputs=None, plot_outputs=True,
5151
transpose=False, overlay_traces=False, overlay_signals=False,
52-
legend_map=None, legend_loc=None, add_initial_zero=True,
52+
legend_map=None, legend_loc=None, add_initial_zero=True, label=None,
5353
trace_labels=None, title=None, relabel=True, **kwargs):
5454
"""Plot the time response of an input/output system.
5555
@@ -112,6 +112,11 @@ def time_response_plot(
112112
input_props : array of dicts
113113
List of line properties to use when plotting combined inputs. The
114114
default values are set by config.defaults['timeplot.input_props'].
115+
label : str or array_like of str
116+
If present, replace automatically generated label(s) with the given
117+
label(s). If more than one line is being generated, an array of
118+
labels should be provided with label[trace, :, 0] representing the
119+
output labels and label[trace, :, 1] representing the input labels.
115120
legend_map : array of str, option
116121
Location of the legend for multi-trace plots. Specifies an array
117122
of legend location strings matching the shape of the subplots, with
@@ -152,7 +157,7 @@ def time_response_plot(
152157
config.defaults[''timeplot.rcParams'].
153158
154159
"""
155-
from .freqplot import _process_ax_keyword
160+
from .freqplot import _process_ax_keyword, _process_line_labels
156161
from .iosys import InputOutputSystem
157162
from .timeresp import TimeResponseData
158163

@@ -342,6 +347,7 @@ def time_response_plot(
342347
out[i, j] = [] # unique list in each element
343348

344349
# Utility function for creating line label
350+
# TODO: combine with freqplot version?
345351
def _make_line_label(signal_index, signal_labels, trace_index):
346352
label = "" # start with an empty label
347353

@@ -375,11 +381,22 @@ def _make_line_label(signal_index, signal_labels, trace_index):
375381
output_offset = fig._output_offset = getattr(fig, '_output_offset', 0)
376382
input_offset = fig._input_offset = getattr(fig, '_input_offset', 0)
377383

384+
#
385+
# Plot the lines for the response
386+
#
387+
388+
# Process labels
389+
line_labels = _process_line_labels(
390+
label, ntraces, max(ninputs, noutputs), 2)
391+
378392
# Go through each trace and each input/output
379393
for trace in range(ntraces):
380394
# Plot the output
381395
for i in range(noutputs):
382-
label = _make_line_label(i, data.output_labels, trace)
396+
if line_labels is None:
397+
label = _make_line_label(i, data.output_labels, trace)
398+
else:
399+
label = line_labels[trace, i, 0]
383400

384401
# Set up line properties for this output, trace
385402
if len(fmt) == 0:
@@ -397,7 +414,10 @@ def _make_line_label(signal_index, signal_labels, trace_index):
397414

398415
# Plot the input
399416
for i in range(ninputs):
400-
label = _make_line_label(i, data.input_labels, trace)
417+
if line_labels is None:
418+
label = _make_line_label(i, data.input_labels, trace)
419+
else:
420+
label = line_labels[trace, i, 1]
401421

402422
if add_initial_zero and data.ntraces > i \
403423
and data.trace_types[i] == 'step':
@@ -596,16 +616,15 @@ def _make_line_label(signal_index, signal_labels, trace_index):
596616
for i in range(nrows):
597617
for j in range(ncols):
598618
ax = ax_array[i, j]
599-
# Get the labels to use
600619
labels = [line.get_label() for line in ax.get_lines()]
601-
labels = _make_legend_labels(labels, plot_inputs == 'overlay')
620+
if line_labels is None:
621+
labels = _make_legend_labels(labels, plot_inputs == 'overlay')
602622

603623
# Update the labels to remove common strings
604624
if len(labels) > 1 and legend_map[i, j] != None:
605625
with plt.rc_context(rcParams):
606626
ax.legend(labels, loc=legend_map[i, j])
607627

608-
609628
#
610629
# Update the plot title (= figure suptitle)
611630
#

control/timeresp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,10 @@ class TimeResponseList(list):
746746
"""
747747
def plot(self, *args, **kwargs):
748748
out_full = None
749-
for response in self:
750-
out = TimeResponseData.plot(response, *args, **kwargs)
749+
label = kwargs.pop('label', [None] * len(self))
750+
for i, response in enumerate(self):
751+
out = TimeResponseData.plot(
752+
response, *args, label=label[i], **kwargs)
751753
if out_full is None:
752754
out_full = out
753755
else:

0 commit comments

Comments
 (0)