Skip to content

Commit e65c4f2

Browse files
committed
additional unit tests for coverage + bug fixes
1 parent 05084f3 commit e65c4f2

2 files changed

Lines changed: 54 additions & 4 deletions

File tree

control/flatsys/basis.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
3737
# SUCH DAMAGE.
3838

39+
import numpy as np
40+
3941

4042
# Basis family class (for use as a base class)
4143
class BasisFamily:
@@ -86,16 +88,24 @@ def eval(self, coeffs, tlist, var=None):
8688
sum([coeffs[i] * self(i, t) for i in range(self.N)])
8789
for t in tlist]
8890

89-
else:
90-
# Multi-variable basis
91+
elif var is None:
92+
# Multi-variable basis with single list of coefficients
9193
values = np.empty((self.nvars, tlist.size))
94+
offset = 0
9295
for j in range(self.nvars):
9396
coef_len = self.var_ncoefs(j)
9497
values[j] = np.array([
95-
sum([coeffs[i] * self(i, t, var=j)
98+
sum([coeffs[offset + i] * self(i, t, var=j)
9699
for i in range(coef_len)])
97100
for t in tlist])
101+
offset += coef_len
98102
return values
99103

104+
else:
105+
return np.array([
106+
sum([coeffs[i] * self(i, t, var=var)
107+
for i in range(self.var_ncoefs(var))])
108+
for t in tlist])
109+
100110
def eval_deriv(self, i, j, t, var=None):
101111
raise NotImplementedError("Internal error; improper basis functions")

control/tests/flatsys_test.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def test_flat_cost_constr(self, basis):
313313
np.testing.assert_array_almost_equal(uf, u_const[:, -1])
314314

315315
# Make sure that the solution respects the bounds (with some slop)
316-
for i in range(x_const.shape[0]):
316+
for i in range(x_const.shape[0]):
317317
assert all(lb[i] - x_const[i] < rtol * abs(lb[i]) + atol)
318318
assert all(x_const[i] - ub[i] < rtol * abs(ub[i]) + atol)
319319

@@ -673,3 +673,43 @@ def test_response(self, xf, uf, Tf):
673673
np.testing.assert_equal(T, response.time)
674674
np.testing.assert_equal(u, response.inputs)
675675
np.testing.assert_equal(x, response.states)
676+
677+
@pytest.mark.parametrize(
678+
"basis",
679+
[fs.PolyFamily(4),
680+
fs.BezierFamily(4),
681+
fs.BSplineFamily([0, 1], 4),
682+
fs.BSplineFamily([0, 1], 4, vars=2),
683+
fs.BSplineFamily([0, 1], [4, 3], [2, 1], vars=2),
684+
])
685+
def test_basis_class(self, basis):
686+
timepts = np.linspace(0, 1, 10)
687+
688+
if basis.nvars is None:
689+
# Evaluate function on basis vectors
690+
for j in range(basis.N):
691+
coefs = np.zeros(basis.N)
692+
coefs[j] = 1
693+
np.testing.assert_equal(
694+
basis.eval(coefs, timepts),
695+
basis.eval_deriv(j, 0, timepts))
696+
else:
697+
# Evaluate each variable on basis vectors
698+
for i in range(basis.nvars):
699+
for j in range(basis.var_ncoefs(i)):
700+
coefs = np.zeros(basis.var_ncoefs(i))
701+
coefs[j] = 1
702+
np.testing.assert_equal(
703+
basis.eval(coefs, timepts, var=i),
704+
basis.eval_deriv(j, 0, timepts, var=i))
705+
706+
# Evaluate multi-variable output
707+
offset = 0
708+
for i in range(basis.nvars):
709+
for j in range(basis.var_ncoefs(i)):
710+
coefs = np.zeros(basis.N)
711+
coefs[offset] = 1
712+
np.testing.assert_equal(
713+
basis.eval(coefs, timepts)[i],
714+
basis.eval_deriv(j, 0, timepts, var=i))
715+
offset += 1

0 commit comments

Comments
 (0)