Skip to content

Commit 236c1d4

Browse files
committed
ctrb converts ndim=1 B correctly; ctrb & obsv check input shapes
1 parent 0ff0452 commit 236c1d4

File tree

2 files changed

+57
-12
lines changed

2 files changed

+57
-12
lines changed

control/statefbk.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,19 +1098,26 @@ def ctrb(A, B, t=None):
10981098
"""
10991099

11001100
# Convert input parameters to matrices (if they aren't already)
1101-
amat = _ssmatrix(A)
1102-
bmat = _ssmatrix(B)
1103-
n = np.shape(amat)[0]
1104-
m = np.shape(bmat)[1]
1101+
A = _ssmatrix(A)
1102+
if np.asarray(B).ndim == 1 and len(B) == A.shape[0]:
1103+
B = _ssmatrix(B, axis=0)
1104+
else:
1105+
B = _ssmatrix(B)
1106+
1107+
n = A.shape[0]
1108+
m = B.shape[1]
1109+
1110+
_check_shape('A', A, n, n, square=True)
1111+
_check_shape('B', B, n, m)
11051112

11061113
if t is None or t > n:
11071114
t = n
11081115

11091116
# Construct the controllability matrix
11101117
ctrb = np.zeros((n, t * m))
1111-
ctrb[:, :m] = bmat
1118+
ctrb[:, :m] = B
11121119
for k in range(1, t):
1113-
ctrb[:, k * m:(k + 1) * m] = np.dot(amat, ctrb[:, (k - 1) * m:k * m])
1120+
ctrb[:, k * m:(k + 1) * m] = np.dot(A, ctrb[:, (k - 1) * m:k * m])
11141121

11151122
return _ssmatrix(ctrb)
11161123

@@ -1140,20 +1147,24 @@ def obsv(A, C, t=None):
11401147
"""
11411148

11421149
# Convert input parameters to matrices (if they aren't already)
1143-
amat = _ssmatrix(A)
1144-
cmat = _ssmatrix(C)
1145-
n = np.shape(amat)[0]
1146-
p = np.shape(cmat)[0]
1150+
A = _ssmatrix(A)
1151+
C = _ssmatrix(C)
1152+
1153+
n = np.shape(A)[0]
1154+
p = np.shape(C)[0]
1155+
1156+
_check_shape('A', A, n, n, square=True)
1157+
_check_shape('C', C, p, n)
11471158

11481159
if t is None or t > n:
11491160
t = n
11501161

11511162
# Construct the observability matrix
11521163
obsv = np.zeros((t * p, n))
1153-
obsv[:p, :] = cmat
1164+
obsv[:p, :] = C
11541165

11551166
for k in range(1, t):
1156-
obsv[k * p:(k + 1) * p, :] = np.dot(obsv[(k - 1) * p:k * p, :], amat)
1167+
obsv[k * p:(k + 1) * p, :] = np.dot(obsv[(k - 1) * p:k * p, :], A)
11571168

11581169
return _ssmatrix(obsv)
11591170

control/tests/statefbk_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,23 @@ def testCtrbT(self):
5757
Wc = ctrb(A, B, t=t)
5858
np.testing.assert_array_almost_equal(Wc, Wctrue)
5959

60+
def testCtrbNdim1(self):
61+
# gh-1097: treat 1-dim B as nx1
62+
A = np.array([[1., 2.], [3., 4.]])
63+
B = np.array([5., 7.])
64+
Wctrue = np.array([[5., 19.], [7., 43.]])
65+
Wc = ctrb(A, B)
66+
np.testing.assert_array_almost_equal(Wc, Wctrue)
67+
68+
def testCtrbRejectMismatch(self):
69+
# gh-1097: check A, B for compatible shapes
70+
with pytest.raises(ControlDimension, match='A must be a square matrix'):
71+
ctrb([[1,2]],[1])
72+
with pytest.raises(ControlDimension, match='Incompatible dimensions of B matrix'):
73+
ctrb([[1,2],[2,3]], 1)
74+
with pytest.raises(ControlDimension, match='Incompatible dimensions of B matrix'):
75+
ctrb([[1,2],[2,3]], [[1,2]])
76+
6077
def testObsvSISO(self):
6178
A = np.array([[1., 2.], [3., 4.]])
6279
C = np.array([[5., 7.]])
@@ -79,6 +96,23 @@ def testObsvT(self):
7996
Wo = obsv(A, C, t=t)
8097
np.testing.assert_array_almost_equal(Wo, Wotrue)
8198

99+
def testObsvNdim1(self):
100+
# gh-1097: treat 1-dim C as 1xn
101+
A = np.array([[1., 2.], [3., 4.]])
102+
C = np.array([5., 7.])
103+
Wotrue = np.array([[5., 7.], [26., 38.]])
104+
Wo = obsv(A, C)
105+
np.testing.assert_array_almost_equal(Wo, Wotrue)
106+
107+
def testObsvRejectMismatch(self):
108+
# gh-1097: check A, B for compatible shapes
109+
with pytest.raises(ControlDimension, match='A must be a square matrix'):
110+
obsv([[1,2]],[1])
111+
with pytest.raises(ControlDimension, match='Incompatible dimensions of C matrix'):
112+
obsv([[1,2],[2,3]], 1)
113+
with pytest.raises(ControlDimension, match='Incompatible dimensions of C matrix'):
114+
obsv([[1,2],[2,3]], [[1],[2]])
115+
82116
def testCtrbObsvDuality(self):
83117
A = np.array([[1.2, -2.3], [3.4, -4.5]])
84118
B = np.array([[5.8, 6.9], [8., 9.1]])

0 commit comments

Comments
 (0)