Skip to content

Commit 1f5de6f

Browse files
committed
refactor processing of x0, u0 keywords in nlsys
1 parent 36263d8 commit 1f5de6f

3 files changed

Lines changed: 78 additions & 57 deletions

File tree

control/nlsys.py

Lines changed: 67 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -532,26 +532,16 @@ def linearize(self, x0, u0, t=0, params=None, eps=1e-6,
532532
# numerical linearization use the `_rhs()` and `_out()` member
533533
# functions.
534534
#
535-
# If x0 and u0 are specified as lists, concatenate the elements
536-
x0 = _concatenate_list_elements(x0, 'x0')
537-
u0 = _concatenate_list_elements(u0, 'u0')
535+
# Process nominal states and inputs
536+
x0, nstates = _process_vector_argument(x0, "x0", self.nstates)
537+
u0, ninputs = _process_vector_argument(u0, "u0", self.ninputs)
538538

539-
# Figure out dimensions if they were not specified.
540-
nstates = _find_size(self.nstates, x0, "x0")
541-
ninputs = _find_size(self.ninputs, u0, "u0")
542-
543-
# Convert x0, u0 to arrays, if needed
544-
if np.isscalar(x0):
545-
x0 = np.ones((nstates,)) * x0
546-
if np.isscalar(u0):
547-
u0 = np.ones((ninputs,)) * u0
539+
# Update the current parameters (prior to calling _out())
540+
self._update_params(params)
548541

549542
# Compute number of outputs by evaluating the output function
550543
noutputs = _find_size(self.noutputs, self._out(t, x0, u0), "outputs")
551544

552-
# Update the current parameters
553-
self._update_params(params)
554-
555545
# Compute the nominal value of the update law and output
556546
F0 = self._rhs(t, x0, u0)
557547
H0 = self._out(t, x0, u0)
@@ -1468,8 +1458,16 @@ def input_output_response(
14681458
# Use the input time points as the output time points
14691459
t_eval = T
14701460

1461+
#
1462+
# Process input argument
1463+
#
1464+
# The input argument is interpreted very flexibly, allowing the
1465+
# use of listsa and/or tuples of mixed scalar and vector elements.
1466+
#
1467+
# Much of the processing here is similar to the processing in
1468+
# _process_vector_argument, but applied to a time series.
1469+
14711470
# If we were passed a list of inputs, concatenate them (w/ broadcast)
1472-
# TODO: call _concatenate_list_elements
14731471
if isinstance(U, (tuple, list)) and len(U) != ntimepts:
14741472
U_elements = []
14751473
for i, u in enumerate(U):
@@ -1494,7 +1492,6 @@ def input_output_response(
14941492
U = np.vstack(U_elements)
14951493

14961494
# Figure out the number of inputs
1497-
# TODO: call _concatenate_list_elements?
14981495
if sys.ninputs is None:
14991496
if isinstance(U, np.ndarray):
15001497
ninputs = U.shape[0] if U.size > 1 else U.size
@@ -1516,22 +1513,8 @@ def input_output_response(
15161513
U = U.reshape(-1, ntimepts)
15171514
ninputs = U.shape[0]
15181515

1519-
# If we were passed a list of initial states, concatenate them
1520-
X0 = _concatenate_list_elements(X0, 'X0')
1521-
1522-
# If the initial state is too short, make it longer (NB: sys.nstates
1523-
# could be None if nstates comes from size of initial condition)
1524-
if sys.nstates and isinstance(X0, np.ndarray) and X0.size < sys.nstates:
1525-
if X0[-1] != 0:
1526-
warn("initial state too short; padding with zeros")
1527-
X0 = np.hstack([X0, np.zeros(sys.nstates - X0.size)])
1528-
1529-
# Compute the number of states
1530-
nstates = _find_size(sys.nstates, X0, "states")
1531-
1532-
# create X0 if not given, test if X0 has correct shape
1533-
X0 = _check_convert_array(
1534-
X0, [(nstates,), (nstates, 1)], 'Parameter ``X0``: ', squeeze=True)
1516+
# Process initial states
1517+
X0, nstates = _process_vector_argument(X0, "X0", sys.nstates)
15351518

15361519
# Update the parameter values (prior to evaluating outfcn)
15371520
sys._update_params(params)
@@ -1752,17 +1735,9 @@ def find_eqpt(sys, x0, u0=None, y0=None, t=0, params=None,
17521735
from scipy.optimize import root
17531736

17541737
# Figure out the number of states, inputs, and outputs
1755-
nstates = _find_size(sys.nstates, x0, "x0")
1756-
ninputs = _find_size(sys.ninputs, u0, "u0")
1757-
noutputs = _find_size(sys.noutputs, y0, "y0")
1758-
1759-
# Convert x0, u0, y0 to arrays, if needed
1760-
if np.isscalar(x0):
1761-
x0 = np.ones((nstates,)) * x0
1762-
if np.isscalar(u0):
1763-
u0 = np.ones((ninputs,)) * u0
1764-
if np.isscalar(y0):
1765-
y0 = np.ones((ninputs,)) * y0
1738+
x0, nstates = _process_vector_argument(x0, "x0", sys.nstates)
1739+
u0, ninputs = _process_vector_argument(u0, "u0", sys.ninputs)
1740+
y0, noutputs = _process_vector_argument(y0, "y0", sys.noutputs)
17661741

17671742
# Make sure the input arguments match the sizes of the system
17681743
if len(x0) != nstates or \
@@ -2572,18 +2547,55 @@ def interconnect(
25722547
return newsys
25732548

25742549

2575-
# Utility function to allow lists of states, inputs
2576-
def _concatenate_list_elements(X, name='X'):
2577-
# If we were passed a list, concatenate the elements together
2578-
if isinstance(X, (tuple, list)):
2579-
X_list = []
2580-
for i, x in enumerate(X):
2581-
x = np.array(x).reshape(-1) # convert everyting to 1D array
2582-
X_list += x.tolist() # add elements to initial state
2583-
return np.array(X_list)
2550+
def _process_vector_argument(arg, name, size):
2551+
"""Utility function to process vector elements (states, inputs)
2552+
2553+
Process state and input arguments to turn them into lists of the
2554+
appropriate length.
2555+
2556+
Parameters
2557+
----------
2558+
arg : array_like
2559+
Value of the parameter passed to the function. Can be a list,
2560+
tuple, ndarray, scalar, or None.
2561+
name : string
2562+
Name of the argument being processed. Used in errors/warnings.
2563+
size : int or None
2564+
Size of the element. If None, size is determined by arg.
2565+
2566+
Returns
2567+
-------
2568+
val : array or None
2569+
Value of the element, zero-padded to proper length.
2570+
nelem : int or None
2571+
Number of elements in the returned value.
2572+
2573+
Warns
2574+
-----
2575+
UserWarning : "{name} too short; padding with zeros"
2576+
If argument is too short and last value in arg is not 0.
2577+
2578+
"""
2579+
# Allow and expand list
2580+
if isinstance(arg, (tuple, list)):
2581+
val_list = []
2582+
for i, v in enumerate(arg):
2583+
v = np.array(v).reshape(-1) # convert to 1D array
2584+
val_list += v.tolist() # add elements to list
2585+
val = np.array(val_list)
2586+
elif np.isscalar(arg) and size is not None: # extend scalars
2587+
val = np.ones((size, )) * arg
2588+
else:
2589+
val = arg # return what we were given
2590+
2591+
if size is not None and isinstance(val, np.ndarray) and val.size < size:
2592+
# If needed, extend the size of the vector to match desired size
2593+
if val[-1] != 0:
2594+
warn(f"{name} too short; padding with zeros")
2595+
val = np.hstack([val, np.zeros(size - val.size)])
25842596

2585-
# Otherwise, do nothing
2586-
return X
2597+
nelem = _find_size(size, val, name) # determine size
2598+
return val, nelem
25872599

25882600

25892601
# Utility function to create an I/O system from a static gain

control/tests/iosys_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,15 @@ def test_linearize(self, tsys, kincar):
231231
linearized.C, [[1, 0, 0], [0, 1, 0]])
232232
np.testing.assert_array_almost_equal(linearized.D, np.zeros((2,2)))
233233

234+
# Pass fewer than the required elements
235+
padded = iosys.linearize([0, 0], np.array([0]))
236+
assert padded.nstates == linearized.nstates
237+
assert padded.ninputs == linearized.ninputs
238+
239+
# Check for warning if last element before padding is nonzero
240+
with pytest.warns(UserWarning, match="x0 too short; padding"):
241+
padded = iosys.linearize([0, 1], np.array([0]))
242+
234243
@pytest.mark.usefixtures("editsdefaults")
235244
def test_linearize_named_signals(self, kincar):
236245
# Full form of the call
@@ -1886,7 +1895,7 @@ def test_input_output_broadcasting():
18861895
np.testing.assert_equal(resp_cov0.states, resp_init.states)
18871896

18881897
# Specify only some of the initial conditions
1889-
with pytest.warns(UserWarning, match="initial state too short; padding"):
1898+
with pytest.warns(UserWarning, match="X0 too short; padding"):
18901899
resp_short = ct.input_output_response(sys, T, [U[0], [0, 1]], [X0, 1])
18911900

18921901
# Make sure that inconsistent settings don't work

control/tests/nlsys_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def kincar_output(t, x, u, params):
4646
])
4747
def test_lti_nlsys_response(nin, nout, input, output):
4848
sys_ss = ct.rss(4, nin, nout, strictly_proper=True)
49-
sys_ss.A = np.diag([-1, -2, -3, -4]) # avoid random noise errors
49+
sys_ss.A = np.diag([-1, -2, -3, -4]) # avoid random numerical errors
5050
sys_nl = ct.nlsys(
5151
lambda t, x, u, params: sys_ss.A @ x + sys_ss.B @ u,
5252
lambda t, x, u, params: sys_ss.C @ x + sys_ss.D @ u,

0 commit comments

Comments
 (0)