Skip to content

Commit 58922ce

Browse files
committed
add string-based indexing of state space systems
1 parent b382abf commit 58922ce

3 files changed

Lines changed: 75 additions & 42 deletions

File tree

control/iosys.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,3 +1010,30 @@ def _parse_spec(syslist, spec, signame, dictname=None):
10101010
ValueError(f"signal index '{index}' is out of range")
10111011

10121012
return system_index, signal_indices, gain
1013+
1014+
1015+
#
1016+
# Utility function for processing subsystem indices
1017+
#
1018+
# This function processes an index specification (int, list, or slice) and
1019+
# returns a index specification that can be used to create a subsystem
1020+
#
1021+
def _process_subsys_index(idx, sys_labels, slice_to_list=False):
1022+
if not isinstance(idx, (slice, list, int)):
1023+
raise TypeError(f"system indices must be integers, slices, or lists")
1024+
1025+
# Convert singleton lists to integers for proper slicing (below)
1026+
if isinstance(idx, (list, tuple)) and len(idx) == 1:
1027+
idx = idx[0]
1028+
1029+
# Convert int to slice so that numpy doesn't drop dimension
1030+
if isinstance(idx, int): idx = slice(idx, idx+1, 1)
1031+
1032+
# Get label names (taking care of possibility that we were passed a list)
1033+
labels = [sys_labels[i] for i in idx] if isinstance(idx, list) \
1034+
else sys_labels[idx]
1035+
1036+
if slice_to_list and isinstance(idx, slice):
1037+
idx = range(len(sys_labels))[idx]
1038+
1039+
return idx, labels

control/statesp.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@
4848
"""
4949

5050
import math
51+
from collections.abc import Iterable
5152
from copy import deepcopy
5253
from warnings import warn
53-
from collections.abc import Iterable
5454

5555
import numpy as np
5656
import scipy as sp
5757
import scipy.linalg
58-
from numpy import (any, asarray, concatenate, cos, delete, empty, exp, eye,
59-
isinf, ones, pad, sin, squeeze, zeros)
58+
from numpy import any, asarray, concatenate, cos, delete, empty, exp, eye, \
59+
isinf, ones, pad, sin, squeeze, zeros
6060
from numpy.linalg import LinAlgError, eigvals, matrix_rank, solve
6161
from numpy.random import rand, randn
6262
from scipy.signal import StateSpace as signalStateSpace
@@ -65,9 +65,9 @@
6565
from . import config
6666
from .exception import ControlMIMONotImplemented, ControlSlycot, slycot_check
6767
from .frdata import FrequencyResponseData
68-
from .iosys import (InputOutputSystem, _process_dt_keyword,
69-
_process_iosys_keywords, _process_signal_list,
70-
common_timebase, isdtime, issiso)
68+
from .iosys import InputOutputSystem, NamedSignal, _process_dt_keyword, \
69+
_process_iosys_keywords, _process_signal_list, _process_subsys_index, \
70+
common_timebase, isdtime, issiso
7171
from .lti import LTI, _process_frequency_response
7272
from .nlsys import InterconnectedSystem, NonlinearIOSystem
7373

@@ -1214,25 +1214,25 @@ def append(self, other):
12141214
D[self.noutputs:, self.ninputs:] = other.D
12151215
return StateSpace(A, B, C, D, self.dt)
12161216

1217-
def __getitem__(self, indices):
1217+
def __getitem__(self, key):
12181218
"""Array style access"""
1219-
if not isinstance(indices, Iterable) or len(indices) != 2:
1220-
raise IOError('must provide indices of length 2 for state space')
1221-
outdx, inpdx = indices
1222-
1223-
# Convert int to slice to ensure that numpy doesn't drop the dimension
1224-
if isinstance(outdx, int): outdx = slice(outdx, outdx+1, 1)
1225-
if isinstance(inpdx, int): inpdx = slice(inpdx, inpdx+1, 1)
1219+
if not isinstance(key, Iterable) or len(key) != 2:
1220+
raise IOError("must provide indices of length 2 for state space")
12261221

1227-
if not isinstance(outdx, slice) or not isinstance(inpdx, slice):
1228-
raise TypeError(f"system indices must be integers or slices")
1222+
# Convert signal names to integer offsets
1223+
iomap = NamedSignal(self.D, self.output_labels, self.input_labels)
1224+
indices = iomap._parse_key(key)
1225+
outdx, output_labels = _process_subsys_index(
1226+
indices[0], self.output_labels)
1227+
inpdx, input_labels = _process_subsys_index(
1228+
indices[1], self.input_labels)
12291229

12301230
sysname = config.defaults['iosys.indexed_system_name_prefix'] + \
12311231
self.name + config.defaults['iosys.indexed_system_name_suffix']
12321232
return StateSpace(
1233-
self.A, self.B[:, inpdx], self.C[outdx, :], self.D[outdx, inpdx],
1234-
self.dt, name=sysname,
1235-
inputs=self.input_labels[inpdx], outputs=self.output_labels[outdx])
1233+
self.A, self.B[:, inpdx], self.C[outdx, :],
1234+
self.D[outdx, :][:, inpdx], self.dt,
1235+
name=sysname, inputs=input_labels, outputs=output_labels)
12361236

12371237
def sample(self, Ts, method='zoh', alpha=None, prewarp_frequency=None,
12381238
name=None, copy_names=True, **kwargs):

control/tests/statesp_test.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -473,39 +473,45 @@ def test_array_access_ss_failure(self):
473473
with pytest.raises(IOError):
474474
sys1[0]
475475

476-
@pytest.mark.parametrize("outdx, inpdx",
477-
[(0, 1),
478-
(slice(0, 1, 1), 1),
479-
(0, slice(1, 2, 1)),
480-
(slice(0, 1, 1), slice(1, 2, 1)),
481-
(slice(None, None, -1), 1),
482-
(0, slice(None, None, -1)),
483-
(slice(None, 2, None), 1),
484-
(slice(None, None, 1), slice(None, None, 2)),
485-
(0, slice(1, 2, 1)),
486-
(slice(0, 1, 1), slice(1, 2, 1))])
487-
def test_array_access_ss(self, outdx, inpdx):
476+
@pytest.mark.parametrize(
477+
"outdx, inpdx",
478+
[(0, 1),
479+
(slice(0, 1, 1), 1),
480+
(0, slice(1, 2, 1)),
481+
(slice(0, 1, 1), slice(1, 2, 1)),
482+
(slice(None, None, -1), 1),
483+
(0, slice(None, None, -1)),
484+
(slice(None, 2, None), 1),
485+
(slice(None, None, 1), slice(None, None, 2)),
486+
(0, slice(1, 2, 1)),
487+
(slice(0, 1, 1), slice(1, 2, 1)),
488+
# ([0, 1], [0]), # lists of indices
489+
])
490+
@pytest.mark.parametrize("named", [False, True])
491+
def test_array_access_ss(self, outdx, inpdx, named):
488492
sys1 = StateSpace(
489493
[[1., 2.], [3., 4.]],
490494
[[5., 6.], [7., 8.]],
491495
[[9., 10.], [11., 12.]],
492496
[[13., 14.], [15., 16.]], 1,
493497
inputs=['u0', 'u1'], outputs=['y0', 'y1'])
494498

495-
sys1_01 = sys1[outdx, inpdx]
496-
499+
if named:
500+
# Use names instead of numbers (and re-convert in statesp)
501+
outnames = sys1.output_labels[outdx]
502+
inpnames = sys1.input_labels[inpdx]
503+
sys1_01 = sys1[outnames, inpnames]
504+
else:
505+
sys1_01 = sys1[outdx, inpdx]
506+
497507
# Convert int to slice to ensure that numpy doesn't drop the dimension
498508
if isinstance(outdx, int): outdx = slice(outdx, outdx+1, 1)
499509
if isinstance(inpdx, int): inpdx = slice(inpdx, inpdx+1, 1)
500-
501-
np.testing.assert_array_almost_equal(sys1_01.A,
502-
sys1.A)
503-
np.testing.assert_array_almost_equal(sys1_01.B,
504-
sys1.B[:, inpdx])
505-
np.testing.assert_array_almost_equal(sys1_01.C,
506-
sys1.C[outdx, :])
507-
np.testing.assert_array_almost_equal(sys1_01.D,
508-
sys1.D[outdx, inpdx])
510+
511+
np.testing.assert_array_almost_equal(sys1_01.A, sys1.A)
512+
np.testing.assert_array_almost_equal(sys1_01.B, sys1.B[:, inpdx])
513+
np.testing.assert_array_almost_equal(sys1_01.C, sys1.C[outdx, :])
514+
np.testing.assert_array_almost_equal(sys1_01.D, sys1.D[outdx, inpdx])
509515

510516
assert sys1.dt == sys1_01.dt
511517
assert sys1_01.input_labels == sys1.input_labels[inpdx]

0 commit comments

Comments
 (0)