Skip to content

Commit c0e4cb4

Browse files
committed
add plot_streamplot option to phase_plane_plot to use matplotlibs streamplot
1 parent 0ff0452 commit c0e4cb4

File tree

3 files changed

+140
-1
lines changed

3 files changed

+140
-1
lines changed

control/phaseplot.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
* separatrices: Plot separatrices in the phase plane
2525
* streamlines: Plot stream lines in the phase plane
2626
* vectorfield: Plot a vector field in the phase plane
27+
* streampot: Plot streamlines using matplotlib's streamplot function
2728
2829
"""
2930

@@ -55,7 +56,7 @@ def phase_plane_plot(
5556
sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
5657
plot_streamlines=True, plot_vectorfield=False, plot_equilpoints=True,
5758
plot_separatrices=True, ax=None, suppress_warnings=False, title=None,
58-
**kwargs
59+
plot_streamplot=False, **kwargs
5960
):
6061
"""Plot phase plane diagram.
6162
@@ -133,6 +134,10 @@ def phase_plane_plot(
133134
If `True` (default) then plot the vector field based on the pointdata
134135
and gridtype. If set to a dict, pass on the key-value pairs in
135136
the dict as keywords to :func:`~control.phaseplot.vectorfield`.
137+
plot_streamplot : bool or dict, optional
138+
If `True` then use matplotlib's streamplot function to plot the
139+
streamlines. If set to a dict, pass on the key-value pairs in the
140+
dict as keywords to :func:`~
136141
plot_equilpoints : bool or dict, optional
137142
If `True` (default) then plot equilibrium points based in the phase
138143
plot boundary. If set to a dict, pass on the key-value pairs in the
@@ -213,6 +218,16 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
213218
for kw in ['color', 'params']:
214219
initial_kwargs.pop(kw, None)
215220

221+
if plot_streamplot:
222+
kwargs_local = _create_kwargs(
223+
kwargs, plot_streamplot, gridspec=gridspec, ax=ax)
224+
streamplot(
225+
sys, pointdata, _check_kwargs=False, **kwargs_local)
226+
227+
# Get rid of keyword arguments handled by streamplot
228+
for kw in ['color', 'params']:
229+
initial_kwargs.pop(kw, None)
230+
216231
if plot_equilpoints:
217232
kwargs_local = _create_kwargs(
218233
kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
@@ -331,6 +346,116 @@ def vectorfield(
331346
return out
332347

333348

349+
def streamplot(
350+
sys, pointdata, gridspec=None, ax=None, vary_color=False,
351+
vary_linewidth=False, cmap=None, norm=None, suppress_warnings=False,
352+
_check_kwargs=True, **kwargs):
353+
"""Plot a vector field in the phase plane.
354+
355+
This function plots a vector field for a two-dimensional state
356+
space system.
357+
358+
Parameters
359+
----------
360+
sys : NonlinearIOSystem or callable(t, x, ...)
361+
I/O system or function used to generate phase plane data. If a
362+
function is given, the remaining arguments are drawn from the
363+
`params` keyword.
364+
pointdata : list or 2D array
365+
List of the form [xmin, xmax, ymin, ymax] describing the
366+
boundaries of the phase plot or an array of shape (N, 2)
367+
giving points from which to make the streamplot. In the latter case,
368+
the points lie on a grid like that generated by `meshgrid`.
369+
gridspec : list, optional
370+
Specifies the size of the grid in the x and y axes on which to
371+
generate points.
372+
params : dict or list, optional
373+
Parameters to pass to system. For an I/O system, `params` should be
374+
a dict of parameters and values. For a callable, `params` should be
375+
dict with key 'args' and value given by a tuple (passed to callable).
376+
color : matplotlib color spec, optional
377+
Plot the vector field in the given color.
378+
vary_color : bool, optional
379+
If set to `True`, vary the color of the streamlines based on the magnitude
380+
vary_linewidth : bool, optional
381+
If set to `True`, vary the linewidth of the streamlines based on the magnitude
382+
cmap : str or Colormap, optional
383+
Colormap to use for varying the color of the streamlines
384+
norm : Normalize, optional
385+
An instance of Normalize to use for scaling the colormap and linewidths
386+
ax : matplotlib.axes.Axes
387+
Use the given axes for the plot, otherwise use the current axes.
388+
389+
Returns
390+
-------
391+
out : Quiver
392+
393+
Other parameters
394+
----------------
395+
rcParams : dict
396+
Override the default parameters used for generating plots.
397+
Default is set by config.default['ctrlplot.rcParams'].
398+
suppress_warnings : bool, optional
399+
If set to `True`, suppress warning messages in generating trajectories.
400+
401+
"""
402+
# Process keywords
403+
rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
404+
405+
# Get system parameters
406+
params = kwargs.pop('params', None)
407+
408+
# Create system from callable, if needed
409+
sys = _create_system(sys, params)
410+
411+
# Determine the points on which to generate the streamplot field
412+
points, gridspec = _make_points(pointdata, gridspec, 'meshgrid')
413+
414+
# attempt to recover the grid by counting the jumps in xvals
415+
if gridspec is None:
416+
nrows = np.sum(np.diff(points[:, 0]) < 0) + 1
417+
ncols = points.shape[0] // nrows
418+
if nrows * ncols != points.shape[0]:
419+
raise ValueError("Could not recover grid from points.")
420+
gridspec = [nrows, ncols]
421+
422+
grid_arr_shape = gridspec[::-1]
423+
xs, ys = points[:, 0].reshape(grid_arr_shape), points[:, 1].reshape(grid_arr_shape)
424+
425+
# Create axis if needed
426+
if ax is None:
427+
ax = plt.gca()
428+
429+
# Set the plotting limits
430+
xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
431+
432+
# Figure out the color to use
433+
color = _get_color(kwargs, ax=ax)
434+
435+
# Make sure all keyword arguments were processed
436+
if _check_kwargs and kwargs:
437+
raise TypeError("unrecognized keywords: ", str(kwargs))
438+
439+
# Generate phase plane (quiver) data
440+
sys._update_params(params)
441+
us_flat, vs_flat = np.transpose([sys._rhs(0, x, np.zeros(sys.ninputs)) for x in points])
442+
us, vs = us_flat.reshape(grid_arr_shape), vs_flat.reshape(grid_arr_shape)
443+
444+
magnitudes = np.linalg.norm([us, vs], axis=0)
445+
norm = norm or mpl.colors.Normalize()
446+
normalized = norm(magnitudes)
447+
cmap = plt.get_cmap(cmap)
448+
449+
with plt.rc_context(rcParams):
450+
default_lw = plt.rcParams['lines.linewidth']
451+
min_lw, max_lw = 0.25*default_lw, 2*default_lw
452+
linewidths = normalized * (max_lw - min_lw) + min_lw if vary_linewidth else None
453+
color = magnitudes if vary_color else color
454+
455+
out = ax.streamplot(xs, ys, us, vs, color=color, linewidth=linewidths, cmap=cmap, norm=norm)
456+
457+
return out
458+
334459
def streamlines(
335460
sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
336461
ax=None, _check_kwargs=True, suppress_warnings=False, **kwargs):

control/tests/kwargs_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def test_unrecognized_kwargs(function, nsssys, ntfsys, moreargs, kwargs,
171171
(control.phase_plane_plot, 1, ([-1, 1, -1, 1], 1), {}),
172172
(control.phaseplot.streamlines, 1, ([-1, 1, -1, 1], 1), {}),
173173
(control.phaseplot.vectorfield, 1, ([-1, 1, -1, 1], ), {}),
174+
(control.phaseplot.streamplot, 1, ([-1, 1, -1, 1], ), {}),
174175
(control.phaseplot.equilpoints, 1, ([-1, 1, -1, 1], ), {}),
175176
(control.phaseplot.separatrices, 1, ([-1, 1, -1, 1], ), {}),
176177
(control.singular_values_plot, 1, (), {})]
@@ -347,6 +348,7 @@ def test_response_plot_kwargs(data_fcn, plot_fcn, mimo):
347348
optimal_test.test_oep_argument_errors,
348349
'phaseplot.streamlines': test_matplotlib_kwargs,
349350
'phaseplot.vectorfield': test_matplotlib_kwargs,
351+
'phaseplot.streamplot': test_matplotlib_kwargs,
350352
'phaseplot.equilpoints': test_matplotlib_kwargs,
351353
'phaseplot.separatrices': test_matplotlib_kwargs,
352354
}

control/tests/phaseplot_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,18 @@ def oscillator_update(t, x, u, params):
227227
if savefigs:
228228
plt.savefig('phaseplot-oscillator-helpers.png')
229229

230+
plt.figure()
231+
ct.phase_plane_plot(
232+
invpend, [-2*pi, 2*pi, -2, 2], plot_streamlines=False,
233+
plot_streamplot=dict(vary_color=True, vary_density=True),
234+
gridspec=[60, 20], params={'m': 1, 'l': 1, 'b': 0.2, 'g': 1}
235+
)
236+
plt.xlabel(r"$\theta$ [rad]")
237+
plt.ylabel(r"$\dot\theta$ [rad/sec]")
238+
239+
if savefigs:
240+
plt.savefig('phaseplot-invpend-streamplot.png')
241+
230242

231243
if __name__ == "__main__":
232244
#

0 commit comments

Comments
 (0)