@@ -1578,32 +1578,35 @@ class RadioButtons(AxesWidget):
15781578 value_selected : str
15791579 The label text of the currently selected button.
15801580 index_selected : int
1581- The index of the selected button in the flattened array. For 2D grids,
1582- this is the index when reading left-to-right, top-to-bottom.
1581+ The index of the selected button.
15831582 """
15841583
15851584 def __init__ (self , ax , labels , active = 0 , activecolor = None , * ,
1586- useblit = True , label_props = None , radio_props = None ):
1585+ layout = "vertical" , useblit = True , label_props = None , radio_props = None ):
15871586 """
15881587 Add radio buttons to an `~.axes.Axes`.
15891588
15901589 Parameters
15911590 ----------
15921591 ax : `~matplotlib.axes.Axes`
15931592 The Axes to add the buttons to.
1594- labels : list of str or list of list of str
1595- The button labels. If a list of strings, buttons are arranged
1596- vertically. If a list of lists of strings, buttons are arranged
1597- in a 2D grid where each inner list represents a row. For simple
1598- horizontal radio buttons, use:
1599- ``labels=[['button1', 'button2', 'button3']]``
1593+ labels : list of str
1594+ The button labels.
16001595 active : int
1601- The index of the initially selected button in the flattened array.
1602- For 2D grids, this is the index when reading left-to-right,
1603- top-to-bottom.
1596+ The index of the initially selected button.
16041597 activecolor : :mpltype:`color`
16051598 The color of the selected button. The default is ``'blue'`` if not
16061599 specified here or in *radio_props*.
1600+ layout : {"vertical", "horizontal"} or (int, int), default: "vertical"
1601+ The layout of the radio buttons. Options are:
1602+
1603+ - ``"vertical"``: Arrange buttons in a single column (default).
1604+ - ``"horizontal"``: Arrange buttons in a single row.
1605+ - ``(rows, cols)`` tuple: Arrange buttons in a grid with the
1606+ specified number of rows and columns. Buttons are placed
1607+ left-to-right, top-to-bottom.
1608+
1609+ .. versionadded:: 3.11
16071610 useblit : bool, default: True
16081611 Use blitting for faster drawing if supported by the backend.
16091612 See the tutorial :ref:`blitting` for details.
@@ -1613,8 +1616,8 @@ def __init__(self, ax, labels, active=0, activecolor=None, *,
16131616 label_props : dict of lists, optional
16141617 Dictionary of `.Text` properties to be used for the labels. Each
16151618 dictionary value should be a list of at least a single element. If
1616- the flat list of labels is of length M, its values are cycled such
1617- that the Nth label gets the (N mod M) property.
1619+ the list is of length M, its values are cycled such that the Nth
1620+ label gets the (N mod M) property.
16181621
16191622 .. versionadded:: 3.7
16201623 radio_props : dict, optional
@@ -1634,13 +1637,33 @@ def __init__(self, ax, labels, active=0, activecolor=None, *,
16341637 _api .check_isinstance ((dict , None ), label_props = label_props ,
16351638 radio_props = radio_props )
16361639
1637- # Check if labels is 2D (list of lists)
1638- _is_2d = isinstance (labels [0 ], (list , tuple ))
1639-
1640- if _is_2d :
1641- flat_labels = [item for row in labels for item in row ]
1640+ labels = list (labels )
1641+ n_labels = len (labels )
1642+
1643+ bad_layout_raise_msg = \
1644+ "layout must be 'vertical', 'horizontal', or a (rows, cols) tuple; " \
1645+ f"got { layout !r} "
1646+ # Parse layout parameter
1647+ if isinstance (layout , str ):
1648+ if layout == "vertical" :
1649+ n_rows , n_cols = n_labels , 1
1650+ elif layout == "horizontal" :
1651+ n_rows , n_cols = 1 , n_labels
1652+ else :
1653+ raise ValueError (bad_layout_raise_msg )
1654+ elif isinstance (layout , tuple ) and len (layout ) == 2 :
1655+ n_rows , n_cols = layout
1656+ if not (isinstance (n_rows , int ) and isinstance (n_cols , int )):
1657+ raise TypeError (
1658+ f"layout tuple must contain two integers; got { layout !r} "
1659+ )
1660+ if n_rows * n_cols < n_labels :
1661+ raise ValueError (
1662+ f"layout { layout } has { n_rows * n_cols } positions but "
1663+ f"{ n_labels } labels were provided"
1664+ )
16421665 else :
1643- flat_labels = list ( labels )
1666+ raise ValueError ( bad_layout_raise_msg )
16441667
16451668 radio_props = cbook .normalize_kwargs (radio_props ,
16461669 collections .PathCollection )
@@ -1655,7 +1678,7 @@ def __init__(self, ax, labels, active=0, activecolor=None, *,
16551678
16561679 self ._activecolor = activecolor
16571680 self ._initial_active = active
1658- self .value_selected = flat_labels [active ]
1681+ self .value_selected = labels [active ]
16591682 self .index_selected = active
16601683
16611684 ax .set_xticks ([])
@@ -1666,87 +1689,82 @@ def __init__(self, ax, labels, active=0, activecolor=None, *,
16661689 self ._background = None
16671690
16681691 label_props = _expand_text_props (label_props )
1669- # Calculate positions based on layout
1670- text_x_offset = 0.10
1671-
1672- if _is_2d :
1673- n_rows = len (labels )
1674- n_cols = max (len (row ) for row in labels )
1675- # Y positions with margins
1676- y_margin = 0.05
1677- y_spacing = (1 - 2 * y_margin ) / max (1 , n_rows - 1 ) if n_rows > 1 else 0
1678-
1679- # Create temporary text objects to measure widths
1680- flat_label_list = []
1681- temp_texts = []
1682- for i , row in enumerate (labels ):
1683- for j , label in enumerate (row ):
1684- flat_label_list .append (label )
1685- for label , props in zip (flat_label_list , label_props ):
1686- temp_texts .append (ax .text (
1687- 0 ,
1688- 0 ,
1689- label ,
1690- transform = ax .transAxes ,
1691- ** props ,
1692- ))
1693- # Force a draw to get accurate text measurements
1694- ax .figure .canvas .draw ()
1695- # Calculate max text width per column (in axes coordinates)
1696- col_widths = []
1697- for col_idx in range (n_cols ):
1698- col_texts = []
1699- for row_idx , row in enumerate (labels ):
1700- if col_idx < len (row ):
1701- col_texts .append (temp_texts [
1702- sum (len (labels [r ]) for r in range (row_idx )) + col_idx
1703- ])
1704- if col_texts :
1705- col_widths .append (
1706- max (
1707- text .get_window_extent (
1708- ax .figure .canvas .get_renderer ()
1709- ).width
1710- for text in col_texts
1711- ) / ax .bbox .width
1712- )
1713- else :
1714- col_widths .append (0 )
1715- # Remove temporary text objects
1716- for text in temp_texts :
1717- text .remove ()
1718- # Calculate x positions based on text widths
1719- # TODO: Should these be arguments?
1720- button_x_margin = 0.07 # Left margin for first button
1721- col_spacing = 0.07 # Space between columns
1722-
1723- col_x_positions = [button_x_margin ] # First column starts at left margin
1724- for col_idx in range (n_cols - 1 ):
1725- col_x_positions .append (sum ([
1726- col_x_positions [- 1 ],
1727- text_x_offset ,
1728- col_widths [col_idx ],
1729- col_spacing
1730- ]))
1731- # Create final positions
1732- positions = []
1733- for i , row in enumerate (labels ):
1734- y = 1 - y_margin - i * y_spacing
1735- for j , label in enumerate (row ):
1736- x = col_x_positions [j ]
1737- positions .append ((x , y ))
1738- xs = [pos [0 ] for pos in positions ]
1739- ys = [pos [1 ] for pos in positions ]
1740- else :
1741- ys = np .linspace (1 , 0 , len (flat_labels ) + 2 )[1 :- 1 ]
1742- xs = [0.15 ] * len (ys )
1743- flat_label_list = flat_labels
1692+
1693+ # Define spacing in display units (pixels) for consistency
1694+ # across different axes sizes
1695+ axes_width_display = ax .bbox .width
1696+ left_margin_display = 15 # pixels
1697+ button_text_offset_display = 6.5 # pixels
1698+ col_spacing_display = 15 # pixels
1699+
1700+ # Convert to axes coordinates
1701+ left_margin = left_margin_display / axes_width_display
1702+ button_text_offset = button_text_offset_display / axes_width_display
1703+ col_spacing = col_spacing_display / axes_width_display
1704+
1705+ # Create temporary text objects to measure widths
1706+ temp_texts = []
1707+ for label , props in zip (labels , label_props ):
1708+ temp_texts .append (ax .text (
1709+ 0 ,
1710+ 0 ,
1711+ label ,
1712+ transform = ax .transAxes ,
1713+ ** props ,
1714+ ))
1715+ # Force a draw to get accurate text measurements
1716+ ax .figure .canvas .draw ()
1717+
1718+ # Calculate max text width per column (in axes coordinates)
1719+ col_widths = []
1720+ for col_idx in range (n_cols ):
1721+ col_texts = []
1722+ for row_idx in range (n_rows ):
1723+ label_idx = row_idx * n_cols + col_idx
1724+ if label_idx < n_labels :
1725+ col_texts .append (temp_texts [label_idx ])
1726+ if col_texts :
1727+ col_widths .append (
1728+ max (
1729+ text .get_window_extent (
1730+ ax .figure .canvas .get_renderer ()
1731+ ).width
1732+ for text in col_texts
1733+ ) / axes_width_display
1734+ )
1735+ else :
1736+ col_widths .append (0 )
1737+ # Remove temporary text objects
1738+ for text in temp_texts :
1739+ text .remove ()
1740+
1741+ # Center rows vertically in the axes
1742+ ys_per_row = np .linspace (1 , 0 , n_rows + 2 )[1 :- 1 ]
1743+ # Calculate x positions based on text widths
1744+ col_x_positions = [left_margin ] # First column starts at left margin
1745+ for col_idx in range (n_cols - 1 ):
1746+ col_x_positions .append (
1747+ col_x_positions [- 1 ] +
1748+ button_text_offset +
1749+ col_widths [col_idx ] +
1750+ col_spacing
1751+ )
1752+ # Create final positions (left-to-right, top-to-bottom)
1753+ xs = []
1754+ ys = []
1755+ for label_idx in range (n_labels ):
1756+ row_idx = label_idx // n_cols
1757+ col_idx = label_idx % n_cols
1758+ x = col_x_positions [col_idx ]
1759+ y = ys_per_row [row_idx ]
1760+ xs .append (x )
1761+ ys .append (y )
17441762
17451763 self .labels = [
1746- ax .text (x + text_x_offset , y , label , transform = ax .transAxes ,
1764+ ax .text (x + button_text_offset , y , label , transform = ax .transAxes ,
17471765 horizontalalignment = "left" , verticalalignment = "center" ,
17481766 ** props )
1749- for x , y , label , props in zip (xs , ys , flat_label_list , label_props )]
1767+ for x , y , label , props in zip (xs , ys , labels , label_props )]
17501768 text_size = np .array ([text .get_fontsize () for text in self .labels ]) / 2
17511769
17521770 radio_props = {
@@ -1765,7 +1783,7 @@ def __init__(self, ax, labels, active=0, activecolor=None, *,
17651783 # the user set.
17661784 self ._active_colors = self ._buttons .get_facecolor ()
17671785 if len (self ._active_colors ) == 1 :
1768- self ._active_colors = np .repeat (self ._active_colors , len ( flat_labels ) ,
1786+ self ._active_colors = np .repeat (self ._active_colors , n_labels ,
17691787 axis = 0 )
17701788 self ._buttons .set_facecolor (
17711789 [activecolor if i == active else "none"
0 commit comments