Skip to content

Commit a63671f

Browse files
committed
fix: preserve signal names in scalar LTI ops
1 parent 146ccee commit a63671f

3 files changed

Lines changed: 63 additions & 8 deletions

File tree

control/statesp.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,10 @@ def fmt_matrix(matrix, name):
555555
# Negation of a system
556556
def __neg__(self):
557557
"""Negate a state space system."""
558-
return StateSpace(self.A, self.B, -self.C, -self.D, self.dt)
558+
return StateSpace(
559+
self.A, self.B, -self.C, -self.D, self.dt,
560+
inputs=self.input_labels, outputs=self.output_labels,
561+
states=self.state_labels)
559562

560563
# Addition of two state space systems (parallel interconnection)
561564
def __add__(self, other):
@@ -645,7 +648,10 @@ def __mul__(self, other):
645648
A, C = self.A, self.C
646649
B = self.B * other
647650
D = self.D * other
648-
dt = self.dt
651+
return StateSpace(
652+
A, B, C, D, self.dt,
653+
inputs=self.input_labels, outputs=self.output_labels,
654+
states=self.state_labels)
649655

650656
elif isinstance(other, np.ndarray):
651657
other = np.atleast_2d(other)
@@ -706,7 +712,10 @@ def __rmul__(self, other):
706712
# Just multiplying by a scalar; change the input
707713
B = other * self.B
708714
D = other * self.D
709-
return StateSpace(self.A, B, self.C, D, self.dt)
715+
return StateSpace(
716+
self.A, B, self.C, D, self.dt,
717+
inputs=self.input_labels, outputs=self.output_labels,
718+
states=self.state_labels)
710719

711720
elif isinstance(other, np.ndarray):
712721
other = np.atleast_2d(other)

control/tests/namedio_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,40 @@ def test_named_ss():
5959
"<StateSpace random: ['u1'] -> ['y1', 'y2']>"
6060

6161

62+
@pytest.mark.parametrize(
63+
"sys",
64+
[
65+
ct.ss(
66+
[[-1, 0], [0, -2]], [[1, 0], [0, 1]],
67+
[[1, 0], [0, 1]], [[0, 0], [0, 0]],
68+
inputs=['e1', 'e2'], outputs=['u1', 'u2'],
69+
states=['x1', 'x2'], name='plant'),
70+
ct.tf(
71+
[[[1], [2]], [[3], [4]]],
72+
[[[1, 1], [1, 2]], [[1, 3], [1, 4]]],
73+
inputs=['e1', 'e2'], outputs=['u1', 'u2'], name='plant'),
74+
],
75+
)
76+
@pytest.mark.parametrize(
77+
"operation",
78+
[
79+
lambda sys: -sys,
80+
lambda sys: sys * 2,
81+
lambda sys: 2 * sys,
82+
lambda sys: ct.negate(sys),
83+
lambda sys: ct.series(sys, 2),
84+
lambda sys: ct.series(2, sys),
85+
],
86+
)
87+
def test_named_scalar_operations_preserve_signal_names(sys, operation):
88+
result = operation(sys)
89+
90+
assert result.input_labels == sys.input_labels
91+
assert result.output_labels == sys.output_labels
92+
if isinstance(sys, ct.StateSpace):
93+
assert result.state_labels == sys.state_labels
94+
95+
6296
# List of classes that are expected
6397
fun_instance = {
6498
ct.rss: (ct.NonlinearIOSystem, ct.StateSpace, ct.StateSpace),

control/xferfcn.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,9 @@ def __neg__(self):
554554
for i in range(self.noutputs):
555555
for j in range(self.ninputs):
556556
num[i, j] *= -1
557-
return TransferFunction(num, self.den, self.dt)
557+
return TransferFunction(
558+
num, self.den, self.dt,
559+
inputs=self.input_labels, outputs=self.output_labels)
558560

559561
def __add__(self, other):
560562
"""Add two LTI objects (parallel connection)."""
@@ -620,8 +622,13 @@ def __mul__(self, other):
620622
if isinstance(other, (StateSpace, np.ndarray)):
621623
other = _convert_to_transfer_function(other)
622624
elif isinstance(other, (int, float, complex, np.number)):
623-
# Multiply by a scaled identity matrix (transfer function)
624-
other = _convert_to_transfer_function(np.eye(self.ninputs) * other)
625+
num = deepcopy(self.num_array)
626+
for i in range(self.noutputs):
627+
for j in range(self.ninputs):
628+
num[i, j] *= other
629+
return TransferFunction(
630+
num, self.den, self.dt,
631+
inputs=self.input_labels, outputs=self.output_labels)
625632
if not isinstance(other, TransferFunction):
626633
return NotImplemented
627634

@@ -669,8 +676,13 @@ def __rmul__(self, other):
669676

670677
# Convert the second argument to a transfer function.
671678
if isinstance(other, (int, float, complex, np.number)):
672-
# Multiply by a scaled identity matrix (transfer function)
673-
other = _convert_to_transfer_function(np.eye(self.noutputs) * other)
679+
num = deepcopy(self.num_array)
680+
for i in range(self.noutputs):
681+
for j in range(self.ninputs):
682+
num[i, j] *= other
683+
return TransferFunction(
684+
num, self.den, self.dt,
685+
inputs=self.input_labels, outputs=self.output_labels)
674686
else:
675687
other = _convert_to_transfer_function(other)
676688

0 commit comments

Comments
 (0)