Skip to content

Commit 5df6fb7

Browse files
committed
implement NamedSignal's for time responses
1 parent 5d7fb42 commit 5df6fb7

3 files changed

Lines changed: 107 additions & 7 deletions

File tree

control/iosys.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
from . import config
1616

17-
__all__ = ['InputOutputSystem', 'issiso', 'timebase', 'common_timebase',
18-
'isdtime', 'isctime']
17+
__all__ = ['InputOutputSystem', 'NamedSignal', 'issiso', 'timebase',
18+
'common_timebase', 'isdtime', 'isctime']
1919

2020
# Define module default parameter values
2121
_iosys_defaults = {
@@ -33,6 +33,51 @@
3333
}
3434

3535

36+
# Named signal class
37+
class NamedSignal(np.ndarray):
38+
def __new__(cls, input_array, signal_labels=None, trace_labels=None):
39+
# See https://numpy.org/doc/stable/user/basics.subclassing.html
40+
obj = np.asarray(input_array).view(cls) # Cast to our class type
41+
obj.signal_labels = signal_labels # Save signal labels
42+
obj.trace_labels = trace_labels # Save trace labels
43+
return obj # Return new object
44+
45+
def __array_finalize__(self, obj):
46+
# See https://numpy.org/doc/stable/user/basics.subclassing.html
47+
if obj is None: return
48+
self.signal_labels = getattr(obj, 'signal_labels', None)
49+
self.trace_labels = getattr(obj, 'trace_labels', None)
50+
51+
def _parse_key(self, key, labels=None):
52+
if labels is None:
53+
labels = self.signal_labels
54+
try:
55+
if isinstance(key, str):
56+
key = labels.index(item := key)
57+
elif isinstance(key, list):
58+
keylist = []
59+
for item in key: # use for loop to save item for error
60+
keylist.append(self._parse_key(item, labels=labels))
61+
key = keylist
62+
elif isinstance(key, tuple):
63+
keylist = []
64+
keylist.append(
65+
self._parse_key(item := key[0], labels=self.signal_labels))
66+
if len(key) > 1:
67+
keylist.append(
68+
self._parse_key(
69+
item := key[1], labels=self.trace_labels))
70+
for i in range(2, len(key)):
71+
keylist.append(key[i]) # pass on remaining elements
72+
key = tuple(keylist)
73+
except ValueError:
74+
raise ValueError(f"unknown signal name '{item}'")
75+
return key
76+
77+
def __getitem__(self, key):
78+
return super().__getitem__(self._parse_key(key))
79+
80+
3681
class InputOutputSystem(object):
3782
"""A class for representing input/output systems.
3883

control/tests/timeresp_test.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def test_discrete_time_impulse(self, tsystem):
537537
sysdt = sys.sample(dt, 'impulse')
538538
np.testing.assert_array_almost_equal(impulse_response(sys, t)[1],
539539
impulse_response(sysdt, t)[1])
540-
540+
541541
def test_discrete_time_impulse_input(self):
542542
# discrete time impulse input, Only one active input for each trace
543543
A = [[.5, 0.25],[.0, .5]]
@@ -1318,3 +1318,54 @@ def test_step_info_nonstep():
13181318
assert step_info['Peak'] == 1
13191319
assert step_info['PeakTime'] == 0
13201320
assert isclose(step_info['SteadyStateValue'], 0.96)
1321+
1322+
1323+
def test_signal_labels():
1324+
# Create a system response for a SISO system
1325+
sys = ct.rss(4, 1, 1)
1326+
response = ct.step_response(sys)
1327+
1328+
# Make sure access via strings works
1329+
np.testing.assert_equal(response.inputs['u[0]'], response.inputs[0])
1330+
np.testing.assert_equal(response.states['x[2]'], response.states[2])
1331+
1332+
# Make sure access via lists of strings works
1333+
np.testing.assert_equal(
1334+
response.states[['x[1]', 'x[2]']], response.states[[1, 2]])
1335+
1336+
# Make sure errors are generated if key is unknown
1337+
with pytest.raises(ValueError, match="unknown signal name 'bad'"):
1338+
response.inputs['bad']
1339+
1340+
with pytest.raises(ValueError, match="unknown signal name 'bad'"):
1341+
response.states[['x[1]', 'bad']]
1342+
1343+
# Create a system response for a MIMO system
1344+
sys = ct.rss(4, 2, 2)
1345+
response = ct.step_response(sys)
1346+
1347+
# Make sure access via strings works
1348+
np.testing.assert_equal(
1349+
response.outputs['y[0]', 'u[1]'],
1350+
response.outputs[0, 1])
1351+
np.testing.assert_equal(
1352+
response.states['x[2]', 'u[0]'], response.states[2, 0])
1353+
1354+
# Make sure access via lists of strings works
1355+
np.testing.assert_equal(
1356+
response.states[['x[1]', 'x[2]'], 'u[0]'],
1357+
response.states[[1, 2], 0])
1358+
1359+
np.testing.assert_equal(
1360+
response.outputs[['y[1]'], ['u[1]', 'u[0]']],
1361+
response.outputs[[1], [1, 0]])
1362+
1363+
# Make sure errors are generated if key is unknown
1364+
with pytest.raises(ValueError, match="unknown signal name 'bad'"):
1365+
response.inputs['bad']
1366+
1367+
with pytest.raises(ValueError, match="unknown signal name 'bad'"):
1368+
response.states[['x[1]', 'bad']]
1369+
1370+
with pytest.raises(ValueError, match=r"unknown signal name 'x\[2\]'"):
1371+
response.states['x[1]', 'x[2]'] # second index = input name

control/timeresp.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080

8181
from . import config
8282
from .exception import pandas_check
83-
from .iosys import isctime, isdtime
83+
from .iosys import NamedSignal, isctime, isdtime
8484
from .timeplot import time_response_plot
8585

8686
__all__ = ['forced_response', 'step_response', 'step_info',
@@ -567,10 +567,11 @@ def outputs(self):
567567
:type: 1D, 2D, or 3D array
568568
569569
"""
570+
# TODO: move to __init__ to avoid recomputing each time?
570571
y = _process_time_response(
571572
self.y, issiso=self.issiso,
572573
transpose=self.transpose, squeeze=self.squeeze)
573-
return y
574+
return NamedSignal(y, self.output_labels, self.input_labels)
574575

575576
# Getter for states (implements squeeze processing)
576577
@property
@@ -586,6 +587,7 @@ def states(self):
586587
:type: 2D or 3D array
587588
588589
"""
590+
# TODO: move to __init__ to avoid recomputing each time?
589591
if self.x is None:
590592
return None
591593

@@ -606,7 +608,7 @@ def states(self):
606608
if self.transpose:
607609
x = np.transpose(x, np.roll(range(x.ndim), 1))
608610

609-
return x
611+
return NamedSignal(x, self.state_labels, self.input_labels)
610612

611613
# Getter for inputs (implements squeeze processing)
612614
@property
@@ -628,15 +630,17 @@ def inputs(self):
628630
:type: 1D or 2D array
629631
630632
"""
633+
# TODO: move to __init__ to avoid recomputing each time?
631634
if self.u is None:
632635
return None
633636

634637
u = _process_time_response(
635638
self.u, issiso=self.issiso,
636639
transpose=self.transpose, squeeze=self.squeeze)
637-
return u
640+
return NamedSignal(u, self.input_labels, self.input_labels)
638641

639642
# Getter for legacy state (implements non-standard squeeze processing)
643+
# TODO: remove when no longer needed
640644
@property
641645
def _legacy_states(self):
642646
"""Time response state vector (legacy version).

0 commit comments

Comments
 (0)