Skip to content

Commit e4d373c

Browse files
committed
update default params computation in interconnect()
1 parent 8f3615b commit e4d373c

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

control/nlsys.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,11 @@ def __init__(self, syslist, connections=None, inplist=None, outlist=None,
710710
if outputs is None and outlist is not None:
711711
outputs = len(outlist)
712712

713+
if params is None:
714+
params = {}
715+
for sys in self.syslist:
716+
params = params | sys.params
717+
713718
# Create updfcn and outfcn
714719
def updfcn(t, x, u, params):
715720
self._update_params(params)
@@ -2268,7 +2273,8 @@ def interconnect(
22682273
22692274
params : dict, optional
22702275
Parameter values for the systems. Passed to the evaluation functions
2271-
for the system as default values, overriding internal defaults.
2276+
for the system as default values, overriding internal defaults. If
2277+
not specified, defaults to parameters from subsystems.
22722278
22732279
dt : timebase, optional
22742280
The timebase for the system, used to specify whether the system is

control/tests/interconnect_test.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -666,15 +666,29 @@ def test_interconnect_params():
666666
# Create a nominally unstable system
667667
sys1 = ct.nlsys(
668668
lambda t, x, u, params: params['a'] * x[0] + u[0],
669-
states=1, inputs='u', outputs='y', params={'a': 1})
669+
states=1, inputs='u', outputs='y', params={'a': 2, 'c':2})
670670

671671
# Simple system for serial interconnection
672672
sys2 = ct.nlsys(
673673
None, lambda t, x, u, params: u[0],
674-
inputs='r', outputs='u')
674+
inputs='r', outputs='u', params={'a': 4, 'b': 3})
675675

676-
# Create a series interconnection
676+
# Make sure default parameters get set as expected
677677
sys = ct.interconnect([sys1, sys2], inputs='r', outputs='y')
678+
assert sys.params == {'a': 4, 'c': 2, 'b': 3}
679+
assert sys.dynamics(0, [1], [0]).item() == 4
680+
681+
# Make sure we can override the parameters
682+
sys = ct.interconnect(
683+
[sys1, sys2], inputs='r', outputs='y', params={'b': 1})
684+
assert sys.params == {'b': 1}
685+
assert sys.dynamics(0, [1], [0]).item() == 2
686+
assert sys.dynamics(0, [1], [0], params={'a': 5}).item() == 5
687+
688+
# Create final series interconnection, with proper parameter values
689+
sys = ct.interconnect(
690+
[sys1, sys2], inputs='r', outputs='y', params={'a': 1})
691+
assert sys.params == {'a': 1}
678692

679693
# Make sure we can call the update function
680694
sys.updfcn(0, [0], [0], {})

control/tests/iosys_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,8 @@ def test_params(self, tsys):
931931
ios_secord_update = ct.NonlinearIOSystem(
932932
secord_update, secord_output, inputs=1, outputs=1, states=2,
933933
params={'omega0':2, 'zeta':0})
934+
lin_secord_update = ct.linearize(ios_secord_update, [0, 0], [0])
935+
w_update, v_update = np.linalg.eig(lin_secord_update.A)
934936

935937
# Make sure the default parameters haven't changed
936938
lin_secord_check = ct.linearize(ios_secord_default, [0, 0], [0])
@@ -960,7 +962,7 @@ def test_params(self, tsys):
960962
ios_series_default_local, [0, 0, 0, 0], [0])
961963
w, v = np.linalg.eig(lin_series_default_local.A)
962964
np.testing.assert_array_almost_equal(
963-
np.sort(w), np.sort(np.concatenate((w_default, [2j, -2j]))))
965+
w, np.concatenate([w_update, w_update]))
964966

965967
# Show that we can change the parameters at linearization
966968
lin_series_override = ct.linearize(

0 commit comments

Comments
 (0)