Skip to content

Commit 1b71fa0

Browse files
committed
add warning messages on trajectory errors (+ ability to suppress)
1 parent d44a577 commit 1b71fa0

4 files changed

Lines changed: 118 additions & 54 deletions

File tree

control/nlsys.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@
1818
1919
"""
2020

21-
import numpy as np
22-
import scipy as sp
2321
import copy
2422
from warnings import warn
2523

24+
import numpy as np
25+
import scipy as sp
26+
2627
from . import config
27-
from .iosys import InputOutputSystem, _process_signal_list, \
28-
_process_iosys_keywords, isctime, isdtime, common_timebase, _parse_spec
29-
from .timeresp import _check_convert_array, _process_time_response, \
30-
TimeResponseData
28+
from .iosys import (InputOutputSystem, _parse_spec, _process_iosys_keywords,
29+
_process_signal_list, common_timebase, isctime, isdtime)
30+
from .timeresp import (TimeResponseData, _check_convert_array,
31+
_process_time_response)
3132

3233
__all__ = ['NonlinearIOSystem', 'InterconnectedSystem', 'nlsys',
3334
'input_output_response', 'find_eqpt', 'linearize',
@@ -528,7 +529,6 @@ def linearize(self, x0, u0, t=0, params=None, eps=1e-6,
528529
# numerical linearization use the `_rhs()` and `_out()` member
529530
# functions.
530531
#
531-
532532
# If x0 and u0 are specified as lists, concatenate the elements
533533
x0 = _concatenate_list_elements(x0, 'x0')
534534
u0 = _concatenate_list_elements(u0, 'u0')
@@ -1317,7 +1317,7 @@ def nlsys(
13171317

13181318

13191319
def input_output_response(
1320-
sys, T, U=0., X0=0, params=None, ignore_error=False,
1320+
sys, T, U=0., X0=0, params=None, ignore_errors=False,
13211321
transpose=False, return_x=False, squeeze=None,
13221322
solve_ivp_kwargs=None, t_eval='T', **kwargs):
13231323
"""Compute the output response of a system to a given input.
@@ -1393,6 +1393,11 @@ def input_output_response(
13931393
to 'RK45'.
13941394
solve_ivp_kwargs : dict, optional
13951395
Pass additional keywords to :func:`scipy.integrate.solve_ivp`.
1396+
ignore_errors : bool, optional
1397+
If ``False`` (default), errors during computation of the trajectory
1398+
will raise a ``RuntimeError`` exception. If ``True``, do not raise
1399+
an exception and instead set ``results.success`` to ``False`` and
1400+
place an error message in ``results.message``.
13961401
13971402
Raises
13981403
------
@@ -1593,8 +1598,12 @@ def ivp_rhs(t, x):
15931598
soln = sp.integrate.solve_ivp(
15941599
ivp_rhs, (T0, Tf), X0, t_eval=t_eval,
15951600
vectorized=False, **solve_ivp_kwargs)
1596-
if not ignore_error and not soln.success:
1597-
raise RuntimeError("solve_ivp failed: " + soln.message)
1601+
if not soln.success:
1602+
message = "solve_ivp failed: " + soln.message
1603+
if not ignore_errors:
1604+
raise RuntimeError(message)
1605+
else:
1606+
message = None
15981607

15991608
# Compute inputs and outputs for each time point
16001609
u = np.zeros((ninputs, len(soln.t)))
@@ -1650,7 +1659,7 @@ def ivp_rhs(t, x):
16501659
u = np.transpose(np.array(u))
16511660

16521661
# Mark solution as successful
1653-
soln.success = True # No way to fail
1662+
soln.success, message = True, None # No way to fail
16541663

16551664
else: # Neither ctime or dtime??
16561665
raise TypeError("Can't determine system type")
@@ -1660,7 +1669,8 @@ def ivp_rhs(t, x):
16601669
output_labels=sys.output_labels, input_labels=sys.input_labels,
16611670
state_labels=sys.state_labels, sysname=sys.name,
16621671
title="Input/output response for " + sys.name,
1663-
transpose=transpose, return_x=return_x, squeeze=squeeze)
1672+
transpose=transpose, return_x=return_x, squeeze=squeeze,
1673+
success=soln.success, message=message)
16641674

16651675

16661676
def find_eqpt(sys, x0, u0=None, y0=None, t=0, params=None,
@@ -2252,7 +2262,7 @@ def interconnect(
22522262
`outputs`, for more natural naming of SISO systems.
22532263
22542264
"""
2255-
from .statesp import StateSpace, LinearICSystem, _convert_to_statespace
2265+
from .statesp import LinearICSystem, StateSpace, _convert_to_statespace
22562266
from .xferfcn import TransferFunction
22572267

22582268
dt = kwargs.pop('dt', None) # bypass normal 'dt' processing

control/phaseplot.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
def phase_plane_plot(
5353
sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
5454
plot_streamlines=True, plot_vectorfield=False, plot_equilpoints=True,
55-
plot_separatrices=True, ax=None, **kwargs
55+
plot_separatrices=True, ax=None, suppress_warnings=False, **kwargs
5656
):
5757
"""Plot phase plane diagram.
5858
@@ -88,22 +88,6 @@ def phase_plane_plot(
8888
Parameters to pass to system. For an I/O system, `params` should be
8989
a dict of parameters and values. For a callable, `params` should be
9090
dict with key 'args' and value given by a tuple (passed to callable).
91-
plot_streamlines : bool or dict
92-
If `True` (default) then plot streamlines based on the pointdata
93-
and gridtype. If set to a dict, pass on the key-value pairs in
94-
the dict as keywords to :func:`~control.phaseplot.streamlines`.
95-
plot_vectorfield : bool or dict
96-
If `True` (default) then plot the vector field based on the pointdata
97-
and gridtype. If set to a dict, pass on the key-value pairs in
98-
the dict as keywords to :func:`~control.phaseplot.vectorfield`.
99-
plot_equilpoints : bool or dict
100-
If `True` (default) then plot equilibrium points based in the phase
101-
plot boundary. If set to a dict, pass on the key-value pairs in the
102-
dict as keywords to :func:`~control.phaseplot.equilpoints`.
103-
plot_separatrices : bool or dict
104-
If `True` (default) then plot separatrices starting from each
105-
equilibrium point. If set to a dict, pass on the key-value pairs
106-
in the dict as keywords to :func:`~control.phaseplot.separatrices`.
10791
color : str
10892
Plot all elements in the given color (use `plot_<fcn>={'color': c}`
10993
to set the color in one element of the phase plot.
@@ -117,6 +101,27 @@ def phase_plane_plot(
117101
out[1] = Quiver object (vector field arrows)
118102
out[2] = list of Line2D objects (equilibrium points)
119103
104+
Other parameters
105+
----------------
106+
plot_streamlines : bool or dict, optional
107+
If `True` (default) then plot streamlines based on the pointdata
108+
and gridtype. If set to a dict, pass on the key-value pairs in
109+
the dict as keywords to :func:`~control.phaseplot.streamlines`.
110+
plot_vectorfield : bool or dict, optional
111+
If `True` (default) then plot the vector field based on the pointdata
112+
and gridtype. If set to a dict, pass on the key-value pairs in
113+
the dict as keywords to :func:`~control.phaseplot.vectorfield`.
114+
plot_equilpoints : bool or dict, optional
115+
If `True` (default) then plot equilibrium points based in the phase
116+
plot boundary. If set to a dict, pass on the key-value pairs in the
117+
dict as keywords to :func:`~control.phaseplot.equilpoints`.
118+
plot_separatrices : bool or dict, optional
119+
If `True` (default) then plot separatrices starting from each
120+
equilibrium point. If set to a dict, pass on the key-value pairs
121+
in the dict as keywords to :func:`~control.phaseplot.separatrices`.
122+
suppress_warnings : bool, optional
123+
If set to `True`, suppress warning messages in generating trajectories.
124+
120125
"""
121126
# Process arguments
122127
params = kwargs.get('params', None)
@@ -149,7 +154,8 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
149154
kwargs, plot_streamlines, gridspec=gridspec, gridtype=gridtype,
150155
ax=ax)
151156
out[0] += streamlines(
152-
sys, pointdata, timedata, check_kwargs=False, **kwargs_local)
157+
sys, pointdata, timedata, check_kwargs=False,
158+
suppress_warnings=suppress_warnings, **kwargs_local)
153159

154160
# Get rid of keyword arguments handled by streamlines
155161
for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
@@ -203,7 +209,8 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
203209

204210

205211
def vectorfield(
206-
sys, pointdata, gridspec=None, ax=None, check_kwargs=True, **kwargs):
212+
sys, pointdata, gridspec=None, ax=None, suppress_warnings=False,
213+
check_kwargs=True, **kwargs):
207214
"""Plot a vector field in the phase plane.
208215
209216
This function plots a vector field for a two-dimensional state
@@ -244,6 +251,11 @@ def vectorfield(
244251
-------
245252
out : Quiver
246253
254+
Other parameters
255+
----------------
256+
suppress_warnings : bool, optional
257+
If set to `True`, suppress warning messages in generating trajectories.
258+
247259
"""
248260
# Get system parameters
249261
params = kwargs.pop('params', None)
@@ -283,8 +295,8 @@ def vectorfield(
283295

284296

285297
def streamlines(
286-
sys, pointdata, timedata=1, gridspec=None, gridtype=None,
287-
dir=None, ax=None, check_kwargs=True, **kwargs):
298+
sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
299+
ax=None, check_kwargs=True, suppress_warnings=False, **kwargs):
288300
"""Plot stream lines in the phase plane.
289301
290302
This function plots stream lines for a two-dimensional state space
@@ -328,6 +340,11 @@ def streamlines(
328340
-------
329341
out : list of Line2D objects
330342
343+
Other parameters
344+
----------------
345+
suppress_warnings : bool, optional
346+
If set to `True`, suppress warning messages in generating trajectories.
347+
331348
"""
332349
# Get system parameters
333350
params = kwargs.pop('params', None)
@@ -373,7 +390,8 @@ def streamlines(
373390
timepts = _make_timepts(timedata, i)
374391
traj = _create_trajectory(
375392
sys, revsys, timepts, X0, params, dir,
376-
gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim)
393+
gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim,
394+
suppress_warnings=suppress_warnings)
377395

378396
# Plot the trajectory (if there is one)
379397
if traj.shape[1] > 1:
@@ -465,7 +483,7 @@ def equilpoints(
465483

466484
def separatrices(
467485
sys, pointdata, timedata=None, gridspec=None, ax=None,
468-
check_kwargs=True, **kwargs):
486+
check_kwargs=True, suppress_warnings=False, **kwargs):
469487
"""Plot separatrices in the phase plane.
470488
471489
This function plots separatrices for a two-dimensional state space
@@ -509,6 +527,11 @@ def separatrices(
509527
-------
510528
out : list of Line2D objects
511529
530+
Other parameters
531+
----------------
532+
suppress_warnings : bool, optional
533+
If set to `True`, suppress warning messages in generating trajectories.
534+
512535
"""
513536
# Get system parameters
514537
params = kwargs.pop('params', None)
@@ -586,13 +609,15 @@ def separatrices(
586609
if evals[j].real < 0:
587610
traj = _create_trajectory(
588611
sys, revsys, timepts, x0, params, 'reverse',
589-
gridtype='boxgrid', xlim=xlim, ylim=ylim)
612+
gridtype='boxgrid', xlim=xlim, ylim=ylim,
613+
suppress_warnings=suppress_warnings)
590614
color = stable_color
591615
linestyle = '--'
592616
elif evals[j].real > 0:
593617
traj = _create_trajectory(
594618
sys, revsys, timepts, x0, params, 'forward',
595-
gridtype='boxgrid', xlim=xlim, ylim=ylim)
619+
gridtype='boxgrid', xlim=xlim, ylim=ylim,
620+
suppress_warnings=suppress_warnings)
596621
color = unstable_color
597622
linestyle = '-'
598623

@@ -880,17 +905,21 @@ def _get_color(kwargs, ax=None):
880905

881906

882907
def _create_trajectory(
883-
sys, revsys, timepts, X0, params, dir,
908+
sys, revsys, timepts, X0, params, dir, suppress_warnings=False,
884909
gridtype=None, gridspec=None, xlim=None, ylim=None):
885910
# Comput ethe forward trajectory
886911
if dir == 'forward' or dir == 'both':
887912
fwdresp = input_output_response(
888-
sys, timepts, X0=X0, params=params, ignore_error=True)
913+
sys, timepts, X0=X0, params=params, ignore_errors=True)
914+
if not fwdresp.success and not suppress_warnings:
915+
warnings.warn(f"{X0=}, {fwdresp.message}")
889916

890917
# Compute the reverse trajectory
891918
if dir == 'reverse' or dir == 'both':
892919
revresp = input_output_response(
893-
revsys, timepts, X0=X0, params=params, ignore_error=True)
920+
revsys, timepts, X0=X0, params=params, ignore_errors=True)
921+
if not revresp.success and not suppress_warnings:
922+
warnings.warn(f"{X0=}, {revresp.message}")
894923

895924
# Create the trace to plot
896925
if dir == 'forward':
@@ -1212,6 +1241,3 @@ def _find(condition):
12121241
Private implementation of deprecated matplotlib.mlab.find
12131242
"""
12141243
return np.nonzero(np.ravel(condition))[0]
1215-
1216-
1217-

control/tests/phaseplot_test.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
the figures so that you can check them visually.
1010
"""
1111

12+
import warnings
1213

1314
import matplotlib.pyplot as plt
1415
import numpy as np
15-
from numpy import pi
1616
import pytest
17-
from control import phase_plot
17+
from math import pi
18+
1819
import control as ct
1920
import control.phaseplot as pp
21+
from control import phase_plot
2022

2123

2224
# Legacy tests
@@ -156,7 +158,22 @@ def invpend_ode(t, x, m=0, l=0, b=0, g=0):
156158
ct.phase_plane_plot(
157159
invpend_ode, [-5, 5, 2, 2], params={'stuff': (1, 1, 0.2, 1)})
158160

159-
161+
# Warning messages for invalid solutions: nonlinear spring mass system
162+
sys = ct.nlsys(
163+
lambda t, x, u, params: np.array(
164+
[x[1], -0.25 * (x[0] - 0.01 * x[0]**3) - 0.1 * x[1]]),
165+
states=2, inputs=0)
166+
with pytest.warns(UserWarning, match=r"X0=array\(.*\), solve_ivp failed"):
167+
ct.phase_plane_plot(
168+
sys, [-12, 12, -10, 10], 15, gridspec=[2, 9],
169+
plot_separatrices=False)
170+
171+
# Turn warnings off
172+
with warnings.catch_warnings():
173+
warnings.simplefilter("error")
174+
ct.phase_plane_plot(
175+
sys, [-12, 12, -10, 10], 15, gridspec=[2, 9],
176+
plot_separatrices=False, suppress_warnings=True)
160177

161178

162179
def test_basic_phase_plots(savefigs=False):

0 commit comments

Comments
 (0)