@@ -56,7 +56,7 @@ def phase_plane_plot(
5656 This function plots phase plane data, including vector fields, stream
5757 lines, equilibrium points, and contour curves.
5858 If none of plot_streamlines, plot_vectorfield, or plot_streamplot are
59- set, then plot_streamlines is used by default.
59+ set, then plot_streamplot is used by default.
6060
6161 Parameters
6262 ----------
@@ -164,7 +164,7 @@ def phase_plane_plot(
164164 and plot_vectorfield is None
165165 and plot_streamplot is None
166166 ):
167- plot_streamlines = True
167+ plot_streamplot = True
168168
169169 if plot_streamplot and not plot_streamlines and not plot_vectorfield :
170170 gridspec = gridspec or [25 , 25 ]
@@ -191,7 +191,10 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
191191 return new_kwargs
192192
193193 # Create list for storing outputs
194- out = np .array ([[], None , None ], dtype = object )
194+ out = np .array ([[], None , None , None ], dtype = object )
195+
196+ # the maximum zorder of stramlines, vectorfield or streamplot
197+ flow_zorder = None
195198
196199 # Plot out the main elements
197200 if plot_streamlines :
@@ -201,6 +204,9 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
201204 out [0 ] += streamlines (
202205 sys , pointdata , timedata , _check_kwargs = False ,
203206 suppress_warnings = suppress_warnings , ** kwargs_local )
207+
208+ new_zorder = max (elem .get_zorder () for elem in out [0 ])
209+ flow_zorder = max (flow_zorder , new_zorder ) if flow_zorder else new_zorder
204210
205211 # Get rid of keyword arguments handled by streamlines
206212 for kw in ['arrows' , 'arrow_size' , 'arrow_style' , 'color' ,
@@ -211,39 +217,56 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
211217 if gridtype not in [None , 'boxgrid' , 'meshgrid' ]:
212218 gridspec = None
213219
214- if plot_separatrices :
215- kwargs_local = _create_kwargs (
216- kwargs , plot_separatrices , gridspec = gridspec , ax = ax )
217- out [0 ] += separatrices (
218- sys , pointdata , _check_kwargs = False , ** kwargs_local )
219-
220- # Get rid of keyword arguments handled by separatrices
221- for kw in ['arrows' , 'arrow_size' , 'arrow_style' , 'params' ]:
222- initial_kwargs .pop (kw , None )
223-
224220 if plot_vectorfield :
225221 kwargs_local = _create_kwargs (
226222 kwargs , plot_vectorfield , gridspec = gridspec , ax = ax )
227223 out [1 ] = vectorfield (
228224 sys , pointdata , _check_kwargs = False , ** kwargs_local )
225+
226+ new_zorder = out [1 ].get_zorder ()
227+ flow_zorder = max (flow_zorder , new_zorder ) if flow_zorder else new_zorder
229228
230229 # Get rid of keyword arguments handled by vectorfield
231230 for kw in ['color' , 'params' ]:
232231 initial_kwargs .pop (kw , None )
233232
234233 if plot_streamplot :
234+ if gridtype not in [None , 'meshgrid' ]:
235+ raise ValueError ("gridtype must be 'meshgrid' when using streamplot" )
236+
235237 kwargs_local = _create_kwargs (
236238 kwargs , plot_streamplot , gridspec = gridspec , ax = ax )
237- streamplot (
239+ out [ 3 ] = streamplot (
238240 sys , pointdata , _check_kwargs = False , ** kwargs_local )
241+
242+ new_zorder = max (out [3 ].lines .get_zorder (), out [3 ].arrows .get_zorder ())
243+ flow_zorder = max (flow_zorder , new_zorder ) if flow_zorder else new_zorder
239244
240245 # Get rid of keyword arguments handled by streamplot
241246 for kw in ['color' , 'params' ]:
242247 initial_kwargs .pop (kw , None )
243248
249+ sep_zorder = flow_zorder + 1 if flow_zorder else None
250+
251+ if plot_separatrices :
252+ kwargs_local = _create_kwargs (
253+ kwargs , plot_separatrices , gridspec = gridspec , ax = ax )
254+ kwargs_local ['zorder' ] = kwargs_local .get ('zorder' , sep_zorder )
255+ out [0 ] += separatrices (
256+ sys , pointdata , _check_kwargs = False , ** kwargs_local )
257+
258+ sep_zorder = max (elem .get_zorder () for elem in out [0 ])
259+
260+ # Get rid of keyword arguments handled by separatrices
261+ for kw in ['arrows' , 'arrow_size' , 'arrow_style' , 'params' ]:
262+ initial_kwargs .pop (kw , None )
263+
264+ equil_zorder = sep_zorder + 1 if sep_zorder else None
265+
244266 if plot_equilpoints :
245267 kwargs_local = _create_kwargs (
246268 kwargs , plot_equilpoints , gridspec = gridspec , ax = ax )
269+ kwargs_local ['zorder' ] = kwargs_local .get ('zorder' , equil_zorder )
247270 out [2 ] = equilpoints (
248271 sys , pointdata , _check_kwargs = False , ** kwargs_local )
249272
@@ -267,8 +290,8 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
267290
268291
269292def vectorfield (
270- sys , pointdata , gridspec = None , ax = None , suppress_warnings = False ,
271- _check_kwargs = True , ** kwargs ):
293+ sys , pointdata , gridspec = None , zorder = None , ax = None ,
294+ suppress_warnings = False , _check_kwargs = True , ** kwargs ):
272295 """Plot a vector field in the phase plane.
273296
274297 This function plots a vector field for a two-dimensional state
@@ -302,6 +325,9 @@ def vectorfield(
302325 dict with key 'args' and value given by a tuple (passed to callable).
303326 color : matplotlib color spec, optional
304327 Plot the vector field in the given color.
328+ zorder : float, optional
329+ Set the zorder for the separatrices. In not specified, it will be
330+ automatically chosen by `matplotlib.axes.Axes.quiver`.
305331 ax : `matplotlib.axes.Axes`, optional
306332 Use the given axes for the plot, otherwise use the current axes.
307333
@@ -354,19 +380,19 @@ def vectorfield(
354380 with plt .rc_context (rcParams ):
355381 out = ax .quiver (
356382 vfdata [:, 0 ], vfdata [:, 1 ], vfdata [:, 2 ], vfdata [:, 3 ],
357- angles = 'xy' , color = color )
383+ angles = 'xy' , color = color , zorder = zorder )
358384
359385 return out
360386
361387
362388def streamplot (
363- sys , pointdata , gridspec = None , ax = None , vary_color = False ,
389+ sys , pointdata , gridspec = None , zorder = None , ax = None , vary_color = False ,
364390 vary_linewidth = False , cmap = None , norm = None , suppress_warnings = False ,
365391 _check_kwargs = True , ** kwargs ):
366- """Plot a vector field in the phase plane.
392+ """Plot streamlines in the phase plane.
367393
368- This function plots a vector field for a two-dimensional state
369- space system.
394+ This function plots the streamlines for a two-dimensional state
395+ space system using the `matplotlib.axes.Axes.streamplot` function .
370396
371397 Parameters
372398 ----------
@@ -376,9 +402,7 @@ def streamplot(
376402 `params` keyword.
377403 pointdata : list or 2D array
378404 List of the form [xmin, xmax, ymin, ymax] describing the
379- boundaries of the phase plot or an array of shape (N, 2)
380- giving points from which to make the streamplot. In the latter case,
381- the points lie on a grid like that generated by `meshgrid`.
405+ boundaries of the phase plot.
382406 gridspec : list, optional
383407 Specifies the size of the grid in the x and y axes on which to
384408 generate points.
@@ -396,6 +420,9 @@ def streamplot(
396420 Colormap to use for varying the color of the streamlines.
397421 norm : `matplotlib.colors.Normalize`, optional
398422 An instance of Normalize to use for scaling the colormap and linewidths.
423+ zorder : float, optional
424+ Set the zorder for the separatrices. In not specified, it will be
425+ automatically chosen by `matplotlib.axes.Axes.streamplot`.
399426 ax : `matplotlib.axes.Axes`, optional
400427 Use the given axes for the plot, otherwise use the current axes.
401428
@@ -423,15 +450,6 @@ def streamplot(
423450
424451 # Determine the points on which to generate the streamplot field
425452 points , gridspec = _make_points (pointdata , gridspec , 'meshgrid' )
426-
427- # attempt to recover the grid by counting the jumps in xvals
428- if gridspec is None :
429- nrows = np .sum (np .diff (points [:, 0 ]) < 0 ) + 1
430- ncols = points .shape [0 ] // nrows
431- if nrows * ncols != points .shape [0 ]:
432- raise ValueError ("Could not recover grid from points." )
433- gridspec = [nrows , ncols ]
434-
435453 grid_arr_shape = gridspec [::- 1 ]
436454 xs , ys = points [:, 0 ].reshape (grid_arr_shape ), points [:, 1 ].reshape (grid_arr_shape )
437455
@@ -465,13 +483,15 @@ def streamplot(
465483 linewidths = normalized * (max_lw - min_lw ) + min_lw if vary_linewidth else None
466484 color = magnitudes if vary_color else color
467485
468- out = ax .streamplot (xs , ys , us , vs , color = color , linewidth = linewidths , cmap = cmap , norm = norm )
486+ out = ax .streamplot (xs , ys , us , vs , color = color , linewidth = linewidths ,
487+ cmap = cmap , norm = norm , zorder = zorder )
469488
470489 return out
471490
472491def streamlines (
473492 sys , pointdata , timedata = 1 , gridspec = None , gridtype = None , dir = None ,
474- ax = None , _check_kwargs = True , suppress_warnings = False , ** kwargs ):
493+ zorder = None , ax = None , _check_kwargs = True , suppress_warnings = False ,
494+ ** kwargs ):
475495 """Plot stream lines in the phase plane.
476496
477497 This function plots stream lines for a two-dimensional state space
@@ -513,6 +533,9 @@ def streamlines(
513533 dict with key 'args' and value given by a tuple (passed to callable).
514534 color : str
515535 Plot the streamlines in the given color.
536+ zorder : float, optional
537+ Set the zorder for the separatrices. In not specified, it will be
538+ automatically chosen by `matplotlib.axes.Axes.plot`.
516539 ax : `matplotlib.axes.Axes`, optional
517540 Use the given axes for the plot, otherwise use the current axes.
518541
@@ -591,7 +614,7 @@ def streamlines(
591614 # Plot the trajectory (if there is one)
592615 if traj .shape [1 ] > 1 :
593616 with plt .rc_context (rcParams ):
594- out += ax .plot (traj [0 ], traj [1 ], color = color )
617+ out += ax .plot (traj [0 ], traj [1 ], color = color , zorder = zorder )
595618
596619 # Add arrows to the lines at specified intervals
597620 _add_arrows_to_line2D (
@@ -600,7 +623,7 @@ def streamlines(
600623
601624
602625def equilpoints (
603- sys , pointdata , gridspec = None , color = 'k' , ax = None ,
626+ sys , pointdata , gridspec = None , color = 'k' , zorder = None , ax = None ,
604627 _check_kwargs = True , ** kwargs ):
605628 """Plot equilibrium points in the phase plane.
606629
@@ -634,6 +657,9 @@ def equilpoints(
634657 dict with key 'args' and value given by a tuple (passed to callable).
635658 color : str
636659 Plot the equilibrium points in the given color.
660+ zorder : float, optional
661+ Set the zorder for the separatrices. In not specified, it will be
662+ automatically chosen by `matplotlib.axes.Axes.plot`.
637663 ax : `matplotlib.axes.Axes`, optional
638664 Use the given axes for the plot, otherwise use the current axes.
639665
@@ -679,12 +705,12 @@ def equilpoints(
679705 out = []
680706 for xeq in equilpts :
681707 with plt .rc_context (rcParams ):
682- out += ax .plot (xeq [0 ], xeq [1 ], marker = 'o' , color = color )
708+ out += ax .plot (xeq [0 ], xeq [1 ], marker = 'o' , color = color , zorder = zorder )
683709 return out
684710
685711
686712def separatrices (
687- sys , pointdata , timedata = None , gridspec = None , ax = None ,
713+ sys , pointdata , timedata = None , gridspec = None , zorder = None , ax = None ,
688714 _check_kwargs = True , suppress_warnings = False , ** kwargs ):
689715 """Plot separatrices in the phase plane.
690716
@@ -726,6 +752,9 @@ def separatrices(
726752 separatrices. If a tuple is given, the first element is used as
727753 the color specification for stable separatrices and the second
728754 element for unstable separatrices.
755+ zorder : float, optional
756+ Set the zorder for the separatrices. In not specified, it will be
757+ automatically chosen by `matplotlib.axes.Axes.plot`.
729758 ax : `matplotlib.axes.Axes`, optional
730759 Use the given axes for the plot, otherwise use the current axes.
731760
@@ -802,7 +831,7 @@ def separatrices(
802831 for i , xeq in enumerate (equilpts ):
803832 # Plot the equilibrium points
804833 with plt .rc_context (rcParams ):
805- out += ax .plot (xeq [0 ], xeq [1 ], marker = 'o' , color = 'k' )
834+ out += ax .plot (xeq [0 ], xeq [1 ], marker = 'o' , color = 'k' , zorder = zorder )
806835
807836 # Figure out the linearization and eigenvectors
808837 evals , evecs = np .linalg .eig (sys .linearize (xeq , 0 , params = params ).A )
@@ -844,7 +873,7 @@ def separatrices(
844873 if traj .shape [1 ] > 1 :
845874 with plt .rc_context (rcParams ):
846875 out += ax .plot (
847- traj [0 ], traj [1 ], color = color , linestyle = linestyle )
876+ traj [0 ], traj [1 ], color = color , linestyle = linestyle , zorder = zorder )
848877
849878 # Add arrows to the lines at specified intervals
850879 with plt .rc_context (rcParams ):
0 commit comments