Skip to content

Commit d44a577

Browse files
committed
allow solve_ivp errors and improve arrow placement
1 parent f6f88f8 commit d44a577

3 files changed

Lines changed: 28 additions & 6 deletions

File tree

control/freqplot.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,9 +1963,22 @@ def _add_arrows_to_line2D(
19631963
if transform is None:
19641964
transform = axes.transData
19651965

1966+
# Figure out the size of the axes (length of diagonal)
1967+
xlim, ylim = axes.get_xlim(), axes.get_ylim()
1968+
ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]])
1969+
diag = np.linalg.norm(ul - lr)
1970+
19661971
# Compute the arc length along the curve
19671972
s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2))
19681973

1974+
# Truncate the number of arrows if the curve is short
1975+
# TODO: figure out a smarter way to do this
1976+
frac = min(s[-1] / diag, 1)
1977+
if len(arrow_locs) and frac < 0.05:
1978+
arrow_locs = [] # too short; no arrows at all
1979+
elif len(arrow_locs) and frac < 0.2:
1980+
arrow_locs = [0.5] # single arrow in the middle
1981+
19691982
arrows = []
19701983
for loc in arrow_locs:
19711984
n = np.searchsorted(s, s[-1] * loc)

control/nlsys.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,7 +1317,7 @@ def nlsys(
13171317

13181318

13191319
def input_output_response(
1320-
sys, T, U=0., X0=0, params=None,
1320+
sys, T, U=0., X0=0, params=None, ignore_error=False,
13211321
transpose=False, return_x=False, squeeze=None,
13221322
solve_ivp_kwargs=None, t_eval='T', **kwargs):
13231323
"""Compute the output response of a system to a given input.
@@ -1593,7 +1593,7 @@ def ivp_rhs(t, x):
15931593
soln = sp.integrate.solve_ivp(
15941594
ivp_rhs, (T0, Tf), X0, t_eval=t_eval,
15951595
vectorized=False, **solve_ivp_kwargs)
1596-
if not soln.success:
1596+
if not ignore_error and not soln.success:
15971597
raise RuntimeError("solve_ivp failed: " + soln.message)
15981598

15991599
# Compute inputs and outputs for each time point

control/phaseplot.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def streamlines(
375375
sys, revsys, timepts, X0, params, dir,
376376
gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim)
377377

378-
# Plot the trajectory
378+
# Plot the trajectory (if there is one)
379379
if traj.shape[1] > 1:
380380
out.append(
381381
ax.plot(traj[0], traj[1], color=color))
@@ -596,6 +596,7 @@ def separatrices(
596596
color = unstable_color
597597
linestyle = '-'
598598

599+
# Plot the trajectory (if there is one)
599600
if traj.shape[1] > 1:
600601
out.append(ax.plot(
601602
traj[0], traj[1], color=color, linestyle=linestyle))
@@ -883,12 +884,13 @@ def _create_trajectory(
883884
gridtype=None, gridspec=None, xlim=None, ylim=None):
884885
# Comput ethe forward trajectory
885886
if dir == 'forward' or dir == 'both':
886-
fwdresp = input_output_response(sys, timepts, X0=X0, params=params)
887+
fwdresp = input_output_response(
888+
sys, timepts, X0=X0, params=params, ignore_error=True)
887889

888890
# Compute the reverse trajectory
889891
if dir == 'reverse' or dir == 'both':
890892
revresp = input_output_response(
891-
revsys, timepts, X0=X0, params=params)
893+
revsys, timepts, X0=X0, params=params, ignore_error=True)
892894

893895
# Create the trace to plot
894896
if dir == 'forward':
@@ -898,7 +900,14 @@ def _create_trajectory(
898900
elif dir == 'both':
899901
traj = np.hstack([revresp.states[:, :1:-1], fwdresp.states])
900902

901-
return traj
903+
# Remove points outside the window (keep first point beyond boundary)
904+
inrange = np.asarray(
905+
(traj[0] >= xlim[0]) & (traj[0] <= xlim[1]) &
906+
(traj[1] >= ylim[0]) & (traj[1] <= ylim[1]))
907+
inrange[:-1] = inrange[:-1] | inrange[1:] # keep if next point in range
908+
inrange[1:] = inrange[1:] | inrange[:-1] # keep if prev point in range
909+
910+
return traj[:, inrange]
902911

903912

904913
def _make_timepts(timepts, i):

0 commit comments

Comments
 (0)