Skip to content

Commit 391d29c

Browse files
committed
update StateSpace size checks/error messages to be more informative
1 parent 8780bdc commit 391d29c

2 files changed

Lines changed: 38 additions & 32 deletions

File tree

control/statesp.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ def __init__(self, *args, **kwargs):
252252
D = np.zeros((C.shape[0], B.shape[1]))
253253
D = _ssmatrix(D)
254254

255+
# If only direct term is present, adjust sizes of C and D if needed
256+
if D.size > 0 and B.size == 0:
257+
B = np.zeros((0, D.shape[1]))
258+
if D.size > 0 and C.size == 0:
259+
C = np.zeros((D.shape[0], 0))
260+
255261
# Matrices defining the linear system
256262
self.A = A
257263
self.B = B
@@ -268,7 +274,7 @@ def __init__(self, *args, **kwargs):
268274

269275
# Process iosys keywords
270276
defaults = args[0] if len(args) == 1 else \
271-
{'inputs': D.shape[1], 'outputs': D.shape[0],
277+
{'inputs': B.shape[1], 'outputs': C.shape[0],
272278
'states': A.shape[0]}
273279
name, inputs, outputs, states, dt = _process_iosys_keywords(
274280
kwargs, defaults, static=(A.size == 0))
@@ -295,16 +301,15 @@ def __init__(self, *args, **kwargs):
295301
# Check to make sure everything is consistent
296302
#
297303
# Check that the matrix sizes are consistent
298-
if A.shape[0] != A.shape[1] or self.nstates != A.shape[0]:
299-
raise ValueError("A must be square.")
300-
if self.nstates != B.shape[0]:
301-
raise ValueError("A and B must have the same number of rows.")
302-
if self.nstates != C.shape[1]:
303-
raise ValueError("A and C must have the same number of columns.")
304-
if self.ninputs != B.shape[1] or self.ninputs != D.shape[1]:
305-
raise ValueError("B and D must have the same number of columns.")
306-
if self.noutputs != C.shape[0] or self.noutputs != D.shape[0]:
307-
raise ValueError("C and D must have the same number of rows.")
304+
def _check_shape(matrix, expected, name):
305+
if matrix.shape != expected:
306+
raise ValueError(
307+
f"{name} is the wrong shape; "
308+
f"expected {expected} instead of {matrix.shape}")
309+
_check_shape(A, (self.nstates, self.nstates), "A")
310+
_check_shape(B, (self.nstates, self.ninputs), "B")
311+
_check_shape(C, (self.noutputs, self.nstates), "C")
312+
_check_shape(D, (self.noutputs, self.ninputs), "D")
308313

309314
#
310315
# Final processing

control/tests/statesp_test.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -121,29 +121,30 @@ def test_constructor(self, sys322ABCD, dt, argfun):
121121
np.testing.assert_almost_equal(sys.D, sys322ABCD[3])
122122
assert sys.dt == dtref
123123

124-
@pytest.mark.parametrize("args, exc, errmsg",
125-
[((True, ), TypeError,
126-
"(can only take in|sys must be) a StateSpace"),
127-
((1, 2), TypeError, "1, 4, or 5 arguments"),
128-
((np.ones((3, 2)), np.ones((3, 2)),
129-
np.ones((2, 2)), np.ones((2, 2))),
130-
ValueError, "A must be square"),
131-
((np.ones((3, 3)), np.ones((2, 2)),
132-
np.ones((2, 3)), np.ones((2, 2))),
133-
ValueError, "A and B"),
134-
((np.ones((3, 3)), np.ones((3, 2)),
135-
np.ones((2, 2)), np.ones((2, 2))),
136-
ValueError, "A and C"),
137-
((np.ones((3, 3)), np.ones((3, 2)),
138-
np.ones((2, 3)), np.ones((2, 3))),
139-
ValueError, "B and D"),
140-
((np.ones((3, 3)), np.ones((3, 2)),
141-
np.ones((2, 3)), np.ones((3, 2))),
142-
ValueError, "C and D"),
143-
])
124+
@pytest.mark.parametrize(
125+
"args, exc, errmsg",
126+
[((True, ), TypeError, "(can only take in|sys must be) a StateSpace"),
127+
((1, 2), TypeError, "1, 4, or 5 arguments"),
128+
((np.ones((3, 2)), np.ones((3, 2)),
129+
np.ones((2, 2)), np.ones((2, 2))), ValueError,
130+
"A is the wrong shape; expected \(3, 3\)"),
131+
((np.ones((3, 3)), np.ones((2, 2)),
132+
np.ones((2, 3)), np.ones((2, 2))), ValueError,
133+
"B is the wrong shape; expected \(3, 2\)"),
134+
((np.ones((3, 3)), np.ones((3, 2)),
135+
np.ones((2, 2)), np.ones((2, 2))), ValueError,
136+
"C is the wrong shape; expected \(2, 3\)"),
137+
((np.ones((3, 3)), np.ones((3, 2)),
138+
np.ones((2, 3)), np.ones((2, 3))), ValueError,
139+
"D is the wrong shape; expected \(2, 2\)"),
140+
((np.ones((3, 3)), np.ones((3, 2)),
141+
np.ones((2, 3)), np.ones((3, 2))), ValueError,
142+
"D is the wrong shape; expected \(2, 2\)"),
143+
])
144144
def test_constructor_invalid(self, args, exc, errmsg):
145145
"""Test invalid input to StateSpace() constructor"""
146-
with pytest.raises(exc, match=errmsg):
146+
147+
with pytest.raises(exc, match=errmsg) as w:
147148
StateSpace(*args)
148149
with pytest.raises(exc, match=errmsg):
149150
ss(*args)

0 commit comments

Comments
 (0)