Skip to content

Commit 1a94f4e

Browse files
committed
add suptitle() function for better centered titles
1 parent 404fbdf commit 1a94f4e

6 files changed

Lines changed: 224 additions & 91 deletions

File tree

control/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
from .modelsimp import *
9393
from .nichols import *
9494
from .phaseplot import *
95+
from .plotutil import *
9596
from .pzmap import *
9697
from .rlocus import *
9798
from .statefbk import *

control/freqplot.py

Lines changed: 70 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,34 @@
88
# charts is in nichols.py. The code for pole-zero diagrams is in pzmap.py
99
# and rlocus.py.
1010

11-
import numpy as np
12-
import matplotlib as mpl
13-
import matplotlib.pyplot as plt
11+
import itertools
1412
import math
1513
import warnings
16-
import itertools
1714
from os.path import commonprefix
1815

19-
from .ctrlutil import unwrap
16+
import matplotlib as mpl
17+
import matplotlib.pyplot as plt
18+
import numpy as np
19+
20+
from . import config
2021
from .bdalg import feedback
21-
from .margins import stability_margins
22+
from .ctrlutil import unwrap
2223
from .exception import ControlMIMONotImplemented
23-
from .statesp import StateSpace
24-
from .lti import LTI, frequency_response, _process_frequency_response
25-
from .xferfcn import TransferFunction
2624
from .frdata import FrequencyResponseData
25+
from .lti import LTI, _process_frequency_response, frequency_response
26+
from .margins import stability_margins
27+
from .plotutil import suptitle, _find_axes_center
28+
from .statesp import StateSpace
2729
from .timeplot import _make_legend_labels
28-
from . import config
30+
from .xferfcn import TransferFunction
2931

3032
__all__ = ['bode_plot', 'NyquistResponseData', 'nyquist_response',
3133
'nyquist_plot', 'singular_values_response',
3234
'singular_values_plot', 'gangof4_plot', 'gangof4_response',
3335
'bode', 'nyquist', 'gangof4']
3436

3537
# Default font dictionary
38+
# TODO: move common plotting params to 'ctrlplot' (in plotutil)
3639
_freqplot_rcParams = mpl.rcParams.copy()
3740
_freqplot_rcParams.update({
3841
'axes.labelsize': 'small',
@@ -57,6 +60,7 @@
5760
'freqplot.share_magnitude': 'row',
5861
'freqplot.share_phase': 'row',
5962
'freqplot.share_frequency': 'col',
63+
'freqplot.suptitle_frame': 'axes',
6064
}
6165

6266
#
@@ -229,6 +233,8 @@ def bode_plot(
229233
'freqplot', 'initial_phase', kwargs, None, pop=True)
230234
rcParams = config._get_param(
231235
'freqplot', 'rcParams', kwargs, _freqplot_defaults, pop=True)
236+
suptitle_frame = config._get_param(
237+
'freqplot', 'suptitle_frame', kwargs, _freqplot_defaults, pop=True)
232238

233239
# Set the default labels
234240
freq_label = config._get_param(
@@ -803,7 +809,7 @@ def _make_line_label(response, output_index, input_index):
803809
#
804810
# Finishing handling axes limit sharing
805811
#
806-
# This code handles labels on phase plots and also removes tick labels
812+
# This code handles labels on Bode plots and also removes tick labels
807813
# on shared axes. It needs to come *after* the plots are generated,
808814
# in order to handle two things:
809815
#
@@ -867,50 +873,6 @@ def gen_zero_centered_series(val_min, val_max, period):
867873
for i, j in itertools.product(range(nrows), range(ncols)):
868874
ax_array[i, j].set_xlim(omega_limits)
869875

870-
#
871-
# Update the plot title (= figure suptitle)
872-
#
873-
# If plots are built up by multiple calls to plot() and the title is
874-
# not given, then the title is updated to provide a list of unique text
875-
# items in each successive title. For data generated by the frequency
876-
# response function this will generate a common prefix followed by a
877-
# list of systems (e.g., "Step response for sys[1], sys[2]").
878-
#
879-
880-
# Set the initial title for the data (unique system names, preserving order)
881-
seen = set()
882-
sysnames = [response.sysname for response in data \
883-
if not (response.sysname in seen or seen.add(response.sysname))]
884-
if title is None:
885-
if data[0].title is None:
886-
title = "Bode plot for " + ", ".join(sysnames)
887-
else:
888-
title = data[0].title
889-
890-
if fig is not None and isinstance(title, str):
891-
# Get the current title, if it exists
892-
old_title = None if fig._suptitle is None else fig._suptitle._text
893-
new_title = title
894-
895-
if old_title is not None:
896-
# Find the common part of the titles
897-
common_prefix = commonprefix([old_title, new_title])
898-
899-
# Back up to the last space
900-
last_space = common_prefix.rfind(' ')
901-
if last_space > 0:
902-
common_prefix = common_prefix[:last_space]
903-
common_len = len(common_prefix)
904-
905-
# Add the new part of the title (usually the system name)
906-
if old_title[common_len:] != new_title[common_len:]:
907-
separator = ',' if len(common_prefix) > 0 else ';'
908-
new_title = old_title + separator + new_title[common_len:]
909-
910-
# Add the title
911-
with plt.rc_context(rcParams):
912-
fig.suptitle(new_title)
913-
914876
#
915877
# Label the axes (including header labels)
916878
#
@@ -949,26 +911,16 @@ def gen_zero_centered_series(val_min, val_max, period):
949911
ax_mag.set_ylabel("\n" + ax_mag.get_ylabel())
950912
ax_phase.set_ylabel("\n" + ax_phase.get_ylabel())
951913

952-
# TODO: remove?
953-
# Redraw the figure to get the proper locations for everything
954-
# fig.tight_layout()
914+
# Find the midpoint between the row axes (+ tight_layout)
915+
_, ypos = _find_axes_center(fig, [ax_mag, ax_phase])
955916

956917
# Get the bounding box including the labels
957918
inv_transform = fig.transFigure.inverted()
958919
mag_bbox = inv_transform.transform(
959920
ax_mag.get_tightbbox(fig.canvas.get_renderer()))
960-
phase_bbox = inv_transform.transform(
961-
ax_phase.get_tightbbox(fig.canvas.get_renderer()))
962-
963-
# Get the axes limits without labels for use in the y position
964-
mag_bot = inv_transform.transform(
965-
ax_mag.transAxes.transform((0, 0)))[1]
966-
phase_top = inv_transform.transform(
967-
ax_phase.transAxes.transform((0, 1)))[1]
968921

969922
# Figure out location for the text (center left in figure frame)
970923
xpos = mag_bbox[0, 0] # left edge
971-
ypos = (mag_bot + phase_top) / 2 # centered between axes
972924

973925
# Put a centered label as text outside the box
974926
fig.text(
@@ -981,6 +933,49 @@ def gen_zero_centered_series(val_min, val_max, period):
981933
f"To {data[0].output_labels[i]}\n" +
982934
ax_array[i, 0].get_ylabel())
983935

936+
#
937+
# Update the plot title (= figure suptitle)
938+
#
939+
# If plots are built up by multiple calls to plot() and the title is
940+
# not given, then the title is updated to provide a list of unique text
941+
# items in each successive title. For data generated by the frequency
942+
# response function this will generate a common prefix followed by a
943+
# list of systems (e.g., "Step response for sys[1], sys[2]").
944+
#
945+
946+
# Set the initial title for the data (unique system names, preserving order)
947+
seen = set()
948+
sysnames = [response.sysname for response in data \
949+
if not (response.sysname in seen or seen.add(response.sysname))]
950+
if title is None:
951+
if data[0].title is None:
952+
title = "Bode plot for " + ", ".join(sysnames)
953+
else:
954+
title = data[0].title
955+
956+
if fig is not None and isinstance(title, str):
957+
# Get the current title, if it exists
958+
old_title = None if fig._suptitle is None else fig._suptitle._text
959+
new_title = title
960+
961+
if old_title is not None:
962+
# Find the common part of the titles
963+
common_prefix = commonprefix([old_title, new_title])
964+
965+
# Back up to the last space
966+
last_space = common_prefix.rfind(' ')
967+
if last_space > 0:
968+
common_prefix = common_prefix[:last_space]
969+
common_len = len(common_prefix)
970+
971+
# Add the new part of the title (usually the system name)
972+
if old_title[common_len:] != new_title[common_len:]:
973+
separator = ',' if len(common_prefix) > 0 else ';'
974+
new_title = old_title + separator + new_title[common_len:]
975+
976+
# Add the title
977+
suptitle(title, fig=fig, rcParams=rcParams, frame=suptitle_frame)
978+
984979
#
985980
# Create legends
986981
#
@@ -1671,6 +1666,8 @@ def nyquist_plot(
16711666
'nyquist', 'start_marker', kwargs, _nyquist_defaults, pop=True)
16721667
start_marker_size = config._get_param(
16731668
'nyquist', 'start_marker_size', kwargs, _nyquist_defaults, pop=True)
1669+
suptitle_frame = config._get_param(
1670+
'freqplot', 'suptitle_frame', kwargs, _freqplot_defaults, pop=True)
16741671

16751672
# Set line styles for the curves
16761673
def _parse_linestyle(style_name, allow_false=False):
@@ -1894,8 +1891,7 @@ def _parse_linestyle(style_name, allow_false=False):
18941891
# Add the title
18951892
if title is None:
18961893
title = "Nyquist plot for " + ", ".join(labels)
1897-
with plt.rc_context(rcParams):
1898-
fig.suptitle(title)
1894+
suptitle(title, fig=fig, rcParams=rcParams, frame=suptitle_frame)
18991895

19001896
# Legacy return pocessing
19011897
if plot is True or return_contour is not None:
@@ -2285,6 +2281,8 @@ def singular_values_plot(
22852281
'freqplot', 'grid', kwargs, _freqplot_defaults, pop=True)
22862282
rcParams = config._get_param(
22872283
'freqplot', 'rcParams', kwargs, _freqplot_defaults, pop=True)
2284+
suptitle_frame = config._get_param(
2285+
'freqplot', 'suptitle_frame', kwargs, _freqplot_defaults, pop=True)
22882286

22892287
# If argument was a singleton, turn it into a tuple
22902288
data = data if isinstance(data, (list, tuple)) else (data,)
@@ -2398,7 +2396,7 @@ def singular_values_plot(
23982396
# Add a grid to the plot + labeling
23992397
if grid:
24002398
ax_sigma.grid(grid, which='both')
2401-
2399+
24022400
ax_sigma.set_ylabel(
24032401
"Singular Values [dB]" if dB else "Singular Values")
24042402
ax_sigma.set_xlabel("Frequency [Hz]" if Hz else "Frequency [rad/sec]")
@@ -2414,8 +2412,7 @@ def singular_values_plot(
24142412
# Add the title
24152413
if title is None:
24162414
title = "Singular values for " + ", ".join(labels)
2417-
with plt.rc_context(rcParams):
2418-
fig.suptitle(title)
2415+
suptitle(title, fig=fig, rcParams=rcParams, frame=suptitle_frame)
24192416

24202417
# Legacy return processing
24212418
if plot is not None:
@@ -2755,6 +2752,7 @@ def _process_ax_keyword(
27552752

27562753
return fig, axs
27572754

2755+
27582756
#
27592757
# Utility functions to create nice looking labels (KLD 5/23/11)
27602758
#

control/nichols.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
nichols.nichols_grid
1414
"""
1515

16-
import numpy as np
1716
import matplotlib.pyplot as plt
1817
import matplotlib.transforms
18+
import numpy as np
1919

20+
from . import config
2021
from .ctrlutil import unwrap
2122
from .freqplot import _default_frequency_range, _freqplot_defaults, \
22-
_get_line_labels
23+
_get_line_labels, _process_ax_keyword
2324
from .lti import frequency_response
25+
from .plotutil import suptitle
2426
from .statesp import StateSpace
2527
from .xferfcn import TransferFunction
26-
from . import config
2728

2829
__all__ = ['nichols_plot', 'nichols', 'nichols_grid']
2930

@@ -34,7 +35,7 @@
3435

3536

3637
def nichols_plot(
37-
data, omega=None, *fmt, grid=None, title=None,
38+
data, omega=None, *fmt, grid=None, title=None, ax=None,
3839
legend_loc='upper left', **kwargs):
3940
"""Nichols plot for a system.
4041
@@ -67,7 +68,7 @@ def nichols_plot(
6768
"""
6869
# Get parameter values
6970
grid = config._get_param('nichols', 'grid', grid, True)
70-
freqplot_rcParams = config._get_param(
71+
rcParams = config._get_param(
7172
'freqplot', 'rcParams', kwargs, _freqplot_defaults, pop=True)
7273

7374
# If argument was a singleton, turn it into a list
@@ -83,6 +84,8 @@ def nichols_plot(
8384
if any([resp.ninputs > 1 or resp.noutputs > 1 for resp in data]):
8485
raise NotImplementedError("MIMO Nichols plots not implemented")
8586

87+
fig, ax_nichols = _process_ax_keyword(ax, rcParams=rcParams, squeeze=True)
88+
8689
# Create a list of lines for the output
8790
out = np.empty(len(data), dtype=object)
8891

@@ -102,8 +105,7 @@ def nichols_plot(
102105
else f"Unknown-{idx_sys}"
103106

104107
# Generate the plot
105-
with plt.rc_context(freqplot_rcParams):
106-
out[idx] = plt.plot(x, y, *fmt, label=sysname, **kwargs)
108+
out[idx] = ax_nichols.plot(x, y, *fmt, label=sysname, **kwargs)
107109

108110
# Label the plot axes
109111
plt.xlabel('Phase [deg]')
@@ -117,19 +119,17 @@ def nichols_plot(
117119
nichols_grid()
118120

119121
# List of systems that are included in this plot
120-
ax_nichols = plt.gca()
121122
lines, labels = _get_line_labels(ax_nichols)
122123

123124
# Add legend if there is more than one system plotted
124125
if len(labels) > 1 and legend_loc is not False:
125-
with plt.rc_context(freqplot_rcParams):
126+
with plt.rc_context(rcParams):
126127
ax_nichols.legend(lines, labels, loc=legend_loc)
127128

128129
# Add the title
129130
if title is None:
130131
title = "Nichols plot for " + ", ".join(labels)
131-
with plt.rc_context(freqplot_rcParams):
132-
plt.suptitle(title)
132+
suptitle(title, fig=fig, rcParams=rcParams)
133133

134134
return out
135135

0 commit comments

Comments
 (0)