Skip to content

Commit 08f6464

Browse files
committed
improved flat system benchmarking (+ docstring, unit test updates)
1 parent e14d7b3 commit 08f6464

3 files changed

Lines changed: 96 additions & 13 deletions

File tree

benchmarks/flatsys_bench.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
import control.flatsys as flat
1212
import control.optimal as opt
1313

14+
#
15+
# System setup: vehicle steering (bicycle model)
16+
#
17+
1418
# Vehicle steering dynamics
1519
def vehicle_update(t, x, u, params):
1620
# Get the parameters for the model
@@ -67,11 +71,28 @@ def vehicle_reverse(zflag, params={}):
6771
# Define the time points where the cost/constraints will be evaluated
6872
timepts = np.linspace(0, Tf, 10, endpoint=True)
6973

70-
def time_steering_point_to_point(basis_name, basis_size):
71-
if basis_name == 'poly':
72-
basis = flat.PolyFamily(basis_size)
73-
elif basis_name == 'bezier':
74-
basis = flat.BezierFamily(basis_size)
74+
#
75+
# Benchmark test parameters
76+
#
77+
78+
basis_params = (['poly', 'bezier', 'bspline'], [8, 10, 12])
79+
basis_param_names = ["basis", "size"]
80+
81+
def get_basis(name, size):
82+
if name == 'poly':
83+
basis = flat.PolyFamily(size, T=Tf)
84+
elif name == 'bezier':
85+
basis = flat.BezierFamily(size, T=Tf)
86+
elif name == 'bspline':
87+
basis = flat.BSplineFamily([0, Tf/2, Tf], size)
88+
return basis
89+
90+
#
91+
# Benchmarks
92+
#
93+
94+
def time_point_to_point(basis_name, basis_size):
95+
basis = get_basis(basis_name, basis_size)
7596

7697
# Find trajectory between initial and final conditions
7798
traj = flat.point_to_point(vehicle, Tf, x0, u0, xf, uf, basis=basis)
@@ -80,13 +101,16 @@ def time_steering_point_to_point(basis_name, basis_size):
80101
x, u = traj.eval([0, Tf])
81102
np.testing.assert_array_almost_equal(x0, x[:, 0])
82103
np.testing.assert_array_almost_equal(u0, u[:, 0])
83-
np.testing.assert_array_almost_equal(xf, x[:, 1])
84-
np.testing.assert_array_almost_equal(uf, u[:, 1])
104+
np.testing.assert_array_almost_equal(xf, x[:, -1])
105+
np.testing.assert_array_almost_equal(uf, u[:, -1])
106+
107+
time_point_to_point.params = basis_params
108+
time_point_to_point.param_names = basis_param_names
85109

86-
time_steering_point_to_point.params = (['poly', 'bezier'], [6, 8])
87-
time_steering_point_to_point.param_names = ["basis", "size"]
88110

89-
def time_steering_cost():
111+
def time_point_to_point_with_cost(basis_name, basis_size):
112+
basis = get_basis(basis_name, basis_size)
113+
90114
# Define cost and constraints
91115
traj_cost = opt.quadratic_cost(
92116
vehicle, None, np.diag([0.1, 1]), u0=uf)
@@ -95,13 +119,47 @@ def time_steering_cost():
95119

96120
traj = flat.point_to_point(
97121
vehicle, timepts, x0, u0, xf, uf,
98-
cost=traj_cost, constraints=constraints, basis=flat.PolyFamily(8)
122+
cost=traj_cost, constraints=constraints, basis=basis,
99123
)
100124

101125
# Verify that the trajectory computation is correct
102126
x, u = traj.eval([0, Tf])
103127
np.testing.assert_array_almost_equal(x0, x[:, 0])
104128
np.testing.assert_array_almost_equal(u0, u[:, 0])
105-
np.testing.assert_array_almost_equal(xf, x[:, 1])
106-
np.testing.assert_array_almost_equal(uf, u[:, 1])
129+
np.testing.assert_array_almost_equal(xf, x[:, -1])
130+
np.testing.assert_array_almost_equal(uf, u[:, -1])
131+
132+
time_point_to_point_with_cost.params = basis_params
133+
time_point_to_point_with_cost.param_names = basis_param_names
134+
135+
136+
def time_solve_flat_ocp_terminal_cost(method, basis_name, basis_size):
137+
basis = get_basis(basis_name, basis_size)
138+
139+
# Define cost and constraints
140+
traj_cost = opt.quadratic_cost(
141+
vehicle, None, np.diag([0.1, 1]), u0=uf)
142+
term_cost = opt.quadratic_cost(
143+
vehicle, np.diag([1e3, 1e3, 1e3]), None, x0=xf)
144+
constraints = [
145+
opt.input_range_constraint(vehicle, [8, -0.1], [12, 0.1]) ]
146+
147+
# Initial guess = straight line
148+
initial_guess = np.array(
149+
[x0[i] + (xf[i] - x0[i]) * timepts/Tf for i in (0, 1)])
150+
151+
traj = flat.solve_flat_ocp(
152+
vehicle, timepts, x0, u0, basis=basis, initial_guess=initial_guess,
153+
trajectory_cost=traj_cost, constraints=constraints,
154+
terminal_cost=term_cost, minimize_method=method,
155+
)
156+
157+
# Verify that the trajectory computation is correct
158+
x, u = traj.eval([0, Tf])
159+
np.testing.assert_array_almost_equal(x0, x[:, 0])
160+
np.testing.assert_array_almost_equal(xf, x[:, -1], decimal=2)
107161

162+
time_solve_flat_ocp_terminal_cost.params = tuple(
163+
[['slsqp', 'trust-constr']] + list(basis_params))
164+
time_solve_flat_ocp_terminal_cost.param_names = tuple(
165+
['method'] + basis_param_names)

control/flatsys/flatsys.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,11 @@ def solve_flat_ocp(
654654
* cost : computed cost of the returned trajectory
655655
* message : message returned by optimization if success if False
656656
657+
3. A common failure in solving optimal control problem is that the
658+
default initial guess violates the constraints and the optimizer
659+
can't find a feasible solution. Using the `initial_guess` parameter
660+
can often be used to overcome these errors.
661+
657662
"""
658663
#
659664
# Make sure the problem is one that we can handle

control/tests/flatsys_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,26 @@ def test_bezier_basis(self):
464464
with pytest.raises(ValueError, match="index too high"):
465465
bezier.eval_deriv(4, 0, time)
466466

467+
@pytest.mark.parametrize("basis, degree, T", [
468+
(fs.PolyFamily(4), 4, 1),
469+
(fs.PolyFamily(4, 100), 4, 100),
470+
(fs.BezierFamily(4), 4, 1),
471+
(fs.BezierFamily(4, 100), 4, 100),
472+
(fs.BSplineFamily([0, 0.5, 1], 4), 3, 1),
473+
(fs.BSplineFamily([0, 50, 100], 4), 3, 100),
474+
])
475+
def test_basis_derivs(self, basis, degree, T):
476+
"""Make sure that that basis function derivates are correct"""
477+
timepts = np.linspace(0, T, 10000)
478+
dt = timepts[1] - timepts[0]
479+
for i in range(basis.N):
480+
for j in range(degree-1):
481+
# Compare numerical and analytical derivative
482+
np.testing.assert_allclose(
483+
np.diff(basis.eval_deriv(i, j, timepts)) / dt,
484+
basis.eval_deriv(i, j+1, timepts)[0:-1],
485+
atol=1e-2, rtol=1e-4)
486+
467487
def test_point_to_point_errors(self):
468488
"""Test error and warning conditions in point_to_point()"""
469489
# Double integrator system

0 commit comments

Comments
 (0)