Skip to content

Commit e52fca6

Browse files
committed
add string-based indexing of transfer functions
1 parent 58922ce commit e52fca6

2 files changed

Lines changed: 33 additions & 49 deletions

File tree

control/tests/xferfcn_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,19 +390,20 @@ def test_pow(self):
390390
with pytest.raises(ValueError):
391391
TransferFunction.__pow__(sys1, 0.5)
392392

393-
def test_slice(self):
393+
@pytest.mark.parametrize("named", [False, True])
394+
def test_slice(self, named):
394395
sys = TransferFunction(
395396
[ [ [1], [2], [3]], [ [3], [4], [5]] ],
396397
[ [[1, 2], [1, 3], [1, 4]], [[1, 4], [1, 5], [1, 6]] ],
397398
inputs=['u0', 'u1', 'u2'], outputs=['y0', 'y1'], name='sys')
398399

399-
sys1 = sys[1:, 1:]
400+
sys1 = sys[1:, 1:] if not named else sys['y1', ['u1', 'u2']]
400401
assert (sys1.ninputs, sys1.noutputs) == (2, 1)
401402
assert sys1.input_labels == ['u1', 'u2']
402403
assert sys1.output_labels == ['y1']
403404
assert sys1.name == 'sys$indexed'
404405

405-
sys2 = sys[:2, :2]
406+
sys2 = sys[:2, :2] if not named else sys[['y0', 'y1'], ['u0', 'u1']]
406407
assert (sys2.ninputs, sys2.noutputs) == (2, 2)
407408
assert sys2.input_labels == ['u0', 'u1']
408409
assert sys2.output_labels == ['y0', 'y1']
@@ -411,7 +412,7 @@ def test_slice(self):
411412
sys = TransferFunction(
412413
[ [ [1], [2], [3]], [ [3], [4], [5]] ],
413414
[ [[1, 2], [1, 3], [1, 4]], [[1, 4], [1, 5], [1, 6]] ], 0.5)
414-
sys1 = sys[1:, 1:]
415+
sys1 = sys[1:, 1:] if not named else sys[['y[1]'], ['u[1]', 'u[2]']]
415416
assert (sys1.ninputs, sys1.noutputs) == (2, 1)
416417
assert sys1.dt == 0.5
417418
assert sys1.input_labels == ['u[1]', 'u[2]']

control/xferfcn.py

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,26 @@
4848
"""
4949

5050
from collections.abc import Iterable
51+
from copy import deepcopy
52+
from itertools import chain
53+
from re import sub
54+
from warnings import warn
5155

5256
# External function declarations
5357
import numpy as np
54-
from numpy import angle, array, empty, finfo, ndarray, ones, \
55-
polyadd, polymul, polyval, roots, sqrt, zeros, squeeze, exp, pi, \
56-
where, delete, real, poly, nonzero
5758
import scipy as sp
58-
from scipy.signal import tf2zpk, zpk2tf, cont2discrete
59+
from numpy import angle, array, delete, empty, exp, finfo, ndarray, nonzero, \
60+
ones, pi, poly, polyadd, polymul, polyval, real, roots, sqrt, squeeze, \
61+
where, zeros
5962
from scipy.signal import TransferFunction as signalTransferFunction
60-
from copy import deepcopy
61-
from warnings import warn
62-
from itertools import chain
63-
from re import sub
64-
from .lti import LTI, _process_frequency_response
65-
from .iosys import InputOutputSystem, common_timebase, isdtime, \
66-
_process_iosys_keywords
63+
from scipy.signal import cont2discrete, tf2zpk, zpk2tf
64+
65+
from . import config
6766
from .exception import ControlMIMONotImplemented
6867
from .frdata import FrequencyResponseData
69-
from . import config
68+
from .iosys import InputOutputSystem, NamedSignal, _process_iosys_keywords, \
69+
_process_subsys_index, common_timebase, isdtime
70+
from .lti import LTI, _process_frequency_response
7071

7172
__all__ = ['TransferFunction', 'tf', 'zpk', 'ss2tf', 'tfdata']
7273

@@ -761,48 +762,30 @@ def __pow__(self, other):
761762

762763
def __getitem__(self, key):
763764
if not isinstance(key, Iterable) or len(key) != 2:
764-
raise IOError('must provide indices of length 2 for transfer functions')
765+
raise IOError(
766+
"must provide indices of length 2 for transfer functions")
767+
768+
# Convert signal names to integer offsets (via NamedSignal object)
769+
iomap = NamedSignal(
770+
np.empty((self.noutputs, self.ninputs)),
771+
self.output_labels, self.input_labels)
772+
indices = iomap._parse_key(key)
773+
outdx, outputs = _process_subsys_index(
774+
indices[0], self.output_labels, slice_to_list=True)
775+
inpdx, inputs = _process_subsys_index(
776+
indices[1], self.input_labels, slice_to_list=True)
765777

766-
key1, key2 = key
767-
if not isinstance(key1, (int, slice)) or not isinstance(key2, (int, slice)):
768-
raise TypeError(f"system indices must be integers or slices")
769-
770-
# pre-process
771-
if isinstance(key1, int):
772-
key1 = slice(key1, key1 + 1, 1)
773-
if isinstance(key2, int):
774-
key2 = slice(key2, key2 + 1, 1)
775-
# dim1
776-
start1, stop1, step1 = key1.start, key1.stop, key1.step
777-
if step1 is None:
778-
step1 = 1
779-
if start1 is None:
780-
start1 = 0
781-
if stop1 is None:
782-
stop1 = len(self.num)
783-
# dim1
784-
start2, stop2, step2 = key2.start, key2.stop, key2.step
785-
if step2 is None:
786-
step2 = 1
787-
if start2 is None:
788-
start2 = 0
789-
if stop2 is None:
790-
stop2 = len(self.num[0])
791-
778+
# Construct the transfer function for the subsyste
792779
num, den = [], []
793-
for i in range(start1, stop1, step1):
780+
for i in outdx:
794781
num_i = []
795782
den_i = []
796-
for j in range(start2, stop2, step2):
783+
for j in inpdx:
797784
num_i.append(self.num[i][j])
798785
den_i.append(self.den[i][j])
799786
num.append(num_i)
800787
den.append(den_i)
801788

802-
# Save the label names
803-
outputs = [self.output_labels[i] for i in range(start1, stop1, step1)]
804-
inputs = [self.input_labels[j] for j in range(start2, stop2, step2)]
805-
806789
# Create the system name
807790
sysname = config.defaults['iosys.indexed_system_name_prefix'] + \
808791
self.name + config.defaults['iosys.indexed_system_name_suffix']

0 commit comments

Comments
 (0)