Skip to content

Commit 609da0c

Browse files
committed
RadioButtons: use a layout parameter instead of grid labels
1 parent 2df3eca commit 609da0c

File tree

4 files changed

+150
-135
lines changed

4 files changed

+150
-135
lines changed

doc/release/next_whats_new/radio_buttons_2d_grid.rst

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
RadioButtons widget supports 2D grid layout
2-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1+
RadioButtons widget supports flexible layouts
2+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
33

4-
The `.widgets.RadioButtons` widget now supports arranging buttons in a 2D grid
5-
layout. Pass a list of lists of strings as the *labels* parameter to arrange
6-
buttons in a grid where each inner list represents a row.
4+
The `.widgets.RadioButtons` widget now supports arranging buttons in different
5+
layouts via the new *layout* parameter. You can arrange buttons vertically
6+
(default), horizontally, or in a 2D grid by passing a ``(rows, cols)`` tuple.
77

88
The *active* parameter and the ``RadioButtons.index_selected`` attribute
9-
continue to use a single integer index into the flattened array, reading
10-
left-to-right, top-to-bottom. The column positions are automatically calculated
11-
based on the maximum text width in each column, ensuring optimal spacing.
9+
continue to use a single integer index into the labels list. For grid layouts,
10+
buttons are positioned left-to-right, top-to-bottom. The column positions are
11+
automatically calculated based on the maximum text width in each column,
12+
ensuring optimal spacing.
1213

13-
See :doc:`/gallery/widgets/radio_buttons_grid` for a complete example.
14+
See :doc:`/gallery/widgets/radio_buttons_grid` for a ``(rows, cols)`` example.
1415

1516
.. plot::
1617
:include-source: true
@@ -23,23 +24,22 @@ See :doc:`/gallery/widgets/radio_buttons_grid` for a complete example.
2324
t = np.arange(0.0, 2.0, 0.01)
2425
s = np.sin(2*np.pi*t)
2526

26-
fig, (ax_plot, ax_buttons) = plt.subplots(1, 2, figsize=(8, 4),
27-
width_ratios=[3, 1])
28-
29-
line, = ax_plot.plot(t, s, lw=2, color='red')
30-
ax_plot.set_xlabel('Time (s)')
31-
ax_plot.set_ylabel('Amplitude')
32-
33-
ax_buttons.set_facecolor('lightgray')
34-
ax_buttons.set_title('Line Color', fontsize=12, pad=10)
35-
36-
colors = [
37-
['red', 'orange', 'yellow'],
38-
['green', 'blue', 'purple'],
39-
['brown', 'pink', 'gray'],
40-
]
41-
42-
radio = RadioButtons(ax_buttons, colors, active=0)
27+
fig, axes = plt.subplot_mosaic(
28+
[
29+
['main'],
30+
['.'],
31+
['buttons'],
32+
],
33+
height_ratios=[8, 0.4, 1],
34+
)
35+
36+
line, = axes['main'].plot(t, s, lw=2, color='red')
37+
axes['main'].set_xlabel('Time (s)')
38+
axes['main'].set_ylabel('Amplitude')
39+
40+
axes['buttons'].set_facecolor('lightgray')
41+
colors = ['red', 'orange', 'yellow', 'green', 'blue', 'purple', 'brown', 'black']
42+
radio = RadioButtons(axes['buttons'], colors, active=0, layout='horizontal')
4343

4444
def color_func(label):
4545
line.set_color(label)

galleries/examples/widgets/radio_buttons_grid.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
66
Using radio buttons in a 2D grid layout.
77
8-
Radio buttons can be arranged in a 2D grid by passing a list of lists of
9-
strings as the *labels* parameter. This is useful when you have multiple
8+
Radio buttons can be arranged in a 2D grid by passing a ``(rows, cols)``
9+
tuple to the *layout* parameter. This is useful when you have multiple
1010
related options that are best displayed in a grid format rather than a
1111
vertical list.
1212
@@ -26,7 +26,7 @@
2626
fig, (ax_plot, ax_buttons) = plt.subplots(
2727
1, 2,
2828
figsize=(8, 4),
29-
width_ratios=[4, 1],
29+
width_ratios=[4, 1.4],
3030
)
3131

3232
# Create initial plot
@@ -40,12 +40,8 @@
4040
ax_buttons.set_facecolor('lightgray')
4141
ax_buttons.set_title('Line Color', fontsize=12, pad=10)
4242
# Create a 2D grid of color options (3 rows x 2 columns)
43-
colors = [
44-
['red', 'yellow'],
45-
['green', 'purple'],
46-
['brown', 'gray'],
47-
]
48-
radio = RadioButtons(ax_buttons, colors, active=0)
43+
colors = ['red', 'yellow', 'green', 'purple', 'brown', 'gray']
44+
radio = RadioButtons(ax_buttons, colors, active=0, layout=(3, 2))
4945

5046
def color_func(label):
5147
"""Update the line color based on selected button."""

lib/matplotlib/widgets.py

Lines changed: 117 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,32 +1578,35 @@ class RadioButtons(AxesWidget):
15781578
value_selected : str
15791579
The label text of the currently selected button.
15801580
index_selected : int
1581-
The index of the selected button in the flattened array. For 2D grids,
1582-
this is the index when reading left-to-right, top-to-bottom.
1581+
The index of the selected button.
15831582
"""
15841583

15851584
def __init__(self, ax, labels, active=0, activecolor=None, *,
1586-
useblit=True, label_props=None, radio_props=None):
1585+
layout="vertical", useblit=True, label_props=None, radio_props=None):
15871586
"""
15881587
Add radio buttons to an `~.axes.Axes`.
15891588
15901589
Parameters
15911590
----------
15921591
ax : `~matplotlib.axes.Axes`
15931592
The Axes to add the buttons to.
1594-
labels : list of str or list of list of str
1595-
The button labels. If a list of strings, buttons are arranged
1596-
vertically. If a list of lists of strings, buttons are arranged
1597-
in a 2D grid where each inner list represents a row. For simple
1598-
horizontal radio buttons, use:
1599-
``labels=[['button1', 'button2', 'button3']]``
1593+
labels : list of str
1594+
The button labels.
16001595
active : int
1601-
The index of the initially selected button in the flattened array.
1602-
For 2D grids, this is the index when reading left-to-right,
1603-
top-to-bottom.
1596+
The index of the initially selected button.
16041597
activecolor : :mpltype:`color`
16051598
The color of the selected button. The default is ``'blue'`` if not
16061599
specified here or in *radio_props*.
1600+
layout : {"vertical", "horizontal"} or (int, int), default: "vertical"
1601+
The layout of the radio buttons. Options are:
1602+
1603+
- ``"vertical"``: Arrange buttons in a single column (default).
1604+
- ``"horizontal"``: Arrange buttons in a single row.
1605+
- ``(rows, cols)`` tuple: Arrange buttons in a grid with the
1606+
specified number of rows and columns. Buttons are placed
1607+
left-to-right, top-to-bottom.
1608+
1609+
.. versionadded:: 3.11
16071610
useblit : bool, default: True
16081611
Use blitting for faster drawing if supported by the backend.
16091612
See the tutorial :ref:`blitting` for details.
@@ -1613,8 +1616,8 @@ def __init__(self, ax, labels, active=0, activecolor=None, *,
16131616
label_props : dict of lists, optional
16141617
Dictionary of `.Text` properties to be used for the labels. Each
16151618
dictionary value should be a list of at least a single element. If
1616-
the flat list of labels is of length M, its values are cycled such
1617-
that the Nth label gets the (N mod M) property.
1619+
the list is of length M, its values are cycled such that the Nth
1620+
label gets the (N mod M) property.
16181621
16191622
.. versionadded:: 3.7
16201623
radio_props : dict, optional
@@ -1634,13 +1637,33 @@ def __init__(self, ax, labels, active=0, activecolor=None, *,
16341637
_api.check_isinstance((dict, None), label_props=label_props,
16351638
radio_props=radio_props)
16361639

1637-
# Check if labels is 2D (list of lists)
1638-
_is_2d = isinstance(labels[0], (list, tuple))
1639-
1640-
if _is_2d:
1641-
flat_labels = [item for row in labels for item in row]
1640+
labels = list(labels)
1641+
n_labels = len(labels)
1642+
1643+
bad_layout_raise_msg = \
1644+
"layout must be 'vertical', 'horizontal', or a (rows, cols) tuple; " \
1645+
f"got {layout!r}"
1646+
# Parse layout parameter
1647+
if isinstance(layout, str):
1648+
if layout == "vertical":
1649+
n_rows, n_cols = n_labels, 1
1650+
elif layout == "horizontal":
1651+
n_rows, n_cols = 1, n_labels
1652+
else:
1653+
raise ValueError(bad_layout_raise_msg)
1654+
elif isinstance(layout, tuple) and len(layout) == 2:
1655+
n_rows, n_cols = layout
1656+
if not (isinstance(n_rows, int) and isinstance(n_cols, int)):
1657+
raise TypeError(
1658+
f"layout tuple must contain two integers; got {layout!r}"
1659+
)
1660+
if n_rows * n_cols < n_labels:
1661+
raise ValueError(
1662+
f"layout {layout} has {n_rows * n_cols} positions but "
1663+
f"{n_labels} labels were provided"
1664+
)
16421665
else:
1643-
flat_labels = list(labels)
1666+
raise ValueError(bad_layout_raise_msg)
16441667

16451668
radio_props = cbook.normalize_kwargs(radio_props,
16461669
collections.PathCollection)
@@ -1655,7 +1678,7 @@ def __init__(self, ax, labels, active=0, activecolor=None, *,
16551678

16561679
self._activecolor = activecolor
16571680
self._initial_active = active
1658-
self.value_selected = flat_labels[active]
1681+
self.value_selected = labels[active]
16591682
self.index_selected = active
16601683

16611684
ax.set_xticks([])
@@ -1666,87 +1689,82 @@ def __init__(self, ax, labels, active=0, activecolor=None, *,
16661689
self._background = None
16671690

16681691
label_props = _expand_text_props(label_props)
1669-
# Calculate positions based on layout
1670-
text_x_offset = 0.10
1671-
1672-
if _is_2d:
1673-
n_rows = len(labels)
1674-
n_cols = max(len(row) for row in labels)
1675-
# Y positions with margins
1676-
y_margin = 0.05
1677-
y_spacing = (1 - 2 * y_margin) / max(1, n_rows - 1) if n_rows > 1 else 0
1678-
1679-
# Create temporary text objects to measure widths
1680-
flat_label_list = []
1681-
temp_texts = []
1682-
for i, row in enumerate(labels):
1683-
for j, label in enumerate(row):
1684-
flat_label_list.append(label)
1685-
for label, props in zip(flat_label_list, label_props):
1686-
temp_texts.append(ax.text(
1687-
0,
1688-
0,
1689-
label,
1690-
transform=ax.transAxes,
1691-
**props,
1692-
))
1693-
# Force a draw to get accurate text measurements
1694-
ax.figure.canvas.draw()
1695-
# Calculate max text width per column (in axes coordinates)
1696-
col_widths = []
1697-
for col_idx in range(n_cols):
1698-
col_texts = []
1699-
for row_idx, row in enumerate(labels):
1700-
if col_idx < len(row):
1701-
col_texts.append(temp_texts[
1702-
sum(len(labels[r]) for r in range(row_idx)) + col_idx
1703-
])
1704-
if col_texts:
1705-
col_widths.append(
1706-
max(
1707-
text.get_window_extent(
1708-
ax.figure.canvas.get_renderer()
1709-
).width
1710-
for text in col_texts
1711-
) / ax.bbox.width
1712-
)
1713-
else:
1714-
col_widths.append(0)
1715-
# Remove temporary text objects
1716-
for text in temp_texts:
1717-
text.remove()
1718-
# Calculate x positions based on text widths
1719-
# TODO: Should these be arguments?
1720-
button_x_margin = 0.07 # Left margin for first button
1721-
col_spacing = 0.07 # Space between columns
1722-
1723-
col_x_positions = [button_x_margin] # First column starts at left margin
1724-
for col_idx in range(n_cols - 1):
1725-
col_x_positions.append(sum([
1726-
col_x_positions[-1],
1727-
text_x_offset,
1728-
col_widths[col_idx],
1729-
col_spacing
1730-
]))
1731-
# Create final positions
1732-
positions = []
1733-
for i, row in enumerate(labels):
1734-
y = 1 - y_margin - i * y_spacing
1735-
for j, label in enumerate(row):
1736-
x = col_x_positions[j]
1737-
positions.append((x, y))
1738-
xs = [pos[0] for pos in positions]
1739-
ys = [pos[1] for pos in positions]
1740-
else:
1741-
ys = np.linspace(1, 0, len(flat_labels) + 2)[1:-1]
1742-
xs = [0.15] * len(ys)
1743-
flat_label_list = flat_labels
1692+
1693+
# Define spacing in display units (pixels) for consistency
1694+
# across different axes sizes
1695+
axes_width_display = ax.bbox.width
1696+
left_margin_display = 15 # pixels
1697+
button_text_offset_display = 6.5 # pixels
1698+
col_spacing_display = 15 # pixels
1699+
1700+
# Convert to axes coordinates
1701+
left_margin = left_margin_display / axes_width_display
1702+
button_text_offset = button_text_offset_display / axes_width_display
1703+
col_spacing = col_spacing_display / axes_width_display
1704+
1705+
# Create temporary text objects to measure widths
1706+
temp_texts = []
1707+
for label, props in zip(labels, label_props):
1708+
temp_texts.append(ax.text(
1709+
0,
1710+
0,
1711+
label,
1712+
transform=ax.transAxes,
1713+
**props,
1714+
))
1715+
# Force a draw to get accurate text measurements
1716+
ax.figure.canvas.draw()
1717+
1718+
# Calculate max text width per column (in axes coordinates)
1719+
col_widths = []
1720+
for col_idx in range(n_cols):
1721+
col_texts = []
1722+
for row_idx in range(n_rows):
1723+
label_idx = row_idx * n_cols + col_idx
1724+
if label_idx < n_labels:
1725+
col_texts.append(temp_texts[label_idx])
1726+
if col_texts:
1727+
col_widths.append(
1728+
max(
1729+
text.get_window_extent(
1730+
ax.figure.canvas.get_renderer()
1731+
).width
1732+
for text in col_texts
1733+
) / axes_width_display
1734+
)
1735+
else:
1736+
col_widths.append(0)
1737+
# Remove temporary text objects
1738+
for text in temp_texts:
1739+
text.remove()
1740+
1741+
# Center rows vertically in the axes
1742+
ys_per_row = np.linspace(1, 0, n_rows + 2)[1:-1]
1743+
# Calculate x positions based on text widths
1744+
col_x_positions = [left_margin] # First column starts at left margin
1745+
for col_idx in range(n_cols - 1):
1746+
col_x_positions.append(
1747+
col_x_positions[-1] +
1748+
button_text_offset +
1749+
col_widths[col_idx] +
1750+
col_spacing
1751+
)
1752+
# Create final positions (left-to-right, top-to-bottom)
1753+
xs = []
1754+
ys = []
1755+
for label_idx in range(n_labels):
1756+
row_idx = label_idx // n_cols
1757+
col_idx = label_idx % n_cols
1758+
x = col_x_positions[col_idx]
1759+
y = ys_per_row[row_idx]
1760+
xs.append(x)
1761+
ys.append(y)
17441762

17451763
self.labels = [
1746-
ax.text(x + text_x_offset, y, label, transform=ax.transAxes,
1764+
ax.text(x + button_text_offset, y, label, transform=ax.transAxes,
17471765
horizontalalignment="left", verticalalignment="center",
17481766
**props)
1749-
for x, y, label, props in zip(xs, ys, flat_label_list, label_props)]
1767+
for x, y, label, props in zip(xs, ys, labels, label_props)]
17501768
text_size = np.array([text.get_fontsize() for text in self.labels]) / 2
17511769

17521770
radio_props = {
@@ -1765,7 +1783,7 @@ def __init__(self, ax, labels, active=0, activecolor=None, *,
17651783
# the user set.
17661784
self._active_colors = self._buttons.get_facecolor()
17671785
if len(self._active_colors) == 1:
1768-
self._active_colors = np.repeat(self._active_colors, len(flat_labels),
1786+
self._active_colors = np.repeat(self._active_colors, n_labels,
17691787
axis=0)
17701788
self._buttons.set_facecolor(
17711789
[activecolor if i == active else "none"

lib/matplotlib/widgets.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,11 @@ class RadioButtons(AxesWidget):
206206
def __init__(
207207
self,
208208
ax: Axes,
209-
labels: Iterable[str] | Iterable[Iterable[str]],
209+
labels: Iterable[str],
210210
active: int = ...,
211211
activecolor: ColorType | None = ...,
212212
*,
213+
layout: Literal["vertical", "horizontal"] | tuple[int, int] = ...,
213214
useblit: bool = ...,
214215
label_props: dict[str, Sequence[Any]] | None = ...,
215216
radio_props: dict[str, Any] | None = ...,

0 commit comments

Comments
 (0)