Skip to content

Commit a4986b6

Browse files
committed
add StateSpaceMatrix class
1 parent 1254ccb commit a4986b6

File tree

2 files changed

+359
-34
lines changed

2 files changed

+359
-34
lines changed

control/statesp.py

Lines changed: 164 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -65,29 +65,147 @@
6565
from .xferfcn import _convert_to_transfer_function
6666
from copy import deepcopy
6767

68-
__all__ = ['StateSpace', 'ss', 'rss', 'drss', 'tf2ss', 'ssdata']
68+
__all__ = ['StateSpaceMatrix', 'StateSpace', 'ss', 'rss', 'drss', 'tf2ss',
69+
'ssdata', 'ssmatrix']
6970

71+
class StateSpaceMatrix(np.ndarray):
72+
"""StateSpaceMatrix(M, [axis=0])
7073
71-
def _matrix(a):
72-
"""Wrapper around numpy.matrix that reshapes empty matrices to be 0x0
74+
A class for representing state-space matrices
7375
74-
Parameters
75-
----------
76-
a: sequence passed to numpy.matrix
76+
The StateSpaceMatrix class is used to represent the state-space
77+
matrices t hat define a StateSpace system (A, B, C, D). It is
78+
mainly a wrapper for the ndarray class, but it maintains matrices
79+
as 2-dimensional arrays. It is similar to the NDmatrix class
80+
(which is being deprecated), but with more limited functionality.
7781
78-
Returns
79-
-------
80-
am: result of numpy.matrix(a), except if a is empty, am will be 0x0.
82+
In addition, for empty matrices the size of the StateSpaceMatrix
83+
instance is set to 0x0 (needed within various StateSpace methods).
8184
82-
numpy.matrix([]) has size 1x0; for empty StateSpace objects, we
83-
need 0x0 matrices, so use this instead of numpy.matrix in this
84-
module.
8585
"""
86-
from numpy import matrix
87-
am = matrix(a, dtype=float)
88-
if (1, 0) == am.shape:
89-
am.shape = (0, 0)
90-
return am
86+
# Allow ndarray * StateSpace to give StateSpace._rmul_() priority
87+
__array_priority__ = 20 # override ndarray and matrix types
88+
89+
def __new__(subtype, data=[], axis=1, dtype=float, copy=True):
90+
"""Create a StateSpaceMatrix object
91+
92+
Parameters
93+
----------
94+
data: array-like object defining matrix values
95+
axis: if data is 1D, choose its axis in the 2D representation
96+
97+
Returns
98+
-------
99+
M: 2D array representing the matrix. If data = [], shape = (0,0)
100+
"""
101+
102+
# Legacy code
103+
# self = np.matrix.__new__(subtype, data, dtype=float)
104+
# if (self.shape == (1, 0)):
105+
# self.shape = (0, 0)
106+
# return self
107+
108+
# See if this is already a StateSpaceMatrix(from np.matrix.__new__)
109+
if isinstance(data, StateSpaceMatrix):
110+
dtype2 = data.dtype
111+
if (dtype is None):
112+
dtype = dtype2
113+
if (dtype2 == dtype) and (not copy):
114+
return data
115+
return data.astype(dtype)
116+
117+
# If data is passed as a string, use (deprecated?) matrix constructor
118+
if isinstance(data, str):
119+
data = np.matrix(data, copy=True)
120+
121+
# Convert the data into an array
122+
arr = np.array(data, dtype=dtype, copy=copy)
123+
ndim = arr.ndim
124+
shape = arr.shape
125+
126+
# Change the shape of the array into a 2D array
127+
if (ndim > 2):
128+
raise ValueError("state-space matrix must be 2-dimensional")
129+
130+
elif (ndim == 2 and shape == (1, 0)) or \
131+
(ndim == 1 and shape == (0, )):
132+
# Passed an empty matrix or empty vector; change shape to (0, 0)
133+
shape = (0, 0)
134+
135+
elif ndim == 1:
136+
# Passed a row or column vector
137+
shape = (1, shape[0]) if axis == 1 else (shape[0], 1)
138+
139+
elif ndim == 0:
140+
# Passed a constant; turn into a matrix
141+
shape = (1, 1)
142+
143+
# Make sure the data representation matches the shape we used
144+
# (from np.matrix.__new__)
145+
order = 'C'
146+
if (ndim == 2) and arr.flags.fortran:
147+
order = 'F'
148+
149+
# Create the actual object used to store the result
150+
self = np.ndarray.__new__(subtype, shape, arr.dtype,
151+
buffer = arr, order = order)
152+
return self
153+
154+
# Override multiplication operation to emulate nd.matrix (vs elementwise)
155+
def __mul__(self, other):
156+
"""Multiply or scale state-space matrices"""
157+
# return np.matrix.__mul__(self, other) # legacy
158+
# Check to see if arguments are (real) scalars in disguise
159+
# (from np.matrix.__mul__)
160+
if isinstance(other, (np.ndarray, list, tuple)) :
161+
# return np.dot(self, np.asmatrix(other)) # legacy
162+
# Promote 1D vectors to row matrices and return ndarray
163+
#! TODO: check to see if return type is correct
164+
product = np.dot(self, np.array(other, copy=False, ndmin=2))
165+
return product if np.isrealobj(product) else np.asarray(product)
166+
167+
if np.isscalar(other) or not hasattr(other, '__rmul__') :
168+
product = np.dot(self, other)
169+
return product if np.isrealobj(product) else np.asarray(product)
170+
171+
return NotImplemented
172+
173+
def __rmul__(self, other):
174+
"""Multiply or scale state-space matrices"""
175+
# return np.matrix.__rmul__(self, other) # legacy
176+
product = np.dot(other, self)
177+
return product if np.isrealobj(product) else np.asarray(product)
178+
179+
def __getitem__(self, index):
180+
"""Get elements of a state-space matrix and return matrix"""
181+
# return np.matrix.__getitem__(self, index) # legacy
182+
self._getitem = True # to get back raw item (from np.matrix.__mul__)
183+
try:
184+
out = np.ndarray.__getitem__(self, index)
185+
finally:
186+
self._getitem = False
187+
188+
if out.ndim == 0:
189+
# If we get down to a scalar, return the actual scalar
190+
return out[()]
191+
if out.ndim == 1:
192+
# Got a row/column vector; figure out what to return
193+
# (from np.matrix.__mul__)
194+
sh = out.shape[0]
195+
try:
196+
n = len(index)
197+
except Exception:
198+
n = 0
199+
if n > 1 and np.isscalar(index[1]):
200+
out.shape = (sh, 1)
201+
else:
202+
out.shape = (1, sh)
203+
return out
204+
205+
206+
#! TODO: Remove this function once changes are all done
207+
def _matrix(a): return StateSpaceMatrix(a)
208+
def ssmatrix(a): return StateSpaceMatrix(a, copy=False)
91209

92210

93211
class StateSpace(LTI):
@@ -115,6 +233,10 @@ class StateSpace(LTI):
115233
sampling time.
116234
"""
117235

236+
# Allow ndarray * StateSpace to give StateSpace._rmul_() priority
237+
# https://docs.scipy.org/doc/numpy/reference/arrays.classes.html#numpy.class.__array_priority__
238+
__array_priority__ = 11 # override ndarray and matrix types
239+
118240
def __init__(self, *args):
119241
"""
120242
StateSpace(A, B, C, D[, dt])
@@ -198,8 +320,8 @@ def _remove_useless_states(self):
198320
# as an array.
199321
ax1_A = np.where(~self.A.any(axis=1))[0]
200322
ax1_B = np.where(~self.B.any(axis=1))[0]
201-
ax0_A = np.where(~self.A.any(axis=0))[1]
202-
ax0_C = np.where(~self.C.any(axis=0))[1]
323+
ax0_A = np.where(~self.A.any(axis=0))[0]
324+
ax0_C = np.where(~self.C.any(axis=0))[0]
203325
useless_1 = np.intersect1d(ax1_A, ax1_B, assume_unique=True)
204326
useless_2 = np.intersect1d(ax0_A, ax0_C, assume_unique=True)
205327
useless = np.union1d(useless_1, useless_2)
@@ -593,9 +715,12 @@ def feedback(self, other=1, sign=-1):
593715
T1 = eye(self.outputs) + sign * D1 * E_D2
594716
T2 = eye(self.inputs) + sign * E_D2 * D1
595717

596-
A = concatenate((concatenate((A1 + sign * B1 * E_D2 * C1, sign * B1 * E_C2), axis=1),
597-
concatenate((B2 * T1 * C1, A2 + sign * B2 * D1 * E_C2), axis=1)),
598-
axis=0)
718+
A = concatenate((
719+
concatenate(
720+
(A1 + sign * B1 * E_D2 * C1, sign * B1 * E_C2), axis=1),
721+
concatenate(
722+
(B2 * T1 * C1, A2 + sign * B2 * D1 * E_C2), axis=1)),
723+
axis=0)
599724
B = concatenate((B1 * T2, B2 * D1 * T2), axis=0)
600725
C = concatenate((T1 * C1, sign * D1 * E_C2), axis=1)
601726
D = D1 * T2
@@ -740,8 +865,9 @@ def returnScipySignalLTI(self):
740865

741866
for i in range(self.outputs):
742867
for j in range(self.inputs):
743-
out[i][j] = lti(asarray(self.A), asarray(self.B[:, j]),
744-
asarray(self.C[i, :]), asarray(self.D[i, j]))
868+
out[i][j] = lti(asarray(self.A), asarray(self.B[:, [j]]),
869+
asarray(self.C[[i], :]),
870+
_matrix(self.D[[i], [j]]))
745871

746872
return out
747873

@@ -778,7 +904,8 @@ def __getitem__(self, indices):
778904
raise IOError('must provide indices of length 2 for state space')
779905
i = indices[0]
780906
j = indices[1]
781-
return StateSpace(self.A, self.B[:, j], self.C[i, :], self.D[i, j], self.dt)
907+
return StateSpace(self.A, self.B[:, [j]], self.C[[i], :],
908+
_matrix(self.D[[i], [j]]), self.dt)
782909

783910
def sample(self, Ts, method='zoh', alpha=None):
784911
"""Convert a continuous time system to discrete time
@@ -859,10 +986,10 @@ def dcgain(self):
859986
def _convertToStateSpace(sys, **kw):
860987
"""Convert a system to state space form (if needed).
861988
862-
If sys is already a state space, then it is returned. If sys is a transfer
863-
function object, then it is converted to a state space and returned. If sys
864-
is a scalar, then the number of inputs and outputs can be specified
865-
manually, as in:
989+
If sys is already a state space, then it is returned. If sys is a
990+
transfer function object, then it is converted to a state space
991+
and returned. If sys is a scalar, then the number of inputs and
992+
outputs can be specified manually, as in:
866993
867994
>>> sys = _convertToStateSpace(3.) # Assumes inputs = outputs = 1
868995
>>> sys = _convertToStateSpace(1., inputs=3, outputs=2)
@@ -875,8 +1002,8 @@ def _convertToStateSpace(sys, **kw):
8751002
import itertools
8761003
if isinstance(sys, StateSpace):
8771004
if len(kw):
878-
raise TypeError("If sys is a StateSpace, _convertToStateSpace \
879-
cannot take keywords.")
1005+
raise TypeError("If sys is a StateSpace object, "
1006+
"_convertToStateSpace cannot take keywords.")
8801007

8811008
# Already a state space system; just return it
8821009
return sys
@@ -897,8 +1024,10 @@ def _convertToStateSpace(sys, **kw):
8971024
denorder, den, num, tol=0)
8981025

8991026
states = ssout[0]
900-
return StateSpace(ssout[1][:states, :states], ssout[2][:states, :sys.inputs],
901-
ssout[3][:sys.outputs, :states], ssout[4], sys.dt)
1027+
return StateSpace(ssout[1][:states, :states],
1028+
ssout[2][:states, :sys.inputs],
1029+
ssout[3][:sys.outputs, :states],
1030+
ssout[4], sys.dt)
9021031
except ImportError:
9031032
# No Slycot. Scipy tf->ss can't handle MIMO, but static
9041033
# MIMO is an easy special case we can check for here
@@ -908,7 +1037,8 @@ def _convertToStateSpace(sys, **kw):
9081037
for drow in sys.den)
9091038
if 1 == maxn and 1 == maxd:
9101039
D = empty((sys.outputs, sys.inputs), dtype=float)
911-
for i, j in itertools.product(range(sys.outputs), range(sys.inputs)):
1040+
for i, j in itertools.product(range(sys.outputs),
1041+
range(sys.inputs)):
9121042
D[i, j] = sys.num[i][j][0] / sys.den[i][j][0]
9131043
return StateSpace([], [], [], D, sys.dt)
9141044
else:

0 commit comments

Comments
 (0)