55
66from os .path import commonprefix
77
8+ import matplotlib as mpl
89import matplotlib .pyplot as plt
910import numpy as np
1011
1112from . import config
1213
1314__all__ = ['suptitle' , 'get_plot_axes' ]
1415
16+ #
17+ # Style parameters
18+ #
19+
20+ _ctrlplot_rcParams = mpl .rcParams .copy ()
21+ _ctrlplot_rcParams .update ({
22+ 'axes.labelsize' : 'small' ,
23+ 'axes.titlesize' : 'small' ,
24+ 'figure.titlesize' : 'medium' ,
25+ 'legend.fontsize' : 'x-small' ,
26+ 'xtick.labelsize' : 'small' ,
27+ 'ytick.labelsize' : 'small' ,
28+ })
29+
30+
31+ #
32+ # User functions
33+ #
34+ # The functions below can be used by users to modify ctrl plots or get
35+ # information about them.
36+ #
37+
1538
1639def suptitle (
1740 title , fig = None , frame = 'axes' , ** kwargs ):
@@ -35,7 +58,7 @@ def suptitle(
3558 Additional keywords (passed to matplotlib).
3659
3760 """
38- rcParams = config ._get_param ('freqplot ' , 'rcParams' , kwargs , pop = True )
61+ rcParams = config ._get_param ('ctrlplot ' , 'rcParams' , kwargs , pop = True )
3962
4063 if fig is None :
4164 fig = plt .gcf ()
@@ -61,10 +84,10 @@ def suptitle(
6184def get_plot_axes (line_array ):
6285 """Get a list of axes from an array of lines.
6386
64- This function can be used to return the set of axes corresponding to
65- the line array that is returned by `time_response_plot`. This is useful for
66- generating an axes array that can be passed to subsequent plotting
67- calls.
87+ This function can be used to return the set of axes corresponding
88+ to the line array that is returned by `time_response_plot`. This
89+ is useful for generating an axes array that can be passed to
90+ subsequent plotting calls.
6891
6992 Parameters
7093 ----------
@@ -89,6 +112,125 @@ def get_plot_axes(line_array):
89112#
90113# Utility functions
91114#
115+ # These functions are used by plotting routines to provide a consistent way
116+ # of processing and displaing information.
117+ #
118+
119+
120+ def _process_ax_keyword (
121+ axs , shape = (1 , 1 ), rcParams = None , squeeze = False , clear_text = False ):
122+ """Utility function to process ax keyword to plotting commands.
123+
124+ This function processes the `ax` keyword to plotting commands. If no
125+ ax keyword is passed, the current figure is checked to see if it has
126+ the correct shape. If the shape matches the desired shape, then the
127+ current figure and axes are returned. Otherwise a new figure is
128+ created with axes of the desired shape.
129+
130+ Legacy behavior: some of the older plotting commands use a axes label
131+ to identify the proper axes for plotting. This behavior is supported
132+ through the use of the label keyword, but will only work if shape ==
133+ (1, 1) and squeeze == True.
134+
135+ """
136+ if axs is None :
137+ fig = plt .gcf () # get current figure (or create new one)
138+ axs = fig .get_axes ()
139+
140+ # Check to see if axes are the right shape; if not, create new figure
141+ # Note: can't actually check the shape, just the total number of axes
142+ if len (axs ) != np .prod (shape ):
143+ with plt .rc_context (rcParams ):
144+ if len (axs ) != 0 :
145+ # Create a new figure
146+ fig , axs = plt .subplots (* shape , squeeze = False )
147+ else :
148+ # Create new axes on (empty) figure
149+ axs = fig .subplots (* shape , squeeze = False )
150+ fig .set_layout_engine ('tight' )
151+ fig .align_labels ()
152+ else :
153+ # Use the existing axes, properly reshaped
154+ axs = np .asarray (axs ).reshape (* shape )
155+
156+ if clear_text :
157+ # Clear out any old text from the current figure
158+ for text in fig .texts :
159+ text .set_visible (False ) # turn off the text
160+ del text # get rid of it completely
161+ else :
162+ try :
163+ axs = np .asarray (axs ).reshape (shape )
164+ except ValueError :
165+ raise ValueError (
166+ "specified axes are not the right shape; "
167+ f"got { axs .shape } but expecting { shape } " )
168+ fig = axs [0 , 0 ].figure
169+
170+ # Process the squeeze keyword
171+ if squeeze and shape == (1 , 1 ):
172+ axs = axs [0 , 0 ] # Just return the single axes object
173+ elif squeeze :
174+ axs = axs .squeeze ()
175+
176+ return fig , axs
177+
178+
179+ # Turn label keyword into array indexed by trace, output, input
180+ # TODO: move to ctrlutil.py and update parameter names to reflect general use
181+ def _process_line_labels (label , ntraces , ninputs = 0 , noutputs = 0 ):
182+ if label is None :
183+ return None
184+
185+ if isinstance (label , str ):
186+ label = [label ] * ntraces # single label for all traces
187+
188+ # Convert to an ndarray, if not done aleady
189+ try :
190+ line_labels = np .asarray (label )
191+ except ValueError :
192+ raise ValueError ("label must be a string or array_like" )
193+
194+ # Turn the data into a 3D array of appropriate shape
195+ # TODO: allow more sophisticated broadcasting (and error checking)
196+ try :
197+ if ninputs > 0 and noutputs > 0 :
198+ if line_labels .ndim == 1 and line_labels .size == ntraces :
199+ line_labels = line_labels .reshape (ntraces , 1 , 1 )
200+ line_labels = np .broadcast_to (
201+ line_labels , (ntraces , ninputs , noutputs ))
202+ else :
203+ line_labels = line_labels .reshape (ntraces , ninputs , noutputs )
204+ except ValueError :
205+ if line_labels .shape [0 ] != ntraces :
206+ raise ValueError ("number of labels must match number of traces" )
207+ else :
208+ raise ValueError ("labels must be given for each input/output pair" )
209+
210+ return line_labels
211+
212+
213+ # Get labels for all lines in an axes
214+ def _get_line_labels (ax , use_color = True ):
215+ labels , lines = [], []
216+ last_color , counter = None , 0 # label unknown systems
217+ for i , line in enumerate (ax .get_lines ()):
218+ label = line .get_label ()
219+ if use_color and label .startswith ("Unknown" ):
220+ label = f"Unknown-{ counter } "
221+ if last_color is None :
222+ last_color = line .get_color ()
223+ elif last_color != line .get_color ():
224+ counter += 1
225+ last_color = line .get_color ()
226+ elif label [0 ] == '_' :
227+ continue
228+
229+ if label not in labels :
230+ lines .append (line )
231+ labels .append (label )
232+
233+ return lines , labels
92234
93235
94236# Utility function to make legend labels
@@ -160,3 +302,83 @@ def _find_axes_center(fig, axs):
160302 ylim = [min (ll [1 ], ylim [0 ]), max (ur [1 ], ylim [1 ])]
161303
162304 return (np .sum (xlim )/ 2 , np .sum (ylim )/ 2 )
305+
306+
307+ # Internal function to add arrows to a curve
308+ def _add_arrows_to_line2D (
309+ axes , line , arrow_locs = [0.2 , 0.4 , 0.6 , 0.8 ],
310+ arrowstyle = '-|>' , arrowsize = 1 , dir = 1 ):
311+ """
312+ Add arrows to a matplotlib.lines.Line2D at selected locations.
313+
314+ Parameters:
315+ -----------
316+ axes: Axes object as returned by axes command (or gca)
317+ line: Line2D object as returned by plot command
318+ arrow_locs: list of locations where to insert arrows, % of total length
319+ arrowstyle: style of the arrow
320+ arrowsize: size of the arrow
321+
322+ Returns:
323+ --------
324+ arrows: list of arrows
325+
326+ Based on https://stackoverflow.com/questions/26911898/
327+
328+ """
329+ # Get the coordinates of the line, in plot coordinates
330+ if not isinstance (line , mpl .lines .Line2D ):
331+ raise ValueError ("expected a matplotlib.lines.Line2D object" )
332+ x , y = line .get_xdata (), line .get_ydata ()
333+
334+ # Determine the arrow properties
335+ arrow_kw = {"arrowstyle" : arrowstyle }
336+
337+ color = line .get_color ()
338+ use_multicolor_lines = isinstance (color , np .ndarray )
339+ if use_multicolor_lines :
340+ raise NotImplementedError ("multicolor lines not supported" )
341+ else :
342+ arrow_kw ['color' ] = color
343+
344+ linewidth = line .get_linewidth ()
345+ if isinstance (linewidth , np .ndarray ):
346+ raise NotImplementedError ("multiwidth lines not supported" )
347+ else :
348+ arrow_kw ['linewidth' ] = linewidth
349+
350+ # Figure out the size of the axes (length of diagonal)
351+ xlim , ylim = axes .get_xlim (), axes .get_ylim ()
352+ ul , lr = np .array ([xlim [0 ], ylim [0 ]]), np .array ([xlim [1 ], ylim [1 ]])
353+ diag = np .linalg .norm (ul - lr )
354+
355+ # Compute the arc length along the curve
356+ s = np .cumsum (np .sqrt (np .diff (x ) ** 2 + np .diff (y ) ** 2 ))
357+
358+ # Truncate the number of arrows if the curve is short
359+ # TODO: figure out a smarter way to do this
360+ frac = min (s [- 1 ] / diag , 1 )
361+ if len (arrow_locs ) and frac < 0.05 :
362+ arrow_locs = [] # too short; no arrows at all
363+ elif len (arrow_locs ) and frac < 0.2 :
364+ arrow_locs = [0.5 ] # single arrow in the middle
365+
366+ # Plot the arrows (and return list if patches)
367+ arrows = []
368+ for loc in arrow_locs :
369+ n = np .searchsorted (s , s [- 1 ] * loc )
370+
371+ if dir == 1 and n == 0 :
372+ # Move the arrow forward by one if it is at start of a segment
373+ n = 1
374+
375+ # Place the head of the arrow at the desired location
376+ arrow_head = [x [n ], y [n ]]
377+ arrow_tail = [x [n - dir ], y [n - dir ]]
378+
379+ p = mpl .patches .FancyArrowPatch (
380+ arrow_tail , arrow_head , transform = axes .transData , lw = 0 ,
381+ ** arrow_kw )
382+ axes .add_patch (p )
383+ arrows .append (p )
384+ return arrows
0 commit comments