Skip to content

Commit a1dfc5e

Browse files
committed
add signal name kwargs for create_mpc_iosysm + kwargs testing for ct.optimal
1 parent 39daef6 commit a1dfc5e

File tree

4 files changed

+67
-15
lines changed

4 files changed

+67
-15
lines changed

control/optimal.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from . import config
2020
from .exception import ControlNotImplemented
21-
from .timeresp import TimeResponseData
2221

2322
# Define module default parameter values
2423
_optimal_trajectory_methods = {'shooting', 'collocation'}
@@ -140,7 +139,8 @@ class OptimalControlProblem():
140139
def __init__(
141140
self, sys, timepts, integral_cost, trajectory_constraints=[],
142141
terminal_cost=None, terminal_constraints=[], initial_guess=None,
143-
trajectory_method=None, basis=None, log=False, **kwargs):
142+
trajectory_method=None, basis=None, log=False, kwargs_check=True,
143+
**kwargs):
144144
"""Set up an optimal control problem."""
145145
# Save the basic information for use later
146146
self.system = sys
@@ -183,7 +183,7 @@ def __init__(
183183
" discrete time systems")
184184

185185
# Make sure there were no extraneous keywords
186-
if kwargs:
186+
if kwargs_check and kwargs:
187187
raise TypeError("unrecognized keyword(s): ", str(kwargs))
188188

189189
self.trajectory_constraints = _process_constraints(
@@ -829,7 +829,7 @@ def compute_mpc(self, x, squeeze=None):
829829
return res.inputs[:, 0]
830830

831831
# Create an input/output system implementing an MPC controller
832-
def create_mpc_iosystem(self):
832+
def create_mpc_iosystem(self, **kwargs):
833833
"""Create an I/O system implementing an MPC controller"""
834834
# Check to make sure we are in discrete time
835835
if self.system.dt == 0:
@@ -857,11 +857,17 @@ def _output(t, x, u, params={}):
857857
res = self.compute_trajectory(u, print_summary=False)
858858
return res.inputs[:, 0]
859859

860+
# Define signal names, if they are not already given
861+
if not kwargs.get('inputs'):
862+
kwargs['inputs'] = self.system.state_labels
863+
if not kwargs.get('outputs'):
864+
kwargs['outputs'] = self.system.input_labels
865+
if not kwargs.get('states'):
866+
kwargs['states'] = self.system.ninputs * \
867+
(self.timepts.size if self.basis is None else self.basis.N)
868+
860869
return ct.NonlinearIOSystem(
861-
_update, _output, dt=self.system.dt,
862-
inputs=self.system.nstates, outputs=self.system.ninputs,
863-
states=self.system.ninputs * \
864-
(self.timepts.size if self.basis is None else self.basis.N))
870+
_update, _output, dt=self.system.dt, **kwargs)
865871

866872

867873
# Optimal control result
@@ -923,7 +929,7 @@ def __init__(
923929
print("* Final cost:", self.cost)
924930

925931
# Process data as a time response (with "outputs" = inputs)
926-
response = TimeResponseData(
932+
response = ct.TimeResponseData(
927933
ocp.timepts, inputs, states, issiso=ocp.system.issiso(),
928934
transpose=transpose, return_x=return_states, squeeze=squeeze)
929935

@@ -1129,10 +1135,10 @@ def create_mpc_iosystem(
11291135
ocp = OptimalControlProblem(
11301136
sys, horizon, cost, trajectory_constraints=constraints,
11311137
terminal_cost=terminal_cost, terminal_constraints=terminal_constraints,
1132-
log=log, **kwargs)
1138+
log=log, kwargs_check=False, **kwargs)
11331139

11341140
# Return an I/O system implementing the model predictive controller
1135-
return ocp.create_mpc_iosystem()
1141+
return ocp.create_mpc_iosystem(**kwargs)
11361142

11371143

11381144
#

control/tests/kwargs_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
import control.tests.flatsys_test as flatsys_test
2323
import control.tests.frd_test as frd_test
2424
import control.tests.interconnect_test as interconnect_test
25+
import control.tests.optimal_test as optimal_test
2526
import control.tests.statefbk_test as statefbk_test
2627
import control.tests.trdata_test as trdata_test
2728

2829

2930
@pytest.mark.parametrize("module, prefix", [
30-
(control, ""), (control.flatsys, "flatsys.")
31+
(control, ""), (control.flatsys, "flatsys."), (control.optimal, "optimal.")
3132
])
3233
def test_kwarg_search(module, prefix):
3334
# Look through every object in the package
@@ -158,7 +159,7 @@ def test_matplotlib_kwargs(function, nsysargs, moreargs, kwargs, mplcleanup):
158159
kwarg_unittest = {
159160
'bode': test_matplotlib_kwargs,
160161
'bode_plot': test_matplotlib_kwargs,
161-
'create_statefbk_iosystem': statefbk_test.TestStatefbk.test_statefbk_iosys,
162+
'create_statefbk_iosystem': statefbk_test.TestStatefbk.test_statefbk_errors,
162163
'describing_function_plot': test_matplotlib_kwargs,
163164
'dlqe': test_unrecognized_kwargs,
164165
'dlqr': test_unrecognized_kwargs,
@@ -191,6 +192,8 @@ def test_matplotlib_kwargs(function, nsysargs, moreargs, kwargs, mplcleanup):
191192
flatsys_test.TestFlatSys.test_point_to_point_errors,
192193
'flatsys.solve_flat_ocp':
193194
flatsys_test.TestFlatSys.test_solve_flat_ocp_errors,
195+
'optimal.create_mpc_iosystem': optimal_test.test_mpc_iosystem_rename,
196+
'optimal.solve_ocp': optimal_test.test_ocp_argument_errors,
194197
'FrequencyResponseData.__init__':
195198
frd_test.TestFRD.test_unrecognized_keyword,
196199
'InputOutputSystem.__init__': test_unrecognized_kwargs,
@@ -205,7 +208,13 @@ def test_matplotlib_kwargs(function, nsysargs, moreargs, kwargs, mplcleanup):
205208
'StateSpace.sample': test_unrecognized_kwargs,
206209
'TimeResponseData.__call__': trdata_test.test_response_copy,
207210
'TransferFunction.__init__': test_unrecognized_kwargs,
208-
'TransferFunction.sample': test_unrecognized_kwargs,
211+
'TransferFunction.sample': test_unrecognized_kwargs,
212+
'optimal.OptimalControlProblem.__init__':
213+
optimal_test.test_ocp_argument_errors,
214+
'optimal.OptimalControlProblem.compute_trajectory':
215+
optimal_test.test_ocp_argument_errors,
216+
'optimal.OptimalControlProblem.create_mpc_iosystem':
217+
optimal_test.test_mpc_iosystem_rename,
209218
}
210219

211220
#

control/tests/optimal_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,35 @@ def test_mpc_iosystem_aircraft():
214214
xout[0:sys.nstates, -1], xd, atol=0.1, rtol=0.01)
215215

216216

217+
def test_mpc_iosystem_rename():
218+
# Create a discrete time system (double integrator) + cost function
219+
sys = ct.ss([[1, 1], [0, 1]], [[0], [1]], np.eye(2), 0, dt=True)
220+
cost = opt.quadratic_cost(sys, np.eye(2), np.eye(1))
221+
timepts = np.arange(0, 5)
222+
223+
# Create the default optimal control problem and check labels
224+
mpc = opt.create_mpc_iosystem(sys, timepts, cost)
225+
assert mpc.input_labels == sys.state_labels
226+
assert mpc.output_labels == sys.input_labels
227+
228+
# Change the signal names
229+
input_relabels = ['x1', 'x2']
230+
output_relabels = ['u']
231+
state_relabels = [f'x_[{i}]' for i in timepts]
232+
mpc_relabeled = opt.create_mpc_iosystem(
233+
sys, timepts, cost, inputs=input_relabels, outputs=output_relabels,
234+
states=state_relabels, name='mpc_relabeled')
235+
assert mpc_relabeled.input_labels == input_relabels
236+
assert mpc_relabeled.output_labels == output_relabels
237+
assert mpc_relabeled.state_labels == state_relabels
238+
assert mpc_relabeled.name == 'mpc_relabeled'
239+
240+
# Make sure that unknown keywords are caught
241+
# Unrecognized arguments
242+
with pytest.raises(TypeError, match="unrecognized keyword"):
243+
mpc = opt.create_mpc_iosystem(sys, timepts, cost, unknown=None)
244+
245+
217246
def test_mpc_iosystem_continuous():
218247
# Create a random state space system
219248
sys = ct.rss(2, 1, 1)
@@ -492,6 +521,14 @@ def test_ocp_argument_errors():
492521
res = opt.solve_ocp(
493522
sys, time, x0, cost, constraints, terminal_constraint=None)
494523

524+
with pytest.raises(TypeError, match="unrecognized keyword"):
525+
ocp = opt.OptimalControlProblem(
526+
sys, time, x0, cost, constraints, terminal_constraint=None)
527+
528+
with pytest.raises(TypeError, match="unrecognized keyword"):
529+
ocp = opt.OptimalControlProblem(sys, time, cost, constraints)
530+
ocp.compute_trajectory(x0, unknown=None)
531+
495532
# Unrecognized trajectory constraint type
496533
constraints = [(None, np.eye(3), [0, 0, 0], [0, 0, 0])]
497534
with pytest.raises(TypeError, match="unknown trajectory constraint type"):

control/tests/statefbk_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def test_statefbk_errors(self):
778778
ctrl, clsys = ct.create_statefbk_iosystem(
779779
sys, K, type='nonlinear', controller_type='nonlinear')
780780

781-
with pytest.raises(TypeError, match="unrecognized keywords"):
781+
with pytest.raises(TypeError, match="unrecognized keyword"):
782782
ctrl, clsys = ct.create_statefbk_iosystem(sys, K, typo='nonlinear')
783783

784784
with pytest.raises(ControlArgument, match="unknown controller_type"):

0 commit comments

Comments
 (0)