Skip to content

Commit 52ab45e

Browse files
committed
Add test for scalar timepts arg in solve_flat_ocp
1 parent ebff125 commit 52ab45e

1 file changed

Lines changed: 23 additions & 0 deletions

File tree

control/tests/flatsys_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,29 @@ def test_flat_solve_ocp(self, basis):
452452
np.testing.assert_almost_equal(x_const, x_nlconst)
453453
np.testing.assert_almost_equal(u_const, u_nlconst)
454454

455+
def test_solve_flat_ocp_scalar_timepts(self):
456+
# scalar timepts gives expected result
457+
f = fs.LinearFlatSystem(ct.ss(ct.tf([1],[1,1])))
458+
459+
def terminal_cost(x, u):
460+
return (x-5).dot(x-5)+u.dot(u)
461+
462+
traj1 = fs.solve_flat_ocp(f, [0, 1], x0=[23],
463+
terminal_cost=terminal_cost)
464+
465+
traj2 = fs.solve_flat_ocp(f, 1, x0=[23],
466+
terminal_cost=terminal_cost)
467+
468+
teval = np.linspace(0, 1, 101)
469+
470+
r1 = traj1.response(teval)
471+
r2 = traj2.response(teval)
472+
473+
assert np.max(abs(r1.x-r2.x)) == 0
474+
assert np.max(abs(r1.u-r2.u)) == 0
475+
assert np.max(abs(r1.y-r2.y)) == 0
476+
477+
455478
def test_bezier_basis(self):
456479
bezier = fs.BezierFamily(4)
457480
time = np.linspace(0, 1, 100)

0 commit comments

Comments
 (0)