Skip to content

Commit 178af36

Browse files
committed
add missing derivs for Bezier basis
1 parent 66d4a53 commit 178af36

2 files changed

Lines changed: 54 additions & 11 deletions

File tree

control/flatsys/bezier.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
# SUCH DAMAGE.
4040

4141
import numpy as np
42-
from scipy.special import binom
42+
from scipy.special import binom, factorial
4343
from .basis import BasisFamily
4444

4545
class BezierFamily(BasisFamily):
@@ -59,11 +59,23 @@ def __init__(self, N, T=1):
5959
# Compute the kth derivative of the ith basis function at time t
6060
def eval_deriv(self, i, k, t):
6161
"""Evaluate the kth derivative of the ith basis function at time t."""
62-
if k > 0:
63-
raise NotImplementedError("Bezier derivatives not yet available")
64-
elif i > self.N:
62+
if i >= self.N:
6563
raise ValueError("Basis function index too high")
64+
elif k >= self.N:
65+
# Higher order derivatives are zero
66+
return np.zeros(t.shape)
6667

67-
# Return the Bezier basis function (note N = # basis functions)
68-
return binom(self.N - 1, i) * \
69-
(t/self.T)**i * (1 - t/self.T)**(self.N - i - 1)
68+
# Compute the variables used in Bezier curve formulas
69+
n = self.N - 1
70+
u = t/self.T
71+
72+
if k == 0:
73+
# No derivative => avoid expansion for speed
74+
return binom(n, i) * u**i * (1-u)**(n-i)
75+
76+
# Return the kth derivative of the ith Bezier basis function
77+
return binom(n, i) * sum([
78+
(-1)**(j-i) *
79+
binom(n-i, j-i) * factorial(j)/factorial(j-k) * np.power(u, j-k)
80+
for j in range(max(i, k), n+1)
81+
])

control/tests/flatsys_test.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def test_double_integrator(self, xf, uf, Tf):
5151
t, y, x = ct.forced_response(sys, T, ud, x1, return_x=True)
5252
np.testing.assert_array_almost_equal(x, xd, decimal=3)
5353

54-
def test_kinematic_car(self):
54+
@pytest.mark.parametrize("poly", [fs.PolyFamily(6), fs.BezierFamily(6)])
55+
def test_kinematic_car(self, poly):
5556
"""Differential flatness for a kinematic car"""
5657
def vehicle_flat_forward(x, u, params={}):
5758
b = params.get('wheelbase', 3.) # get parameter values
@@ -98,9 +99,6 @@ def vehicle_output(t, x, u, params): return x
9899
xf = [100., 2., 0.]; uf = [10., 0.]
99100
Tf = 10
100101

101-
# Define a set of basis functions to use for the trajectories
102-
poly = fs.PolyFamily(6)
103-
104102
# Find trajectory between initial and final conditions
105103
traj = fs.point_to_point(vehicle_flat, x0, u0, xf, uf, Tf, basis=poly)
106104

@@ -121,3 +119,36 @@ def vehicle_output(t, x, u, params): return x
121119
vehicle_flat, T, ud, x0, return_x=True)
122120
np.testing.assert_allclose(x, xd, atol=0.01, rtol=0.01)
123121

122+
def test_bezier_basis(self):
123+
bezier = fs.BezierFamily(4)
124+
time = np.linspace(0, 1, 100)
125+
126+
# Sum of the Bezier curves should be one
127+
np.testing.assert_almost_equal(
128+
1, sum([bezier(i, time) for i in range(4)]))
129+
130+
# Sum of derivatives should be zero
131+
for k in range(1, 5):
132+
np.testing.assert_almost_equal(
133+
0, sum([bezier.eval_deriv(i, k, time) for i in range(4)]))
134+
135+
# Compare derivatives to formulas
136+
np.testing.assert_almost_equal(
137+
bezier.eval_deriv(1, 0, time), 3 * time - 6 * time**2 + 3 * time**3)
138+
np.testing.assert_almost_equal(
139+
bezier.eval_deriv(1, 1, time), 3 - 12 * time + 9 * time**2)
140+
np.testing.assert_almost_equal(
141+
bezier.eval_deriv(1, 2, time), -12 + 18 * time)
142+
143+
# Make sure that the second derivative integrates to the first
144+
time = np.linspace(0, 1, 1000)
145+
dt = np.diff(time)
146+
for i in range(4):
147+
for j in (2, 3, 4):
148+
np.testing.assert_almost_equal(
149+
np.diff(bezier.eval_deriv(i, j-1, time)) / dt,
150+
bezier.eval_deriv(i, j, time)[0:-1], decimal=2)
151+
152+
# Exception check
153+
with pytest.raises(ValueError, match="index too high"):
154+
bezier.eval_deriv(4, 0, time)

0 commit comments

Comments
 (0)