Skip to content

Commit a94f6ca

Browse files
committed
add plot_streamplot option to phase_plane_plot to use matplotlibs streamplot
1 parent f73e893 commit a94f6ca

3 files changed

Lines changed: 139 additions & 1 deletion

File tree

control/phaseplot.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def phase_plane_plot(
4949
sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
5050
plot_streamlines=True, plot_vectorfield=False, plot_equilpoints=True,
5151
plot_separatrices=True, ax=None, suppress_warnings=False, title=None,
52-
**kwargs
52+
plot_streamplot=False, **kwargs
5353
):
5454
"""Plot phase plane diagram.
5555
@@ -135,6 +135,10 @@ def phase_plane_plot(
135135
If True (default) then plot the vector field based on the pointdata
136136
and gridtype. If set to a dict, pass on the key-value pairs in
137137
the dict as keywords to `phaseplot.vectorfield`.
138+
plot_streamplot : bool or dict, optional
139+
If `True` then use matplotlib's streamplot function to plot the
140+
streamlines. If set to a dict, pass on the key-value pairs in the
141+
dict as keywords to :func:`~
138142
plot_equilpoints : bool or dict, optional
139143
If True (default) then plot equilibrium points based in the phase
140144
plot boundary. If set to a dict, pass on the key-value pairs in the
@@ -214,6 +218,16 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
214218
for kw in ['color', 'params']:
215219
initial_kwargs.pop(kw, None)
216220

221+
if plot_streamplot:
222+
kwargs_local = _create_kwargs(
223+
kwargs, plot_streamplot, gridspec=gridspec, ax=ax)
224+
streamplot(
225+
sys, pointdata, _check_kwargs=False, **kwargs_local)
226+
227+
# Get rid of keyword arguments handled by streamplot
228+
for kw in ['color', 'params']:
229+
initial_kwargs.pop(kw, None)
230+
217231
if plot_equilpoints:
218232
kwargs_local = _create_kwargs(
219233
kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
@@ -332,6 +346,116 @@ def vectorfield(
332346
return out
333347

334348

349+
def streamplot(
350+
sys, pointdata, gridspec=None, ax=None, vary_color=False,
351+
vary_linewidth=False, cmap=None, norm=None, suppress_warnings=False,
352+
_check_kwargs=True, **kwargs):
353+
"""Plot a vector field in the phase plane.
354+
355+
This function plots a vector field for a two-dimensional state
356+
space system.
357+
358+
Parameters
359+
----------
360+
sys : NonlinearIOSystem or callable(t, x, ...)
361+
I/O system or function used to generate phase plane data. If a
362+
function is given, the remaining arguments are drawn from the
363+
`params` keyword.
364+
pointdata : list or 2D array
365+
List of the form [xmin, xmax, ymin, ymax] describing the
366+
boundaries of the phase plot or an array of shape (N, 2)
367+
giving points from which to make the streamplot. In the latter case,
368+
the points lie on a grid like that generated by `meshgrid`.
369+
gridspec : list, optional
370+
Specifies the size of the grid in the x and y axes on which to
371+
generate points.
372+
params : dict or list, optional
373+
Parameters to pass to system. For an I/O system, `params` should be
374+
a dict of parameters and values. For a callable, `params` should be
375+
dict with key 'args' and value given by a tuple (passed to callable).
376+
color : matplotlib color spec, optional
377+
Plot the vector field in the given color.
378+
vary_color : bool, optional
379+
If set to `True`, vary the color of the streamlines based on the magnitude
380+
vary_linewidth : bool, optional
381+
If set to `True`, vary the linewidth of the streamlines based on the magnitude
382+
cmap : str or Colormap, optional
383+
Colormap to use for varying the color of the streamlines
384+
norm : Normalize, optional
385+
An instance of Normalize to use for scaling the colormap and linewidths
386+
ax : matplotlib.axes.Axes
387+
Use the given axes for the plot, otherwise use the current axes.
388+
389+
Returns
390+
-------
391+
out : Quiver
392+
393+
Other parameters
394+
----------------
395+
rcParams : dict
396+
Override the default parameters used for generating plots.
397+
Default is set by config.default['ctrlplot.rcParams'].
398+
suppress_warnings : bool, optional
399+
If set to `True`, suppress warning messages in generating trajectories.
400+
401+
"""
402+
# Process keywords
403+
rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
404+
405+
# Get system parameters
406+
params = kwargs.pop('params', None)
407+
408+
# Create system from callable, if needed
409+
sys = _create_system(sys, params)
410+
411+
# Determine the points on which to generate the streamplot field
412+
points, gridspec = _make_points(pointdata, gridspec, 'meshgrid')
413+
414+
# attempt to recover the grid by counting the jumps in xvals
415+
if gridspec is None:
416+
nrows = np.sum(np.diff(points[:, 0]) < 0) + 1
417+
ncols = points.shape[0] // nrows
418+
if nrows * ncols != points.shape[0]:
419+
raise ValueError("Could not recover grid from points.")
420+
gridspec = [nrows, ncols]
421+
422+
grid_arr_shape = gridspec[::-1]
423+
xs, ys = points[:, 0].reshape(grid_arr_shape), points[:, 1].reshape(grid_arr_shape)
424+
425+
# Create axis if needed
426+
if ax is None:
427+
ax = plt.gca()
428+
429+
# Set the plotting limits
430+
xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
431+
432+
# Figure out the color to use
433+
color = _get_color(kwargs, ax=ax)
434+
435+
# Make sure all keyword arguments were processed
436+
if _check_kwargs and kwargs:
437+
raise TypeError("unrecognized keywords: ", str(kwargs))
438+
439+
# Generate phase plane (quiver) data
440+
sys._update_params(params)
441+
us_flat, vs_flat = np.transpose([sys._rhs(0, x, np.zeros(sys.ninputs)) for x in points])
442+
us, vs = us_flat.reshape(grid_arr_shape), vs_flat.reshape(grid_arr_shape)
443+
444+
magnitudes = np.linalg.norm([us, vs], axis=0)
445+
norm = norm or mpl.colors.Normalize()
446+
normalized = norm(magnitudes)
447+
cmap = plt.get_cmap(cmap)
448+
449+
with plt.rc_context(rcParams):
450+
default_lw = plt.rcParams['lines.linewidth']
451+
min_lw, max_lw = 0.25*default_lw, 2*default_lw
452+
linewidths = normalized * (max_lw - min_lw) + min_lw if vary_linewidth else None
453+
color = magnitudes if vary_color else color
454+
455+
out = ax.streamplot(xs, ys, us, vs, color=color, linewidth=linewidths, cmap=cmap, norm=norm)
456+
457+
return out
458+
335459
def streamlines(
336460
sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
337461
ax=None, _check_kwargs=True, suppress_warnings=False, **kwargs):

control/tests/kwargs_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def test_unrecognized_kwargs(function, nsssys, ntfsys, moreargs, kwargs,
172172
(control.phase_plane_plot, 1, ([-1, 1, -1, 1], 1), {}),
173173
(control.phaseplot.streamlines, 1, ([-1, 1, -1, 1], 1), {}),
174174
(control.phaseplot.vectorfield, 1, ([-1, 1, -1, 1], ), {}),
175+
(control.phaseplot.streamplot, 1, ([-1, 1, -1, 1], ), {}),
175176
(control.phaseplot.equilpoints, 1, ([-1, 1, -1, 1], ), {}),
176177
(control.phaseplot.separatrices, 1, ([-1, 1, -1, 1], ), {}),
177178
(control.singular_values_plot, 1, (), {})]
@@ -360,6 +361,7 @@ def test_response_plot_kwargs(data_fcn, plot_fcn, mimo):
360361
optimal_test.test_oep_argument_errors,
361362
'phaseplot.streamlines': test_matplotlib_kwargs,
362363
'phaseplot.vectorfield': test_matplotlib_kwargs,
364+
'phaseplot.streamplot': test_matplotlib_kwargs,
363365
'phaseplot.equilpoints': test_matplotlib_kwargs,
364366
'phaseplot.separatrices': test_matplotlib_kwargs,
365367
}

control/tests/phaseplot_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,18 @@ def oscillator_update(t, x, u, params):
227227
if savefigs:
228228
plt.savefig('phaseplot-oscillator-helpers.png')
229229

230+
plt.figure()
231+
ct.phase_plane_plot(
232+
invpend, [-2*pi, 2*pi, -2, 2], plot_streamlines=False,
233+
plot_streamplot=dict(vary_color=True, vary_density=True),
234+
gridspec=[60, 20], params={'m': 1, 'l': 1, 'b': 0.2, 'g': 1}
235+
)
236+
plt.xlabel(r"$\theta$ [rad]")
237+
plt.ylabel(r"$\dot\theta$ [rad/sec]")
238+
239+
if savefigs:
240+
plt.savefig('phaseplot-invpend-streamplot.png')
241+
230242

231243
if __name__ == "__main__":
232244
#

0 commit comments

Comments
 (0)