Skip to content

Commit 5f261cc

Browse files
committed
updated argument checking + unit tests (and coverage) + fixes
1 parent acc4439 commit 5f261cc

2 files changed

Lines changed: 146 additions & 29 deletions

File tree

control/optimal.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,15 @@ def __init__(
193193
# See whether we got entire guess or just first time point
194194
if len(initial_guess.shape) == 1:
195195
# Broadcast inputs to entire time vector
196-
initial_guess = np.broadcast_to(
197-
initial_guess.reshape(-1, 1),
198-
(self.system.ninputs, self.time_vector.size))
199-
elif len(initial_guess.shape) != 2:
196+
try:
197+
initial_guess = np.broadcast_to(
198+
initial_guess.reshape(-1, 1),
199+
(self.system.ninputs, self.time_vector.size))
200+
except:
201+
raise ValueError("initial guess is the wrong shape")
202+
203+
elif initial_guess.shape != \
204+
(self.system.ninputs, self.time_vector.size):
200205
raise ValueError("initial guess is the wrong shape")
201206

202207
# Reshape for use by scipy.optimize.minimize()
@@ -975,7 +980,13 @@ def state_poly_constraint(sys, A, b):
975980
A tuple consisting of the constraint type and parameter values.
976981
977982
"""
978-
# TODO: make sure the system and constraints are compatible
983+
# Convert arguments to arrays and make sure dimensions are right
984+
A = np.atleast_2d(A)
985+
b = np.atleast_1d(b)
986+
if len(A.shape) != 2 or A.shape[1] != sys.nstates:
987+
raise ValueError("polytope matrix must match number of states")
988+
elif len(b.shape) != 1 or A.shape[0] != b.shape[0]:
989+
raise ValueError("number of bounds must match number of constraints")
979990

980991
# Return a linear constraint object based on the polynomial
981992
return (opt.LinearConstraint,
@@ -1006,7 +1017,11 @@ def state_range_constraint(sys, lb, ub):
10061017
A tuple consisting of the constraint type and parameter values.
10071018
10081019
"""
1009-
# TODO: make sure the system and constraints are compatible
1020+
# Convert bounds to lists and make sure they are the right dimension
1021+
lb = np.atleast_1d(lb)
1022+
ub = np.atleast_1d(ub)
1023+
if lb.shape != (sys.nstates,) or ub.shape != (sys.nstates,):
1024+
raise ValueError("state bounds must match number of states")
10101025

10111026
# Return a linear constraint object based on the polynomial
10121027
return (opt.LinearConstraint,
@@ -1037,7 +1052,13 @@ def input_poly_constraint(sys, A, b):
10371052
A tuple consisting of the constraint type and parameter values.
10381053
10391054
"""
1040-
# TODO: make sure the system and constraints are compatible
1055+
# Convert arguments to arrays and make sure dimensions are right
1056+
A = np.atleast_2d(A)
1057+
b = np.atleast_1d(b)
1058+
if len(A.shape) != 2 or A.shape[1] != sys.ninputs:
1059+
raise ValueError("polytope matrix must match number of inputs")
1060+
elif len(b.shape) != 1 or A.shape[0] != b.shape[0]:
1061+
raise ValueError("number of bounds must match number of constraints")
10411062

10421063
# Return a linear constraint object based on the polynomial
10431064
return (opt.LinearConstraint,
@@ -1069,13 +1090,17 @@ def input_range_constraint(sys, lb, ub):
10691090
A tuple consisting of the constraint type and parameter values.
10701091
10711092
"""
1072-
# TODO: make sure the system and constraints are compatible
1093+
# Convert bounds to lists and make sure they are the right dimension
1094+
lb = np.atleast_1d(lb)
1095+
ub = np.atleast_1d(ub)
1096+
if lb.shape != (sys.ninputs,) or ub.shape != (sys.ninputs,):
1097+
raise ValueError("input bounds must match number of inputs")
10731098

10741099
# Return a linear constraint object based on the polynomial
10751100
return (opt.LinearConstraint,
10761101
np.hstack(
10771102
[np.zeros((sys.ninputs, sys.nstates)), np.eye(sys.ninputs)]),
1078-
np.array(lb), np.array(ub))
1103+
lb, ub)
10791104

10801105

10811106
#
@@ -1112,15 +1137,17 @@ def output_poly_constraint(sys, A, b):
11121137
A tuple consisting of the constraint type and parameter values.
11131138
11141139
"""
1115-
# TODO: make sure the system and constraints are compatible
1140+
# Convert arguments to arrays and make sure dimensions are right
1141+
A = np.atleast_2d(A)
1142+
b = np.atleast_1d(b)
1143+
if len(A.shape) != 2 or A.shape[1] != sys.noutputs:
1144+
raise ValueError("polytope matrix must match number of outputs")
1145+
elif len(b.shape) != 1 or A.shape[0] != b.shape[0]:
1146+
raise ValueError("number of bounds must match number of constraints")
11161147

11171148
# Function to create the output
1118-
def _evaluate_output_poly_constraint(x):
1119-
# Separate the constraint into states and inputs
1120-
states = x[:sys.nstates]
1121-
inputs = x[sys.nstates:]
1122-
outputs = sys._out(0, states, inputs)
1123-
return A @ outputs
1149+
def _evaluate_output_poly_constraint(x, u):
1150+
return A @ sys._out(0, x, u)
11241151

11251152
# Return a nonlinear constraint object based on the polynomial
11261153
return (opt.NonlinearConstraint,
@@ -1151,14 +1178,16 @@ def output_range_constraint(sys, lb, ub):
11511178
A tuple consisting of the constraint type and parameter values.
11521179
11531180
"""
1154-
# TODO: make sure the system and constraints are compatible
1181+
# Convert bounds to lists and make sure they are the right dimension
1182+
lb = np.atleast_1d(lb)
1183+
ub = np.atleast_1d(ub)
1184+
if lb.shape != (sys.noutputs,) or ub.shape != (sys.noutputs,):
1185+
raise ValueError("output bounds must match number of outputs")
11551186

11561187
# Function to create the output
1157-
def _evaluate_output_range_constraint(x):
1188+
def _evaluate_output_range_constraint(x, u):
11581189
# Separate the constraint into states and inputs
1159-
states = x[:sys.nstates]
1160-
inputs = x[sys.nstates:]
1161-
outputs = sys._out(0, states, inputs)
1190+
return sys._out(0, x, u)
11621191

11631192
# Return a nonlinear constraint object based on the polynomial
11641193
return (opt.NonlinearConstraint, _evaluate_output_range_constraint, lb, ub)

control/tests/optimal_test.py

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_discrete_lqr():
7777

7878
# Compute the integral and terminal cost
7979
integral_cost = opt.quadratic_cost(sys, Q, R)
80-
terminal_cost = opt.quadratic_cost(sys, S, 0)
80+
terminal_cost = opt.quadratic_cost(sys, S, None)
8181

8282
# Formulate finite horizon MPC problem
8383
time = np.arange(0, 5, 1)
@@ -171,6 +171,11 @@ def test_mpc_iosystem():
171171
[(opt.state_poly_constraint,
172172
np.array([[1, 0], [0, 1], [-1, 0], [0, -1]]), [5, 5, 5, 5]),
173173
(opt.input_poly_constraint, np.array([[1], [-1]]), [1, 1])],
174+
[(opt.output_range_constraint, [-5, -5], [5, 5]),
175+
(opt.input_poly_constraint, np.array([[1], [-1]]), [1, 1])],
176+
[(opt.output_poly_constraint,
177+
np.array([[1, 0], [0, 1], [-1, 0], [0, -1]]), [5, 5, 5, 5]),
178+
(opt.input_poly_constraint, np.array([[1], [-1]]), [1, 1])],
174179
[(sp.optimize.NonlinearConstraint,
175180
lambda x, u: np.array([x[0], x[1], u[0]]), [-5, -5, -1], [5, 5, 1])],
176181
])
@@ -258,6 +263,10 @@ def test_terminal_constraints(sys_args):
258263
np.testing.assert_allclose(
259264
x1, np.kron(x0.reshape((2, 1)), time[::-1]/Tf), atol=0.1, rtol=0.01)
260265

266+
# Re-run using initial guess = optional and make sure nothing chnages
267+
res = optctrl.compute_trajectory(x0, initial_guess=u1)
268+
np.testing.assert_almost_equal(res.inputs, u1)
269+
261270
# Impose some cost on the state, which should change the path
262271
Q = np.eye(2)
263272
R = np.eye(2) * 0.1
@@ -305,22 +314,101 @@ def test_terminal_constraints(sys_args):
305314

306315
def test_optimal_logging(capsys):
307316
"""Test logging functions (mainly for code coverage)"""
308-
sys = ct.ss2io(ct.ss([[1, 1], [0, 1]], [[1], [0.5]], np.eye(2), 0, 1))
317+
sys = ct.ss2io(ct.ss(np.eye(2), np.eye(2), np.eye(2), 0, 1))
309318

310319
# Set up the optimal control problem
311320
cost = opt.quadratic_cost(sys, 1, 1)
312321
state_constraint = opt.state_range_constraint(
313-
sys, [-np.inf, -10], [10, np.inf])
314-
input_constraint = opt.input_range_constraint(sys, -100, 100)
322+
sys, [-np.inf, 1], [10, 1])
323+
input_constraint = opt.input_range_constraint(sys, [-100, -100], [100, 100])
315324
time = np.arange(0, 3, 1)
316325
x0 = [-1, 1]
317326

318-
# Solve it, with logging turned on
319-
res = opt.solve_ocp(
320-
sys, time, x0, cost, input_constraint, terminal_cost=cost,
321-
terminal_constraints=state_constraint, log=True)
327+
# Solve it, with logging turned on (with warning due to mixed constraints)
328+
with pytest.warns(sp.optimize.optimize.OptimizeWarning,
329+
match="Equality and inequality .* same element"):
330+
res = opt.solve_ocp(
331+
sys, time, x0, cost, input_constraint, terminal_cost=cost,
332+
terminal_constraints=state_constraint, log=True)
322333

323334
# Make sure the output has info available only with logging turned on
324335
captured = capsys.readouterr()
325336
assert captured.out.find("process time") != -1
326337

338+
339+
@pytest.mark.parametrize("fun, args, exception, match", [
340+
[opt.quadratic_cost, (np.zeros((2, 3)), np.eye(2)), ValueError,
341+
"Q matrix is the wrong shape"],
342+
[opt.quadratic_cost, (np.eye(2), 1), ValueError,
343+
"R matrix is the wrong shape"],
344+
])
345+
def test_constraint_constructor_errors(fun, args, exception, match):
346+
"""Test various error conditions for constraint constructors"""
347+
sys = ct.ss2io(ct.rss(2, 2, 2))
348+
with pytest.raises(exception, match=match):
349+
fun(sys, *args)
350+
351+
352+
@pytest.mark.parametrize("fun, args, exception, match", [
353+
[opt.input_poly_constraint, (np.zeros((2, 3)), [0, 0]), ValueError,
354+
"polytope matrix must match number of inputs"],
355+
[opt.output_poly_constraint, (np.zeros((2, 3)), [0, 0]), ValueError,
356+
"polytope matrix must match number of outputs"],
357+
[opt.state_poly_constraint, (np.zeros((2, 3)), [0, 0]), ValueError,
358+
"polytope matrix must match number of states"],
359+
[opt.input_poly_constraint, (np.zeros((2, 2)), [0, 0, 0]), ValueError,
360+
"number of bounds must match number of constraints"],
361+
[opt.output_poly_constraint, (np.zeros((2, 2)), [0, 0, 0]), ValueError,
362+
"number of bounds must match number of constraints"],
363+
[opt.state_poly_constraint, (np.zeros((2, 2)), [0, 0, 0]), ValueError,
364+
"number of bounds must match number of constraints"],
365+
[opt.input_poly_constraint, (np.zeros((2, 2)), [[0, 0, 0]]), ValueError,
366+
"number of bounds must match number of constraints"],
367+
[opt.output_poly_constraint, (np.zeros((2, 2)), [[0, 0, 0]]), ValueError,
368+
"number of bounds must match number of constraints"],
369+
[opt.state_poly_constraint, (np.zeros((2, 2)), 0), ValueError,
370+
"number of bounds must match number of constraints"],
371+
[opt.input_range_constraint, ([1, 2, 3], [0, 0]), ValueError,
372+
"input bounds must match"],
373+
[opt.output_range_constraint, ([2, 3], [0, 0, 0]), ValueError,
374+
"output bounds must match"],
375+
[opt.state_range_constraint, ([1, 2, 3], [0, 0, 0]), ValueError,
376+
"state bounds must match"],
377+
])
378+
def test_constraint_constructor_errors(fun, args, exception, match):
379+
"""Test various error conditions for constraint constructors"""
380+
sys = ct.ss2io(ct.rss(2, 2, 2))
381+
with pytest.raises(exception, match=match):
382+
fun(sys, *args)
383+
384+
385+
def test_ocp_argument_errors():
386+
sys = ct.ss2io(ct.ss([[1, 1], [0, 1]], [[1], [0.5]], np.eye(2), 0, 1))
387+
388+
# State and input constraints
389+
constraints = [
390+
(sp.optimize.LinearConstraint, np.eye(3), [-5, -5, -1], [5, 5, 1]),
391+
]
392+
393+
# Quadratic state and input penalty
394+
Q = [[1, 0], [0, 1]]
395+
R = [[1]]
396+
cost = opt.quadratic_cost(sys, Q, R)
397+
398+
# Set up the optimal control problem
399+
time = np.arange(0, 5, 1)
400+
x0 = [4, 0]
401+
402+
# Trajectory constraints not in the right form
403+
with pytest.raises(TypeError, match="constraints must be a list"):
404+
res = opt.solve_ocp(sys, time, x0, cost, np.eye(2))
405+
406+
# Terminal constraints not in the right form
407+
with pytest.raises(TypeError, match="constraints must be a list"):
408+
res = opt.solve_ocp(
409+
sys, time, x0, cost, constraints, terminal_constraints=np.eye(2))
410+
411+
# Initial guess in the wrong shape
412+
with pytest.raises(ValueError, match="initial guess is the wrong shape"):
413+
res = opt.solve_ocp(
414+
sys, time, x0, cost, constraints, initial_guess=np.zeros((4,1,1)))

0 commit comments

Comments
 (0)