1212import warnings
1313from math import pi
1414
15+ import matplotlib as mpl
1516import matplotlib .pyplot as plt
1617import numpy as np
1718import pytest
@@ -137,29 +138,39 @@ def invpend_ode(t, x, m=0, l=0, b=0, g=0):
137138
138139 # Use callable form, with parameters (if not correct, will get /0 error)
139140 ct .phase_plane_plot (
140- invpend_ode , [- 5 , 5 , - 2 , 2 ], params = {'args' : (1 , 1 , 0.2 , 1 )})
141+ invpend_ode , [- 5 , 5 , - 2 , 2 ], params = {'args' : (1 , 1 , 0.2 , 1 )},
142+ plot_streamlines = True )
141143
142144 # Linear I/O system
143145 ct .phase_plane_plot (
144- ct .ss ([[0 , 1 ], [- 1 , - 1 ]], [[0 ], [1 ]], [[1 , 0 ]], 0 ))
146+ ct .ss ([[0 , 1 ], [- 1 , - 1 ]], [[0 ], [1 ]], [[1 , 0 ]], 0 ),
147+ plot_streamlines = True )
145148
146149
147150@pytest .mark .usefixtures ('mplcleanup' )
148151def test_phaseplane_errors ():
149152 with pytest .raises (ValueError , match = "invalid grid specification" ):
150- ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), gridspec = 'bad' )
153+ ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), gridspec = 'bad' ,
154+ plot_streamlines = True )
151155
152156 with pytest .raises (ValueError , match = "unknown grid type" ):
153- ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), gridtype = 'bad' )
157+ ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), gridtype = 'bad' ,
158+ plot_streamlines = True )
154159
155160 with pytest .raises (ValueError , match = "system must be planar" ):
156- ct .phase_plane_plot (ct .rss (3 , 1 , 1 ))
161+ ct .phase_plane_plot (ct .rss (3 , 1 , 1 ),
162+ plot_streamlines = True )
157163
158164 with pytest .raises (ValueError , match = "params must be dict with key" ):
159165 def invpend_ode (t , x , m = 0 , l = 0 , b = 0 , g = 0 ):
160166 return (x [1 ], - b / m * x [1 ] + (g * l / m ) * np .sin (x [0 ]))
161167 ct .phase_plane_plot (
162- invpend_ode , [- 5 , 5 , 2 , 2 ], params = {'stuff' : (1 , 1 , 0.2 , 1 )})
168+ invpend_ode , [- 5 , 5 , 2 , 2 ], params = {'stuff' : (1 , 1 , 0.2 , 1 )},
169+ plot_streamlines = True )
170+
171+ with pytest .raises (ValueError , match = "gridtype must be 'meshgrid' when using streamplot" ):
172+ ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), plot_streamlines = False ,
173+ plot_streamplot = True , gridtype = 'boxgrid' )
163174
164175 # Warning messages for invalid solutions: nonlinear spring mass system
165176 sys = ct .nlsys (
@@ -170,14 +181,87 @@ def invpend_ode(t, x, m=0, l=0, b=0, g=0):
170181 UserWarning , match = r"initial_state=\[.*\], solve_ivp failed" ):
171182 ct .phase_plane_plot (
172183 sys , [- 12 , 12 , - 10 , 10 ], 15 , gridspec = [2 , 9 ],
173- plot_separatrices = False )
184+ plot_separatrices = False , plot_streamlines = True )
174185
175186 # Turn warnings off
176187 with warnings .catch_warnings ():
177188 warnings .simplefilter ("error" )
178189 ct .phase_plane_plot (
179190 sys , [- 12 , 12 , - 10 , 10 ], 15 , gridspec = [2 , 9 ],
180- plot_separatrices = False , suppress_warnings = True )
191+ plot_streamlines = True , plot_separatrices = False ,
192+ suppress_warnings = True )
193+
194+ @pytest .mark .usefixtures ('mplcleanup' )
195+ def test_phase_plot_zorder ():
196+ # some of these tests are a bit akward since the streamlines and separatrices
197+ # are stored in the same list, so we separate them by color
198+ key_color = "tab:blue" # must not be 'k', 'r', 'b' since they are used by separatrices
199+
200+ def get_zorders (cplt ):
201+ max_zorder = lambda items : max ([line .get_zorder () for line in items ])
202+ assert isinstance (cplt .lines [0 ], list )
203+ streamline_lines = [line for line in cplt .lines [0 ] if line .get_color () == key_color ]
204+ separatrice_lines = [line for line in cplt .lines [0 ] if line .get_color () != key_color ]
205+ streamlines = max_zorder (streamline_lines ) if streamline_lines else None
206+ separatrices = max_zorder (separatrice_lines ) if separatrice_lines else None
207+ assert cplt .lines [1 ] == None or isinstance (cplt .lines [1 ], mpl .quiver .Quiver )
208+ quiver = cplt .lines [1 ].get_zorder () if cplt .lines [1 ] else None
209+ assert cplt .lines [2 ] == None or isinstance (cplt .lines [2 ], list )
210+ equilpoints = max_zorder (cplt .lines [2 ]) if cplt .lines [2 ] else None
211+ assert cplt .lines [3 ] == None or isinstance (cplt .lines [3 ], mpl .streamplot .StreamplotSet )
212+ streamplot = max (cplt .lines [3 ].lines .get_zorder (), cplt .lines [3 ].arrows .get_zorder ()) if cplt .lines [3 ] else None
213+ return streamlines , quiver , streamplot , separatrices , equilpoints
214+
215+ def assert_orders (streamlines , quiver , streamplot , separatrices , equilpoints ):
216+ print (streamlines , quiver , streamplot , separatrices , equilpoints )
217+ if streamlines is not None :
218+ assert streamlines < separatrices < equilpoints
219+ if quiver is not None :
220+ assert quiver < separatrices < equilpoints
221+ if streamplot is not None :
222+ assert streamplot < separatrices < equilpoints
223+
224+ def sys (t , x ):
225+ return np .array ([4 * x [1 ], - np .sin (4 * x [0 ])])
226+
227+ # ensure correct zordering for all three flow types
228+ res_streamlines = ct .phase_plane_plot (sys , plot_streamlines = dict (color = key_color ))
229+ assert_orders (* get_zorders (res_streamlines ))
230+ res_vectorfield = ct .phase_plane_plot (sys , plot_vectorfield = True )
231+ assert_orders (* get_zorders (res_vectorfield ))
232+ res_streamplot = ct .phase_plane_plot (sys , plot_streamplot = True )
233+ assert_orders (* get_zorders (res_streamplot ))
234+
235+ # ensure that zorder can still be overwritten
236+ res_reversed = ct .phase_plane_plot (sys , plot_streamlines = dict (color = key_color , zorder = 50 ), plot_vectorfield = dict (zorder = 40 ),
237+ plot_streamplot = dict (zorder = 30 ), plot_separatrices = dict (zorder = 20 ), plot_equilpoints = dict (zorder = 10 ))
238+ streamlines , quiver , streamplot , separatrices , equilpoints = get_zorders (res_reversed )
239+ assert streamlines > quiver > streamplot > separatrices > equilpoints
240+
241+
242+ @pytest .mark .usefixtures ('mplcleanup' )
243+ def test_stream_plot_magnitude ():
244+ def sys (t , x ):
245+ return np .array ([4 * x [1 ], - np .sin (4 * x [0 ])])
246+
247+ # plt context with linewidth
248+ with plt .rc_context ({'lines.linewidth' : 4 }):
249+ res = ct .phase_plane_plot (sys , plot_streamplot = dict (vary_linewidth = True ))
250+ linewidths = res .lines [3 ].lines .get_linewidths ()
251+ # linewidths are scaled to be between 0.25 and 2 times default linewidth
252+ # but the extremes may not exist if there is no line at that point
253+ assert min (linewidths ) < 2 and max (linewidths ) > 7
254+
255+ # make sure changing the colormap works
256+ res = ct .phase_plane_plot (sys , plot_streamplot = dict (vary_color = True , cmap = 'viridis' ))
257+ assert res .lines [3 ].lines .get_cmap ().name == 'viridis'
258+ res = ct .phase_plane_plot (sys , plot_streamplot = dict (vary_color = True , cmap = 'turbo' ))
259+ assert res .lines [3 ].lines .get_cmap ().name == 'turbo'
260+
261+ # make sure changing the norm at least doesn't throw an error
262+ ct .phase_plane_plot (sys , plot_streamplot = dict (vary_color = True , norm = mpl .colors .LogNorm ()))
263+
264+
181265
182266
183267@pytest .mark .usefixtures ('mplcleanup' )
@@ -189,7 +273,7 @@ def test_basic_phase_plots(savefigs=False):
189273 plt .figure ()
190274 axis_limits = [- 1 , 1 , - 1 , 1 ]
191275 T = 8
192- ct .phase_plane_plot (sys , axis_limits , T )
276+ ct .phase_plane_plot (sys , axis_limits , T , plot_streamlines = True )
193277 if savefigs :
194278 plt .savefig ('phaseplot-dampedosc-default.png' )
195279
@@ -202,7 +286,7 @@ def invpend_update(t, x, u, params):
202286 ct .phase_plane_plot (
203287 invpend , [- 2 * pi , 2 * pi , - 2 , 2 ], 5 ,
204288 gridtype = 'meshgrid' , gridspec = [5 , 8 ], arrows = 3 ,
205- plot_separatrices = {'gridspec' : [12 , 9 ]},
289+ plot_separatrices = {'gridspec' : [12 , 9 ]}, plot_streamlines = True ,
206290 params = {'m' : 1 , 'l' : 1 , 'b' : 0.2 , 'g' : 1 })
207291 plt .xlabel (r"$\theta$ [rad]" )
208292 plt .ylabel (r"$\dot\theta$ [rad/sec]" )
@@ -217,7 +301,8 @@ def oscillator_update(t, x, u, params):
217301 oscillator_update , states = 2 , inputs = 0 , name = 'nonlinear oscillator' )
218302
219303 plt .figure ()
220- ct .phase_plane_plot (oscillator , [- 1.5 , 1.5 , - 1.5 , 1.5 ], 0.9 )
304+ ct .phase_plane_plot (oscillator , [- 1.5 , 1.5 , - 1.5 , 1.5 ], 0.9 ,
305+ plot_streamlines = True )
221306 pp .streamlines (
222307 oscillator , np .array ([[0 , 0 ]]), 1.5 ,
223308 gridtype = 'circlegrid' , gridspec = [0.5 , 6 ], dir = 'both' )
@@ -227,6 +312,18 @@ def oscillator_update(t, x, u, params):
227312 if savefigs :
228313 plt .savefig ('phaseplot-oscillator-helpers.png' )
229314
315+ plt .figure ()
316+ ct .phase_plane_plot (
317+ invpend , [- 2 * pi , 2 * pi , - 2 , 2 ],
318+ plot_streamplot = dict (vary_color = True , vary_density = True ),
319+ gridspec = [60 , 20 ], params = {'m' : 1 , 'l' : 1 , 'b' : 0.2 , 'g' : 1 }
320+ )
321+ plt .xlabel (r"$\theta$ [rad]" )
322+ plt .ylabel (r"$\dot\theta$ [rad/sec]" )
323+
324+ if savefigs :
325+ plt .savefig ('phaseplot-invpend-streamplot.png' )
326+
230327
231328if __name__ == "__main__" :
232329 #
0 commit comments