Skip to content

Commit acc4439

Browse files
committed
update unit tests for speed and coverage
1 parent 9494092 commit acc4439

2 files changed

Lines changed: 40 additions & 2 deletions

File tree

control/optimal.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,27 @@ def __init__(
104104
self.system = sys
105105
self.time_vector = time_vector
106106
self.integral_cost = integral_cost
107-
self.trajectory_constraints = trajectory_constraints
108107
self.terminal_cost = terminal_cost
109108
self.terminal_constraints = terminal_constraints
110109
self.kwargs = kwargs
111110

111+
# Process trajectory constraints
112+
if isinstance(trajectory_constraints, tuple):
113+
self.trajectory_constraints = [trajectory_constraints]
114+
elif not isinstance(trajectory_constraints, list):
115+
raise TypeError("trajectory constraints must be a list")
116+
else:
117+
self.trajectory_constraints = trajectory_constraints
118+
119+
# Process terminal constraints
120+
if isinstance(terminal_constraints, tuple):
121+
self.terminal_constraints = [terminal_constraints]
122+
elif not isinstance(terminal_constraints, list):
123+
raise TypeError("terminal constraints must be a list")
124+
else:
125+
self.terminal_constraints = terminal_constraints
126+
127+
112128
#
113129
# Compute and store constraints
114130
#

control/tests/optimal_test.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def test_terminal_constraints(sys_args):
230230
final_point = [opt.state_range_constraint(sys, [0, 0], [0, 0])]
231231

232232
# Create the optimal control problem
233-
time = np.arange(0, 5, 1)
233+
time = np.arange(0, 3, 1)
234234
optctrl = opt.OptimalControlProblem(
235235
sys, time, cost, terminal_constraints=final_point)
236236

@@ -302,3 +302,25 @@ def test_terminal_constraints(sys_args):
302302
with pytest.warns(UserWarning, match="unable to solve"):
303303
res = optctrl.compute_trajectory(x0, squeeze=True, return_x=True)
304304
assert not res.success
305+
306+
def test_optimal_logging(capsys):
307+
"""Test logging functions (mainly for code coverage)"""
308+
sys = ct.ss2io(ct.ss([[1, 1], [0, 1]], [[1], [0.5]], np.eye(2), 0, 1))
309+
310+
# Set up the optimal control problem
311+
cost = opt.quadratic_cost(sys, 1, 1)
312+
state_constraint = opt.state_range_constraint(
313+
sys, [-np.inf, -10], [10, np.inf])
314+
input_constraint = opt.input_range_constraint(sys, -100, 100)
315+
time = np.arange(0, 3, 1)
316+
x0 = [-1, 1]
317+
318+
# Solve it, with logging turned on
319+
res = opt.solve_ocp(
320+
sys, time, x0, cost, input_constraint, terminal_cost=cost,
321+
terminal_constraints=state_constraint, log=True)
322+
323+
# Make sure the output has info available only with logging turned on
324+
captured = capsys.readouterr()
325+
assert captured.out.find("process time") != -1
326+

0 commit comments

Comments
 (0)