Skip to content

Commit 05084f3

Browse files
committed
unit tests, bug fixes, algorithm improvements
* add initial_guess functionality to solve_flat_ocp * pre-compute collocation matrices in point_to_point, solve_flat_ocp * updated return values for solve_flat_ocp * add __repr__ for flat basis functions * docstring improvements * additional unit tests + examples
1 parent 78f7bbb commit 05084f3

8 files changed

Lines changed: 335 additions & 78 deletions

File tree

control/flatsys/basis.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def __init__(self, N):
6464
self.coef_offset = [0] # coefficient offset for each variable
6565
self.coef_length = [N] # coefficient length for each variable
6666

67+
def __repr__(self):
68+
return f'<{self.__class__.__name__}: nvars={self.nvars}, ' + \
69+
f'N={self.N}>'
70+
6771
def __call__(self, i, t, var=None):
6872
"""Evaluate the ith basis function at a point in time"""
6973
return self.eval_deriv(i, 0, t, var=var)

control/flatsys/bezier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def eval_deriv(self, i, k, t, var=None):
7878
# Return the kth derivative of the ith Bezier basis function
7979
return binom(n, i) * sum([
8080
(-1)**(j-i) *
81-
binom(n-i, j-i) * factorial(j)/factorial(j-k) * np.power(u, j-k)
81+
binom(n-i, j-i) * factorial(j)/factorial(j-k) * \
82+
np.power(u, j-k) / np.power(self.T, k)
8283
for j in range(max(i, k), n+1)
8384
])

control/flatsys/bspline.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def process_spline_parameters(
179179
self.knotpoints[i][offset:offset + self.degree[i] + 1] = \
180180
self.breakpoints[-1]
181181

182+
def __repr__(self):
183+
return f'<{self.__class__.__name__}: nvars={self.nvars}, ' + \
184+
f'degree={self.degree}, smoothness={self.smoothness}>'
185+
182186
# Compute the kth derivative of the ith basis function at time t
183187
def eval_deriv(self, i, k, t, var=None):
184188
"""Evaluate the kth derivative of the ith basis function at time t."""

control/flatsys/flatsys.py

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -426,24 +426,23 @@ def point_to_point(
426426
if rank < Z.size:
427427
warnings.warn("basis too small; solution may not exist")
428428

429-
# Precompute the collocation matrix the defines the flag at timepts
430-
Mt_list = []
431-
for t in timepts:
432-
Mt_list.append(_basis_flag_matrix(sys, basis, zflag_T0, t))
433-
434429
if cost is not None or trajectory_constraints is not None:
435430
# Search over the null space to minimize cost/satisfy constraints
436431
N = sp.linalg.null_space(M)
437432

433+
# Precompute the collocation matrix the defines the flag at timepts
434+
Mt_list = []
435+
for t in timepts[1:-1]:
436+
Mt_list.append(_basis_flag_matrix(sys, basis, zflag_T0, t))
437+
438438
# Define a function to evaluate the cost along a trajectory
439439
def traj_cost(null_coeffs):
440440
# Add this to the existing solution
441441
coeffs = alpha + N @ null_coeffs
442442

443443
# Evaluate the costs at the listed time points
444-
# TODO: store Mt ahead of time, since it doesn't change
445444
costval = 0
446-
for i, t in enumerate(timepts):
445+
for i, t in enumerate(timepts[1:-1]):
447446
M_t = Mt_list[i]
448447

449448
# Compute flag at this time point
@@ -453,7 +452,7 @@ def traj_cost(null_coeffs):
453452
x, u = sys.reverse(zflag, params)
454453

455454
# Evaluate the cost at this time point
456-
costval += cost(x, u)
455+
costval += cost(x, u) * (timepts[i+1] - timepts[i])
457456
return costval
458457

459458
# If no cost given, override with magnitude of the coefficients
@@ -480,7 +479,7 @@ def traj_const(null_coeffs):
480479

481480
# Evaluate the constraints at the listed time points
482481
values = []
483-
for i, t in enumerate(timepts):
482+
for i, t in enumerate(timepts[1:-1]):
484483
# Calculate the states and inputs for the flat output
485484
M_t = Mt_list[i]
486485

@@ -504,7 +503,7 @@ def traj_const(null_coeffs):
504503

505504
# Store upper and lower bounds
506505
const_lb, const_ub = [], []
507-
for t in timepts:
506+
for t in timepts[1:-1]:
508507
for type, fun, lb, ub in traj_constraints:
509508
const_lb.append(lb)
510509
const_ub.append(ub)
@@ -515,9 +514,6 @@ def traj_const(null_coeffs):
515514
minimize_constraints = [sp.optimize.NonlinearConstraint(
516515
traj_const, const_lb, const_ub)]
517516

518-
# Add initial and terminal constraints
519-
# minimize_constraints += [sp.optimize.LinearConstraint(M, Z, Z)]
520-
521517
# Process the initial condition
522518
if initial_guess is None:
523519
initial_guess = np.zeros(M.shape[1] - M.shape[0])
@@ -528,19 +524,25 @@ def traj_const(null_coeffs):
528524
res = sp.optimize.minimize(
529525
traj_cost, initial_guess, constraints=minimize_constraints,
530526
**minimize_kwargs)
531-
if res.success:
532-
alpha += N @ res.x
533-
else:
534-
raise RuntimeError(
535-
"Unable to solve optimal control problem\n" +
536-
"scipy.optimize.minimize returned " + res.message)
527+
alpha += N @ res.x
528+
529+
# See if we got an answer
530+
if not res.success:
531+
warnings.warn(
532+
"unable to solve optimal control problem\n"
533+
f"scipy.optimize.minimize: '{res.message}'", UserWarning)
537534

538535
#
539536
# Transform the trajectory from flat outputs to states and inputs
540537
#
541538

542539
# Create a trajectory object to store the result
543540
systraj = SystemTrajectory(sys, basis, params=params)
541+
if cost is not None or trajectory_constraints is not None:
542+
# Store the result of the optimization
543+
systraj.cost = res.fun
544+
systraj.success = res.success
545+
systraj.message = res.message
544546

545547
# Store the flag lengths and coefficients
546548
# TODO: make this more pythonic
@@ -560,7 +562,7 @@ def traj_const(null_coeffs):
560562

561563
# Solve a point to point trajectory generation problem for a flat system
562564
def solve_flat_ocp(
563-
sys, timepts, x0=0, u0=0, basis=None, trajectory_cost=None,
565+
sys, timepts, x0=0, u0=0, trajectory_cost=None, basis=None,
564566
terminal_cost=None, trajectory_constraints=None,
565567
initial_guess=None, params=None, **kwargs):
566568
"""Compute trajectory between an initial and final conditions.
@@ -619,6 +621,9 @@ def solve_flat_ocp(
619621
620622
The constraints are applied at each time point along the trajectory.
621623
624+
initial_guess : 2D array_like, optional
625+
Initial guess for the optimal trajectory of the flat outputs.
626+
622627
minimize_kwargs : str, optional
623628
Pass additional keywords to :func:`scipy.optimize.minimize`.
624629
@@ -631,9 +636,14 @@ def solve_flat_ocp(
631636
632637
Notes
633638
-----
634-
Additional keyword parameters can be used to fine tune the behavior of
635-
the underlying optimization function. See `minimize_*` keywords in
636-
:func:`OptimalControlProblem` for more information.
639+
1. Additional keyword parameters can be used to fine tune the behavior
640+
of the underlying optimization function. See `minimize_*` keywords
641+
in :func:`OptimalControlProblem` for more information.
642+
643+
2. The return data structure includes the following additional attributes:
644+
* success : bool indicating whether the optimization succeeded
645+
* cost : computed cost of the returned trajectory
646+
* message : message returned by optimization if success if False
637647
638648
"""
639649
#
@@ -705,7 +715,7 @@ def solve_flat_ocp(
705715
# essentially amounts to evaluating the basis functions and their
706716
# derivatives at the initial conditions.
707717

708-
# Compute the flags for the initial and final states
718+
# Compute the flag for the initial state
709719
M_T0 = _basis_flag_matrix(sys, basis, zflag_T0, T0)
710720

711721
#
@@ -752,7 +762,7 @@ def traj_cost(null_coeffs):
752762

753763
# Evaluate the cost at this time point
754764
# TODO: make use of time interval
755-
costval += trajectory_cost(x, u)
765+
costval += trajectory_cost(x, u) * (timepts[i+1] - timepts[i])
756766

757767
# Evaluate the terminal_cost
758768
if terminal_cost is not None:
@@ -821,29 +831,47 @@ def traj_const(null_coeffs):
821831
# Add initial and terminal constraints
822832
# minimize_constraints += [sp.optimize.LinearConstraint(M, Z, Z)]
823833

824-
# Process the initial condition
834+
# Process the initial guess
825835
if initial_guess is None:
826-
initial_guess = np.zeros(M_T0.shape[1] - M_T0.shape[0])
836+
initial_coefs = np.ones(M_T0.shape[1] - M_T0.shape[0])
827837
else:
828-
raise NotImplementedError("Initial guess not yet implemented.")
838+
# Compute the map from coefficients to flat outputs
839+
initial_coefs = []
840+
for i in range(sys.ninputs):
841+
M_z = np.array([
842+
basis.eval_deriv(j, 0, timepts, var=i)
843+
for j in range(basis.var_ncoefs(i))]).transpose()
844+
845+
# Compute the parameters that give the best least squares fit
846+
coefs, _, _, _ = np.linalg.lstsq(
847+
M_z, initial_guess[i], rcond=None)
848+
initial_coefs.append(coefs)
849+
initial_coefs = np.hstack(initial_coefs)
850+
851+
# Project the parameters onto the independent variables
852+
initial_coefs, _, _, _ = np.linalg.lstsq(N, initial_coefs, rcond=None)
829853

830854
# Find the optimal solution
831855
res = sp.optimize.minimize(
832-
traj_cost, initial_guess, constraints=minimize_constraints,
856+
traj_cost, initial_coefs, constraints=minimize_constraints,
833857
**minimize_kwargs)
834-
if res.success:
835-
alpha += N @ res.x
836-
else:
837-
raise RuntimeError(
838-
"Unable to solve optimal control problem\n" +
839-
"scipy.optimize.minimize returned " + res.message)
858+
alpha += N @ res.x
859+
860+
# See if we got an answer
861+
if not res.success:
862+
warnings.warn(
863+
"unable to solve optimal control problem\n"
864+
f"scipy.optimize.minimize: '{res.message}'", UserWarning)
840865

841866
#
842867
# Transform the trajectory from flat outputs to states and inputs
843868
#
844869

845870
# Create a trajectory object to store the result
846871
systraj = SystemTrajectory(sys, basis, params=params)
872+
systraj.cost = res.fun
873+
systraj.success = res.success
874+
systraj.message = res.message
847875

848876
# Store the flag lengths and coefficients
849877
# TODO: make this more pythonic

control/flatsys/poly.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@ class PolyFamily(BasisFamily):
5050
\phi_i(t) = t^i
5151
5252
"""
53-
def __init__(self, N):
53+
def __init__(self, N, T=1.):
5454
"""Create a polynomial basis of order N."""
5555
super(PolyFamily, self).__init__(N)
56+
self.T = T
5657

5758
# Compute the kth derivative of the ith basis function at time t
5859
def eval_deriv(self, i, k, t, var=None):
5960
"""Evaluate the kth derivative of the ith basis function at time t."""
6061
if (i < k): return 0; # higher derivative than power
61-
return factorial(i)/factorial(i-k) * np.power(t, i-k)
62+
return factorial(i)/factorial(i-k) * \
63+
np.power(t/self.T, i-k) / np.power(self.T, k)

0 commit comments

Comments
 (0)