3939
4040import numpy as np
4141import matplotlib .pyplot as mpl
42- from matplotlib . mlab import frange , find
42+
4343from scipy .integrate import odeint
4444from .exception import ControlNotImplemented
4545
4646__all__ = ['phase_plot' , 'box_grid' ]
4747
48+
49+ def _find (condition ):
50+ """Returns indices where ravel(a) is true.
51+ Private implementation of deprecated matplotlib.mlab.find
52+ """
53+ return np .nonzero (np .ravel (condition ))[0 ]
54+
55+
4856def phase_plot (odefun , X = None , Y = None , scale = 1 , X0 = None , T = None ,
4957 lingrid = None , lintime = None , logtime = None , timepts = None ,
5058 parms = (), verbose = True ):
@@ -70,11 +78,11 @@ def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
7078 dxdt = F(x, t) that accepts a state x of dimension 2 and
7179 returns a derivative dx/dt of dimension 2.
7280
73- X, Y: ndarray , optional
74- Two 1-D arrays representing x and y coordinates of a grid.
75- These arguments are passed to meshgrid and generate the lists
76- of points at which the vector field is plotted. If absent (or
77- None), the vector field is not plotted.
81+ X, Y: 3-element sequences , optional, as [start, stop, npts]
82+ Two 3-element sequences specifying x and y coordinates of a
83+ grid. These arguments are passed to linspace and meshgrid to
84+ generate the points at which the vector field is plotted. If
85+ absent (or None), the vector field is not plotted.
7886
7987 scale: float, optional
8088 Scale size of arrows; default = 1
@@ -145,8 +153,8 @@ def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
145153 #! TODO: Add sanity checks
146154 elif (X is not None and Y is not None ):
147155 (x1 , x2 ) = np .meshgrid (
148- frange (X [0 ], X [1 ], float ( X [ 1 ] - X [ 0 ]) / X [2 ]),
149- frange (Y [0 ], Y [1 ], float ( Y [ 1 ] - Y [ 0 ]) / Y [ 2 ]));
156+ np . linspace (X [0 ], X [1 ], X [2 ]),
157+ np . linspace (Y [0 ], Y [1 ], Y [ 2 ]))
150158 else :
151159 # If we weren't given any grid points, don't plot arrows
152160 Narrows = 0 ;
@@ -234,12 +242,12 @@ def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
234242 elif (logtimeFlag ):
235243 # Use an exponential time vector
236244 # MATLAB: tind = find(time < (j-k) / lambda, 1, 'last');
237- tarr = find (time < (j - k ) / timefactor );
245+ tarr = _find (time < (j - k ) / timefactor );
238246 tind = tarr [- 1 ] if len (tarr ) else 0 ;
239247 elif (timeptsFlag ):
240248 # Use specified time points
241249 # MATLAB: tind = find(time < Y[j], 1, 'last');
242- tarr = find (time < timepts [j ]);
250+ tarr = _find (time < timepts [j ]);
243251 tind = tarr [- 1 ] if len (tarr ) else 0 ;
244252
245253 # For tailless arrows, skip the first point
@@ -295,8 +303,8 @@ def box_grid(xlimp, ylimp):
295303 box defined by the corners [xmin ymin] and [xmax ymax].
296304 """
297305
298- sx10 = frange (xlimp [0 ], xlimp [1 ], float ( xlimp [ 1 ] - xlimp [ 0 ]) / xlimp [2 ])
299- sy10 = frange (ylimp [0 ], ylimp [1 ], float ( ylimp [ 1 ] - ylimp [ 0 ]) / ylimp [2 ])
306+ sx10 = np . linspace (xlimp [0 ], xlimp [1 ], xlimp [2 ])
307+ sy10 = np . linspace (ylimp [0 ], ylimp [1 ], ylimp [2 ])
300308
301309 sx1 = np .hstack ((0 , sx10 , 0 * sy10 + sx10 [0 ], sx10 , 0 * sy10 + sx10 [- 1 ]))
302310 sx2 = np .hstack ((0 , 0 * sx10 + sy10 [0 ], sy10 , 0 * sx10 + sy10 [- 1 ], sy10 ))
0 commit comments