Skip to content

Commit 0ab0d07

Browse files
committed
add support for multiple spline variables; update docstrings
1 parent 3dcad07 commit 0ab0d07

7 files changed

Lines changed: 142 additions & 67 deletions

File tree

control/flatsys/basis.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@ class BasisFamily:
4747
4848
:math:`z_i^{(q)}(t)` = basis.eval_deriv(self, i, j, t)
4949
50-
Parameters
50+
A basis set can either consist of a single variable that is used for
51+
each flat output (nvars = None) or a different variable for different
52+
flat outputs (nvars > 0).
53+
54+
Attributes
5155
----------
5256
N : int
5357
Order of the basis set.
@@ -56,15 +60,38 @@ class BasisFamily:
5660
def __init__(self, N):
5761
"""Create a basis family of order N."""
5862
self.N = N # save number of basis functions
63+
self.nvars = None # default number of variables
64+
self.coef_offset = [0] # coefficient offset for each variable
65+
self.coef_length = [N] # coefficient length for each variable
5966

60-
def __call__(self, i, t):
67+
def __call__(self, i, t, var=None):
6168
"""Evaluate the ith basis function at a point in time"""
62-
return self.eval_deriv(i, 0, t)
69+
return self.eval_deriv(i, 0, t, var=var)
70+
71+
def var_ncoefs(self, var):
72+
"""Get the number of coefficients for a variable"""
73+
return self.N if self.nvars is None else self.coef_length[var]
74+
75+
def eval(self, coeffs, tlist, var=None):
76+
if self.nvars is None and var != None:
77+
raise SystemError("multi-variable call to a scalar basis")
78+
79+
elif self.nvars is None:
80+
# Single variable basis
81+
return [
82+
sum([coeffs[i] * self(i, t) for i in range(self.N)])
83+
for t in tlist]
6384

64-
def eval(self, coeffs, tlist):
65-
return [
66-
sum([coeffs[i] * self(i, t) for i in range(self.N)])
67-
for t in tlist]
85+
else:
86+
# Multi-variable basis
87+
values = np.empty((self.nvars, tlist.size))
88+
for j in range(self.nvars):
89+
coef_len = self.var_ncoefs(j)
90+
values[j] = np.array([
91+
sum([coeffs[i] * self(i, t, var=j)
92+
for i in range(coef_len)])
93+
for t in tlist])
94+
return values
6895

69-
def eval_deriv(self, i, j, t):
96+
def eval_deriv(self, i, j, t, var=None):
7097
raise NotImplementedError("Internal error; improper basis functions")

control/flatsys/bezier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, N, T=1):
5959
self.T = T # save end of time interval
6060

6161
# Compute the kth derivative of the ith basis function at time t
62-
def eval_deriv(self, i, k, t):
62+
def eval_deriv(self, i, k, t, var=None):
6363
"""Evaluate the kth derivative of the ith basis function at time t."""
6464
if i >= self.N:
6565
raise ValueError("Basis function index too high")

control/flatsys/bspline.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ class BSplineFamily(BasisFamily):
1717
across a set of breakpoints with given order and smoothness.
1818
1919
"""
20-
def __init__(self, breakpoints, degree, smoothness=None, vars=1):
20+
def __init__(self, breakpoints, degree, smoothness=None, vars=None):
2121
"""Create a B-spline basis for piecewise smooth polynomials
2222
2323
Define B-spline polynomials for a set of one or more variables.
24-
B-splines are characterized by a set of intervals separated by break
25-
points. On each interval we have a polynomial of a certain order
26-
and the spline is continuous up to a given smoothness at interior
27-
break points.
24+
B-splines are used as a basis for a set of piecewise smooth
25+
polynomials joined at breakpoints. On each interval we have a
26+
polynomial of a given order and the spline is continuous up to a
27+
given smoothness at interior breakpoints.
2828
2929
Parameters
3030
----------
@@ -41,8 +41,11 @@ def __init__(self, breakpoints, degree, smoothness=None, vars=1):
4141
For each spline variable, the smoothness at breakpoints (number
4242
of derivatives that should match).
4343
44-
vars : int or list of str, option
45-
The number of spline variables or a list of spline variable names.
44+
vars : None or int, optional
45+
The number of spline variables. If specified as None (default),
46+
then the spline basis describes a single variable, with no
47+
indexing. If the number of spine variables is > 0, then the
48+
spline basis is index using the `var` keyword.
4649
4750
"""
4851
# Process the breakpoints for the spline */
@@ -58,17 +61,19 @@ def __init__(self, breakpoints, degree, smoothness=None, vars=1):
5861
raise ValueError("break points must be strictly increasing values")
5962

6063
# Decide on the number of spline variables
61-
if isinstance(vars, list) and all([isinstance(v, str) for v in vars]):
62-
raise NotImplemented("list of variable names not yet supported")
64+
if vars is None:
65+
nvars = 1
66+
self.nvars = None # track as single variable
6367
elif not isinstance(vars, int):
64-
raise TypeError("vars must be an integer or list of strings")
68+
raise TypeError("vars must be an integer")
6569
else:
6670
nvars = vars
71+
self.nvars = nvars
6772

6873
#
6974
# Process B-spline parameters (order, smoothness)
7075
#
71-
# B-splines are characterized by a set of intervals separated by
76+
# B-splines are defined on a set of intervals separated by
7277
# breakpoints. On each interval we have a polynomial of a certain
7378
# order and the spline is continuous up to a given smoothness at
7479
# breakpoints. The code in this section allows some flexibility in
@@ -99,14 +104,15 @@ def process_spline_parameters(
99104
elif all([isinstance(v, allowed_types) for v in values]):
100105
# List of values => make sure it is the right size
101106
if len(values) != length:
102-
raise ValueError(f"length of '{name}' does not match n")
107+
raise ValueError(f"length of '{name}' does not match"
108+
f" number of variables")
103109
else:
104110
raise ValueError(f"could not parse '{name}' keyword")
105111

106112
# Check to make sure the values are OK
107113
if values is not None and any([val < minimum for val in values]):
108114
raise ValueError(
109-
f"invalid value for {name}; must be at least {minimum}")
115+
f"invalid value for '{name}'; must be at least {minimum}")
110116

111117
return values
112118

@@ -123,25 +129,23 @@ def process_spline_parameters(
123129
if any([degree[i] - smoothness[i] < 1 for i in range(nvars)]):
124130
raise ValueError("degree must be greater than smoothness")
125131

126-
# Store the parameters and process them in call_ntg()
127-
self.nvars = nvars
132+
# Store the parameters for the spline (self.nvars already stored)
128133
self.breakpoints = breakpoints
129134
self.degree = degree
130135
self.smoothness = smoothness
131-
self.nintervals = breakpoints.size - 1
132136

133137
#
134138
# Compute parameters for a SciPy BSpline object
135139
#
136-
# To create a B-spline, we need to compute the knot points, keeping
137-
# track of the use of repeated knot points at the initial knot and
140+
# To create a B-spline, we need to compute the knotpoints, keeping
141+
# track of the use of repeated knotpoints at the initial knot and
138142
# final knot as well as repeated knots at intermediate points
139143
# depending on the desired smoothness.
140144
#
141145

142146
# Store the coefficients for each output (useful later)
143147
self.coef_offset, self.coef_length, offset = [], [], 0
144-
for i in range(self.nvars):
148+
for i in range(nvars):
145149
# Compute number of coefficients for the piecewise polynomial
146150
ncoefs = (self.degree[i] + 1) * (len(self.breakpoints) - 1) - \
147151
(self.smoothness[i] + 1) * (len(self.breakpoints) - 2)
@@ -151,48 +155,43 @@ def process_spline_parameters(
151155
offset += ncoefs
152156
self.N = offset # save the total number of coefficients
153157

154-
# Create knot points for each spline variable
158+
# Create knotpoints for each spline variable
155159
# TODO: extend to multi-dimensional breakpoints
156160
self.knotpoints = []
157-
for i in range(self.nvars):
161+
for i in range(nvars):
158162
# Allocate space for the knotpoints
159163
self.knotpoints.append(np.empty(
160164
(self.degree[i] + 1) + (len(self.breakpoints) - 2) * \
161165
(self.degree[i] - self.smoothness[i]) + (self.degree[i] + 1)))
162166

163-
# Initial knot points
167+
# Initial knotpoints (multiplicity = order)
164168
self.knotpoints[i][0:self.degree[i] + 1] = self.breakpoints[0]
165169
offset = self.degree[i] + 1
166170

167-
# Interior knot points
171+
# Interior knotpoints (multiplicity = degree - smoothness)
168172
nknots = self.degree[i] - self.smoothness[i]
169173
assert nknots > 0 # just in case
170174
for j in range(1, self.breakpoints.size - 1):
171175
self.knotpoints[i][offset:offset+nknots] = self.breakpoints[j]
172176
offset += nknots
173177

174-
# Final knot point
178+
# Final knotpoint (multiplicity = order)
175179
self.knotpoints[i][offset:offset + self.degree[i] + 1] = \
176180
self.breakpoints[-1]
177181

178-
def eval(self, coefs, tlist):
179-
return np.array([
180-
BSpline(self.knotpoints[i],
181-
coefs[self.coef_offset[i]:
182-
self.coef_offset[i] + self.coef_length[i]],
183-
self.degree[i])(tlist)
184-
for i in range(self.nvars)])
185-
186182
# Compute the kth derivative of the ith basis function at time t
187-
def eval_deriv(self, i, k, t, squeeze=True):
183+
def eval_deriv(self, i, k, t, var=None):
188184
"""Evaluate the kth derivative of the ith basis function at time t."""
189-
if self.nvars > 1 or not squeeze:
190-
raise NotImplementedError(
191-
"derivatives of multi-variable splines not yet supported")
185+
if self.nvars is None or (self.nvars == 1 and var is None):
186+
# Use same variable for all requests
187+
var = 0
188+
elif self.nvars > 1 and var is None:
189+
raise SystemError(
190+
"scalar variable call to multi-variable splines")
192191

193192
# Create a coefficient vector for this spline
194-
coefs = np.zeros(self.coef_length[0]); coefs[i] = 1
193+
coefs = np.zeros(self.coef_length[var]); coefs[i] = 1
195194

196195
# Evaluate the derivative of the spline at the desired point in time
197-
return BSpline(self.knotpoints[0], coefs,
198-
self.degree[0]).derivative(k)(t)
196+
return BSpline(self.knotpoints[var], coefs,
197+
self.degree[var]).derivative(k)(t)

control/flatsys/flatsys.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(self,
155155
if reverse is not None: self.reverse = reverse
156156

157157
# Save the length of the flat flag
158+
# TODO: missing
158159

159160
def __str__(self):
160161
return f"{NonlinearIOSystem.__str__(self)}\n\n" \
@@ -233,17 +234,19 @@ def _basis_flag_matrix(sys, basis, flag, t, params={}):
233234
column of the matrix corresponds to a basis function and each row is a
234235
derivative, with the derivatives (flag) for each output stacked on top
235236
of each other.
236-
237+
l
237238
"""
238239
flagshape = [len(f) for f in flag]
239-
M = np.zeros((sum(flagshape), basis.N * sys.ninputs))
240+
M = np.zeros((sum(flagshape),
241+
sum([basis.var_ncoefs(i) for i in range(sys.ninputs)])))
240242
flag_off = 0
241-
coeff_off = 0
243+
coef_off = 0
242244
for i, flag_len in enumerate(flagshape):
243-
for j, k in itertools.product(range(basis.N), range(flag_len)):
244-
M[flag_off + k, coeff_off + j] = basis.eval_deriv(j, k, t)
245+
coef_len = basis.var_ncoefs(i)
246+
for j, k in itertools.product(range(coef_len), range(flag_len)):
247+
M[flag_off + k, coef_off + j] = basis.eval_deriv(j, k, t, var=i)
245248
flag_off += flag_len
246-
coeff_off += basis.N
249+
coef_off += coef_len
247250
return M
248251

249252

@@ -362,11 +365,16 @@ def point_to_point(
362365
if basis is None:
363366
basis = PolyFamily(2 * (sys.nstates + sys.ninputs))
364367

368+
# If a multivariable basis was given, make sure the size is correct
369+
if basis.nvars is not None and basis.nvars != sys.ninputs:
370+
raise ValueError("size of basis does not match flat system size")
371+
365372
# Make sure we have enough basis functions to solve the problem
366-
if basis.N * sys.ninputs < 2 * (sys.nstates + sys.ninputs):
373+
ncoefs = sum([basis.var_ncoefs(i) for i in range(sys.ninputs)])
374+
if ncoefs < 2 * (sys.nstates + sys.ninputs):
367375
raise ValueError("basis set is too small")
368376
elif (cost is not None or trajectory_constraints is not None) and \
369-
basis.N * sys.ninputs == 2 * (sys.nstates + sys.ninputs):
377+
ncoefs == 2 * (sys.nstates + sys.ninputs):
370378
warnings.warn("minimal basis specified; optimization not possible")
371379
cost = None
372380
trajectory_constraints = None
@@ -531,11 +539,12 @@ def traj_const(null_coeffs):
531539

532540
# Store the flag lengths and coefficients
533541
# TODO: make this more pythonic
534-
coeff_off = 0
542+
coef_off = 0
535543
for i in range(sys.ninputs):
536544
# Grab the coefficients corresponding to this flat output
537-
systraj.coeffs.append(alpha[coeff_off:coeff_off + basis.N])
538-
coeff_off += basis.N
545+
coef_len = basis.var_ncoefs(i)
546+
systraj.coeffs.append(alpha[coef_off:coef_off + coef_len])
547+
coef_off += coef_len
539548

540549
# Keep track of the length of the flat flag for this output
541550
systraj.flaglen.append(len(zflag_T0[i]))

control/flatsys/poly.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, N):
5555
super(PolyFamily, self).__init__(N)
5656

5757
# Compute the kth derivative of the ith basis function at time t
58-
def eval_deriv(self, i, k, t):
58+
def eval_deriv(self, i, k, t, var=None):
5959
"""Evaluate the kth derivative of the ith basis function at time t."""
6060
if (i < k): return 0; # higher derivative than power
6161
return factorial(i)/factorial(i-k) * np.power(t, i-k)

control/flatsys/systraj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ def eval(self, tlist):
106106
for i in range(self.ninputs):
107107
flag_len = self.flaglen[i]
108108
zflag.append(np.zeros(flag_len))
109-
for j in range(self.basis.N):
109+
for j in range(self.basis.var_ncoefs(i)):
110110
for k in range(flag_len):
111111
#! TODO: rewrite eval_deriv to take in time vector
112112
zflag[i][k] += self.coeffs[i][j] * \
113-
self.basis.eval_deriv(j, k, t)
113+
self.basis.eval_deriv(j, k, t, var=i)
114114

115115
# Now copy the states and inputs
116116
# TODO: revisit order of list arguments

0 commit comments

Comments
 (0)