Skip to content

Commit c806ed5

Browse files
committed
added automatic zordering and set streamplot as default, added tests
1 parent ad40832 commit c806ed5

4 files changed

Lines changed: 169 additions & 53 deletions

File tree

control/phaseplot.py

Lines changed: 70 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def phase_plane_plot(
5656
This function plots phase plane data, including vector fields, stream
5757
lines, equilibrium points, and contour curves.
5858
If none of plot_streamlines, plot_vectorfield, or plot_streamplot are
59-
set, then plot_streamlines is used by default.
59+
set, then plot_streamplot is used by default.
6060
6161
Parameters
6262
----------
@@ -164,7 +164,7 @@ def phase_plane_plot(
164164
and plot_vectorfield is None
165165
and plot_streamplot is None
166166
):
167-
plot_streamlines = True
167+
plot_streamplot = True
168168

169169
if plot_streamplot and not plot_streamlines and not plot_vectorfield:
170170
gridspec = gridspec or [25, 25]
@@ -191,7 +191,10 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
191191
return new_kwargs
192192

193193
# Create list for storing outputs
194-
out = np.array([[], None, None], dtype=object)
194+
out = np.array([[], None, None, None], dtype=object)
195+
196+
# the maximum zorder of stramlines, vectorfield or streamplot
197+
flow_zorder = None
195198

196199
# Plot out the main elements
197200
if plot_streamlines:
@@ -201,6 +204,9 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
201204
out[0] += streamlines(
202205
sys, pointdata, timedata, _check_kwargs=False,
203206
suppress_warnings=suppress_warnings, **kwargs_local)
207+
208+
new_zorder = max(elem.get_zorder() for elem in out[0])
209+
flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
204210

205211
# Get rid of keyword arguments handled by streamlines
206212
for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
@@ -211,39 +217,56 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
211217
if gridtype not in [None, 'boxgrid', 'meshgrid']:
212218
gridspec = None
213219

214-
if plot_separatrices:
215-
kwargs_local = _create_kwargs(
216-
kwargs, plot_separatrices, gridspec=gridspec, ax=ax)
217-
out[0] += separatrices(
218-
sys, pointdata, _check_kwargs=False, **kwargs_local)
219-
220-
# Get rid of keyword arguments handled by separatrices
221-
for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
222-
initial_kwargs.pop(kw, None)
223-
224220
if plot_vectorfield:
225221
kwargs_local = _create_kwargs(
226222
kwargs, plot_vectorfield, gridspec=gridspec, ax=ax)
227223
out[1] = vectorfield(
228224
sys, pointdata, _check_kwargs=False, **kwargs_local)
225+
226+
new_zorder = out[1].get_zorder()
227+
flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
229228

230229
# Get rid of keyword arguments handled by vectorfield
231230
for kw in ['color', 'params']:
232231
initial_kwargs.pop(kw, None)
233232

234233
if plot_streamplot:
234+
if gridtype not in [None, 'meshgrid']:
235+
raise ValueError("gridtype must be 'meshgrid' when using streamplot")
236+
235237
kwargs_local = _create_kwargs(
236238
kwargs, plot_streamplot, gridspec=gridspec, ax=ax)
237-
streamplot(
239+
out[3] = streamplot(
238240
sys, pointdata, _check_kwargs=False, **kwargs_local)
241+
242+
new_zorder = max(out[3].lines.get_zorder(), out[3].arrows.get_zorder())
243+
flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
239244

240245
# Get rid of keyword arguments handled by streamplot
241246
for kw in ['color', 'params']:
242247
initial_kwargs.pop(kw, None)
243248

249+
sep_zorder = flow_zorder + 1 if flow_zorder else None
250+
251+
if plot_separatrices:
252+
kwargs_local = _create_kwargs(
253+
kwargs, plot_separatrices, gridspec=gridspec, ax=ax)
254+
kwargs_local['zorder'] = kwargs_local.get('zorder', sep_zorder)
255+
out[0] += separatrices(
256+
sys, pointdata, _check_kwargs=False, **kwargs_local)
257+
258+
sep_zorder = max(elem.get_zorder() for elem in out[0])
259+
260+
# Get rid of keyword arguments handled by separatrices
261+
for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
262+
initial_kwargs.pop(kw, None)
263+
264+
equil_zorder = sep_zorder + 1 if sep_zorder else None
265+
244266
if plot_equilpoints:
245267
kwargs_local = _create_kwargs(
246268
kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
269+
kwargs_local['zorder'] = kwargs_local.get('zorder', equil_zorder)
247270
out[2] = equilpoints(
248271
sys, pointdata, _check_kwargs=False, **kwargs_local)
249272

@@ -267,8 +290,8 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
267290

268291

269292
def vectorfield(
270-
sys, pointdata, gridspec=None, ax=None, suppress_warnings=False,
271-
_check_kwargs=True, **kwargs):
293+
sys, pointdata, gridspec=None, zorder=None, ax=None,
294+
suppress_warnings=False, _check_kwargs=True, **kwargs):
272295
"""Plot a vector field in the phase plane.
273296
274297
This function plots a vector field for a two-dimensional state
@@ -302,6 +325,9 @@ def vectorfield(
302325
dict with key 'args' and value given by a tuple (passed to callable).
303326
color : matplotlib color spec, optional
304327
Plot the vector field in the given color.
328+
zorder : float, optional
329+
Set the zorder for the separatrices. In not specified, it will be
330+
automatically chosen by `matplotlib.axes.Axes.quiver`.
305331
ax : `matplotlib.axes.Axes`, optional
306332
Use the given axes for the plot, otherwise use the current axes.
307333
@@ -354,19 +380,19 @@ def vectorfield(
354380
with plt.rc_context(rcParams):
355381
out = ax.quiver(
356382
vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
357-
angles='xy', color=color)
383+
angles='xy', color=color, zorder=zorder)
358384

359385
return out
360386

361387

362388
def streamplot(
363-
sys, pointdata, gridspec=None, ax=None, vary_color=False,
389+
sys, pointdata, gridspec=None, zorder=None, ax=None, vary_color=False,
364390
vary_linewidth=False, cmap=None, norm=None, suppress_warnings=False,
365391
_check_kwargs=True, **kwargs):
366-
"""Plot a vector field in the phase plane.
392+
"""Plot streamlines in the phase plane.
367393
368-
This function plots a vector field for a two-dimensional state
369-
space system.
394+
This function plots the streamlines for a two-dimensional state
395+
space system using the `matplotlib.axes.Axes.streamplot` function.
370396
371397
Parameters
372398
----------
@@ -376,9 +402,7 @@ def streamplot(
376402
`params` keyword.
377403
pointdata : list or 2D array
378404
List of the form [xmin, xmax, ymin, ymax] describing the
379-
boundaries of the phase plot or an array of shape (N, 2)
380-
giving points from which to make the streamplot. In the latter case,
381-
the points lie on a grid like that generated by `meshgrid`.
405+
boundaries of the phase plot.
382406
gridspec : list, optional
383407
Specifies the size of the grid in the x and y axes on which to
384408
generate points.
@@ -396,6 +420,9 @@ def streamplot(
396420
Colormap to use for varying the color of the streamlines.
397421
norm : `matplotlib.colors.Normalize`, optional
398422
An instance of Normalize to use for scaling the colormap and linewidths.
423+
zorder : float, optional
424+
Set the zorder for the separatrices. In not specified, it will be
425+
automatically chosen by `matplotlib.axes.Axes.streamplot`.
399426
ax : `matplotlib.axes.Axes`, optional
400427
Use the given axes for the plot, otherwise use the current axes.
401428
@@ -423,15 +450,6 @@ def streamplot(
423450

424451
# Determine the points on which to generate the streamplot field
425452
points, gridspec = _make_points(pointdata, gridspec, 'meshgrid')
426-
427-
# attempt to recover the grid by counting the jumps in xvals
428-
if gridspec is None:
429-
nrows = np.sum(np.diff(points[:, 0]) < 0) + 1
430-
ncols = points.shape[0] // nrows
431-
if nrows * ncols != points.shape[0]:
432-
raise ValueError("Could not recover grid from points.")
433-
gridspec = [nrows, ncols]
434-
435453
grid_arr_shape = gridspec[::-1]
436454
xs, ys = points[:, 0].reshape(grid_arr_shape), points[:, 1].reshape(grid_arr_shape)
437455

@@ -465,13 +483,15 @@ def streamplot(
465483
linewidths = normalized * (max_lw - min_lw) + min_lw if vary_linewidth else None
466484
color = magnitudes if vary_color else color
467485

468-
out = ax.streamplot(xs, ys, us, vs, color=color, linewidth=linewidths, cmap=cmap, norm=norm)
486+
out = ax.streamplot(xs, ys, us, vs, color=color, linewidth=linewidths,
487+
cmap=cmap, norm=norm, zorder=zorder)
469488

470489
return out
471490

472491
def streamlines(
473492
sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
474-
ax=None, _check_kwargs=True, suppress_warnings=False, **kwargs):
493+
zorder=None, ax=None, _check_kwargs=True, suppress_warnings=False,
494+
**kwargs):
475495
"""Plot stream lines in the phase plane.
476496
477497
This function plots stream lines for a two-dimensional state space
@@ -513,6 +533,9 @@ def streamlines(
513533
dict with key 'args' and value given by a tuple (passed to callable).
514534
color : str
515535
Plot the streamlines in the given color.
536+
zorder : float, optional
537+
Set the zorder for the separatrices. In not specified, it will be
538+
automatically chosen by `matplotlib.axes.Axes.plot`.
516539
ax : `matplotlib.axes.Axes`, optional
517540
Use the given axes for the plot, otherwise use the current axes.
518541
@@ -591,7 +614,7 @@ def streamlines(
591614
# Plot the trajectory (if there is one)
592615
if traj.shape[1] > 1:
593616
with plt.rc_context(rcParams):
594-
out += ax.plot(traj[0], traj[1], color=color)
617+
out += ax.plot(traj[0], traj[1], color=color, zorder=zorder)
595618

596619
# Add arrows to the lines at specified intervals
597620
_add_arrows_to_line2D(
@@ -600,7 +623,7 @@ def streamlines(
600623

601624

602625
def equilpoints(
603-
sys, pointdata, gridspec=None, color='k', ax=None,
626+
sys, pointdata, gridspec=None, color='k', zorder=None, ax=None,
604627
_check_kwargs=True, **kwargs):
605628
"""Plot equilibrium points in the phase plane.
606629
@@ -634,6 +657,9 @@ def equilpoints(
634657
dict with key 'args' and value given by a tuple (passed to callable).
635658
color : str
636659
Plot the equilibrium points in the given color.
660+
zorder : float, optional
661+
Set the zorder for the separatrices. In not specified, it will be
662+
automatically chosen by `matplotlib.axes.Axes.plot`.
637663
ax : `matplotlib.axes.Axes`, optional
638664
Use the given axes for the plot, otherwise use the current axes.
639665
@@ -679,12 +705,12 @@ def equilpoints(
679705
out = []
680706
for xeq in equilpts:
681707
with plt.rc_context(rcParams):
682-
out += ax.plot(xeq[0], xeq[1], marker='o', color=color)
708+
out += ax.plot(xeq[0], xeq[1], marker='o', color=color, zorder=zorder)
683709
return out
684710

685711

686712
def separatrices(
687-
sys, pointdata, timedata=None, gridspec=None, ax=None,
713+
sys, pointdata, timedata=None, gridspec=None, zorder=None, ax=None,
688714
_check_kwargs=True, suppress_warnings=False, **kwargs):
689715
"""Plot separatrices in the phase plane.
690716
@@ -726,6 +752,9 @@ def separatrices(
726752
separatrices. If a tuple is given, the first element is used as
727753
the color specification for stable separatrices and the second
728754
element for unstable separatrices.
755+
zorder : float, optional
756+
Set the zorder for the separatrices. In not specified, it will be
757+
automatically chosen by `matplotlib.axes.Axes.plot`.
729758
ax : `matplotlib.axes.Axes`, optional
730759
Use the given axes for the plot, otherwise use the current axes.
731760
@@ -802,7 +831,7 @@ def separatrices(
802831
for i, xeq in enumerate(equilpts):
803832
# Plot the equilibrium points
804833
with plt.rc_context(rcParams):
805-
out += ax.plot(xeq[0], xeq[1], marker='o', color='k')
834+
out += ax.plot(xeq[0], xeq[1], marker='o', color='k', zorder=zorder)
806835

807836
# Figure out the linearization and eigenvectors
808837
evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A)
@@ -844,7 +873,7 @@ def separatrices(
844873
if traj.shape[1] > 1:
845874
with plt.rc_context(rcParams):
846875
out += ax.plot(
847-
traj[0], traj[1], color=color, linestyle=linestyle)
876+
traj[0], traj[1], color=color, linestyle=linestyle, zorder=zorder)
848877

849878
# Add arrows to the lines at specified intervals
850879
with plt.rc_context(rcParams):

control/tests/ctrlplot_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ def setup_plot_arguments(resp_fcn, plot_fcn, compute_time_response=True):
116116
args2 = (sys2, )
117117
argsc = ([sys1, sys2], )
118118

119+
case (None, ct.phase_plane_plot):
120+
args1 = (sys1, )
121+
args2 = (sys2, )
122+
plot_kwargs = {'plot_streamlines': True}
123+
119124
case _, _:
120125
args1 = (sys1, )
121126
args2 = (sys2, )

0 commit comments

Comments
 (0)