Skip to content

Commit 4bb8d02

Browse files
author
Roman Werpachowski
authored
interp: Add NaturalCubic spline interpolator (#1657)
1 parent 4d500a2 commit 4bb8d02

3 files changed

Lines changed: 251 additions & 2 deletions

File tree

interp/cubic.go

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ func (pc *PiecewiseCubic) FitWithDerivatives(xs, ys, dydxs []float64) {
104104
pc.coeffs.Set(i, 2, (3*dy-(2*dydxs[i]+dydxs[i+1])*dx)/dx/dx)
105105
pc.coeffs.Set(i, 3, (-2*dy+(dydxs[i]+dydxs[i+1])*dx)/dx/dx/dx)
106106
}
107-
pc.xs = make([]float64, n)
108-
copy(pc.xs, xs)
107+
pc.xs = append(pc.xs[:0], xs...)
109108
pc.lastY = ys[m]
110109
pc.lastDyDx = dydxs[m]
111110
}
@@ -297,3 +296,119 @@ func fritschButlandEdgeDerivative(xs, ys, slopes []float64, leftEdge bool) float
297296
}
298297
return g
299298
}
299+
300+
// fitWithSecondDerivatives fits a piecewise cubic predictor to (X, Y, d^2Y/dX^2) value
301+
// triples provided as three slices.
302+
// It panics if any of these is true:
303+
// - len(xs) < 2,
304+
// - elements of xs are not strictly increasing,
305+
// - len(xs) != len(ys),
306+
// - len(xs) != len(d2ydx2s).
307+
// Note that this method does not guarantee on its own the continuity of first derivatives.
308+
func (pc *PiecewiseCubic) fitWithSecondDerivatives(xs, ys, d2ydx2s []float64) {
309+
n := len(xs)
310+
switch {
311+
case len(ys) != n, len(d2ydx2s) != n:
312+
panic(differentLengths)
313+
case n < 2:
314+
panic(tooFewPoints)
315+
}
316+
m := n - 1
317+
pc.coeffs.Reset()
318+
pc.coeffs.ReuseAs(m, 4)
319+
for i := 0; i < m; i++ {
320+
dx := xs[i+1] - xs[i]
321+
if dx <= 0 {
322+
panic(xsNotStrictlyIncreasing)
323+
}
324+
dy := ys[i+1] - ys[i]
325+
dm := d2ydx2s[i+1] - d2ydx2s[i]
326+
pc.coeffs.Set(i, 0, ys[i]) // a_0
327+
pc.coeffs.Set(i, 1, (dy-(d2ydx2s[i]+dm/3)*dx*dx/2)/dx) // a_1
328+
pc.coeffs.Set(i, 2, d2ydx2s[i]/2) // a_2
329+
pc.coeffs.Set(i, 3, dm/6/dx) // a_3
330+
}
331+
pc.xs = append(pc.xs[:0], xs...)
332+
pc.lastY = ys[m]
333+
lastDx := xs[m] - xs[m-1]
334+
pc.lastDyDx = pc.coeffs.At(m-1, 1) + 2*pc.coeffs.At(m-1, 2)*lastDx + 3*pc.coeffs.At(m-1, 3)*lastDx*lastDx
335+
}
336+
337+
// makeCubicSplineSecondDerivativeEquations generates the basic system of linear equations
338+
// which have to be satisfied by the second derivatives to make the first derivatives of a
339+
// cubic spline continuous. It panics if elements of xs are not strictly increasing, or
340+
// len(xs) != len(ys).
341+
// makeCubicSplineSecondDerivativeEquations returns a tri-diagonal matrix A and a vector b
342+
// defining a system of linear equations A*m = b for second derivatives vector m.
343+
func makeCubicSplineSecondDerivativeEquations(xs, ys []float64) (*mat.Tridiag, mat.MutableVector) {
344+
n := len(xs)
345+
if len(ys) != n {
346+
panic(differentLengths)
347+
}
348+
b := make([]float64, n)
349+
m := n - 1
350+
// Diagonal of A:
351+
d := make([]float64, n)
352+
// Sub-diagonal of A:
353+
dl := make([]float64, m)
354+
// Super-diagonal of A:
355+
du := make([]float64, m)
356+
if n > 2 {
357+
for i := 0; i < m; i++ {
358+
dx := xs[i+1] - xs[i]
359+
if dx <= 0 {
360+
panic(xsNotStrictlyIncreasing)
361+
}
362+
slope := (ys[i+1] - ys[i]) / dx
363+
if i > 0 {
364+
b[i] += slope
365+
d[i] += dx / 3
366+
du[i] = dx / 6
367+
}
368+
if i < m-1 {
369+
b[i+1] -= slope
370+
d[i+1] += dx / 3
371+
dl[i] = dx / 6
372+
}
373+
}
374+
}
375+
return mat.NewTridiag(n, dl, d, du), mat.NewVecDense(n, b)
376+
}
377+
378+
// NaturalCubic is a piecewise cubic 1-dimensional interpolator with
379+
// continuous value, first and second derivatives, which can be fitted to (X, Y)
380+
// value pairs without providing derivatives. See e.g. https://www.math.drexel.edu/~tolya/cubicspline.pdf
381+
// for details.
382+
type NaturalCubic struct {
383+
cubic PiecewiseCubic
384+
}
385+
386+
// Predict returns the interpolation value at x.
387+
func (nc *NaturalCubic) Predict(x float64) float64 {
388+
return nc.cubic.Predict(x)
389+
}
390+
391+
// PredictDerivative returns the predicted derivative at x.
392+
func (nc *NaturalCubic) PredictDerivative(x float64) float64 {
393+
return nc.cubic.PredictDerivative(x)
394+
}
395+
396+
// Fit fits a predictor to (X, Y) value pairs provided as two slices.
397+
// It panics if len(xs) < 2, elements of xs are not strictly increasing
398+
// or len(xs) != len(ys). It returns an error if solving the required system
399+
// of linear equations fails.
400+
func (nc *NaturalCubic) Fit(xs, ys []float64) error {
401+
a, b := makeCubicSplineSecondDerivativeEquations(xs, ys)
402+
// Add boundary conditions y''(left) = y''(right) = 0:
403+
n := len(xs)
404+
b.SetVec(0, 0)
405+
b.SetVec(n-1, 0)
406+
a.SetBand(0, 0, 1)
407+
a.SetBand(n-1, n-1, 1)
408+
x := mat.NewVecDense(n, nil)
409+
err := a.SolveVecTo(x, false, b)
410+
if err == nil {
411+
nc.cubic.fitWithSecondDerivatives(xs, ys, x.RawVector().Data)
412+
}
413+
return err
414+
}

interp/cubic_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,3 +671,134 @@ func TestFritschButlandErrors(t *testing.T) {
671671
}
672672
}
673673
}
674+
675+
func TestPiecewiseCubicFitWithSecondDerivatives(t *testing.T) {
676+
t.Parallel()
677+
const tol = 1e-14
678+
xs := []float64{-2, 0, 3}
679+
ys := []float64{2.5, 1, 2.5}
680+
d2ydx2s := []float64{1, 2, 3}
681+
var pc PiecewiseCubic
682+
pc.fitWithSecondDerivatives(xs, ys, d2ydx2s)
683+
m := len(xs) - 1
684+
if pc.lastY != ys[m] {
685+
t.Errorf("Mismatch in lastY: got %v, want %g", pc.lastY, ys[m])
686+
}
687+
if !floats.Equal(pc.xs, xs) {
688+
t.Errorf("Mismatch in xs: got %v, want %v", pc.xs, xs)
689+
}
690+
for i := 0; i < len(xs); i++ {
691+
yHat := pc.Predict(xs[i])
692+
if math.Abs(yHat-ys[i]) > tol {
693+
t.Errorf("Mismatch in predicted Y[%d]: got %v, want %g", i, yHat, ys[i])
694+
}
695+
var d2ydx2Hat float64
696+
if i < m {
697+
d2ydx2Hat = 2 * pc.coeffs.At(i, 2)
698+
} else {
699+
d2ydx2Hat = 2*pc.coeffs.At(m-1, 2) + 6*pc.coeffs.At(m-1, 3)*(xs[m]-xs[m-1])
700+
}
701+
if math.Abs(d2ydx2Hat-d2ydx2s[i]) > tol {
702+
t.Errorf("Mismatch in predicted d2Y/dX2[%d]: got %v, want %g", i, d2ydx2Hat, d2ydx2s[i])
703+
}
704+
}
705+
// Test pc.lastDyDx without copying verbatim the calculation from the tested method:
706+
lastDyDx := pc.PredictDerivative(xs[m] - tol/1000)
707+
if math.Abs(lastDyDx-pc.lastDyDx) > tol {
708+
t.Errorf("Mismatch in lastDxDy: got %v, want %g", pc.lastDyDx, lastDyDx)
709+
}
710+
}
711+
712+
func TestPiecewiseCubicFitWithSecondDerivativesErrors(t *testing.T) {
713+
t.Parallel()
714+
for _, test := range []struct {
715+
xs, ys, d2ydx2s []float64
716+
}{
717+
{
718+
xs: []float64{0, 1, 2},
719+
ys: []float64{10, 20},
720+
d2ydx2s: []float64{0, 0, 0},
721+
},
722+
{
723+
xs: []float64{0, 1, 1},
724+
ys: []float64{10, 20, 30},
725+
d2ydx2s: []float64{0, 0, 0, 0},
726+
},
727+
{
728+
xs: []float64{0},
729+
ys: []float64{0},
730+
d2ydx2s: []float64{0},
731+
},
732+
{
733+
xs: []float64{0, 1, 1},
734+
ys: []float64{10, 20, 10},
735+
d2ydx2s: []float64{0, 0, 0},
736+
},
737+
} {
738+
var pc PiecewiseCubic
739+
if !panics(func() { pc.fitWithSecondDerivatives(test.xs, test.ys, test.d2ydx2s) }) {
740+
t.Errorf("expected panic for xs: %v, ys: %v and d2ydx2s: %v", test.xs, test.ys, test.d2ydx2s)
741+
}
742+
}
743+
}
744+
745+
func TestMakeCubicSplineSecondDerivativeEquations(t *testing.T) {
746+
t.Parallel()
747+
const tol = 1e-15
748+
xs := []float64{-1, 0, 2}
749+
ys := []float64{2, 0, 2}
750+
n := len(xs)
751+
A, b := makeCubicSplineSecondDerivativeEquations(xs, ys)
752+
if b.Len() != n {
753+
t.Errorf("Mismatch in b size: got %v, want %d", b.Len(), n)
754+
}
755+
r, c := A.Dims()
756+
if r != n || c != n {
757+
t.Errorf("Mismatch in A size: got %d x %d, want %d x %d", r, c, n, n)
758+
}
759+
expectedB := mat.NewVecDense(3, []float64{0, 3, 0})
760+
var diffB mat.VecDense
761+
diffB.SubVec(b, expectedB)
762+
if diffB.Norm(math.Inf(1)) > tol {
763+
t.Errorf("Mismatch in b values: got %v, want %v", b, expectedB)
764+
}
765+
expectedA := mat.NewDense(3, 3, []float64{0, 0, 0, 1 / 6., 1, 2 / 6., 0, 0, 0})
766+
var diffA mat.Dense
767+
diffA.Sub(A, expectedA)
768+
if diffA.Norm(math.Inf(1)) > tol {
769+
t.Errorf("Mismatch in A values: got %v, want %v", A, expectedA)
770+
}
771+
}
772+
773+
func TestNaturalCubicFit(t *testing.T) {
774+
t.Parallel()
775+
xs := []float64{-1, 0, 2, 3.5}
776+
ys := []float64{2, 0, 2, 1.5}
777+
var nc NaturalCubic
778+
err := nc.Fit(xs, ys)
779+
if err != nil {
780+
t.Errorf("Error when fitting NaturalCubic: %v", err)
781+
}
782+
testXs := []float64{-1, -0.99, -0.5, 0, 0.5, 1, 1.5, 2, 2.5, 3, 3.49, 3.5}
783+
// From scipy.interpolate.CubicSpline:
784+
want := []float64{
785+
2.0,
786+
1.9737725526315788,
787+
0.7664473684210527,
788+
0.0,
789+
0.027960526315789477,
790+
0.6184210526315789,
791+
1.3996710526315788,
792+
2.0,
793+
2.1403508771929824,
794+
1.9122807017543857,
795+
1.508859403508772,
796+
1.5}
797+
for i := 0; i < len(testXs); i++ {
798+
got := nc.Predict(testXs[i])
799+
if math.Abs(got-want[i]) > 1e-14 {
800+
t.Errorf("Mismatch in predicted Y value for x = %g: got %v, want %g", testXs[i], got, want[i])
801+
}
802+
}
803+
804+
}

mat/band.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ type RawBander interface {
5151
// A MutableBanded can set elements of a band matrix.
5252
type MutableBanded interface {
5353
Banded
54+
55+
// SetBand sets the element at row i, column j to the value v.
56+
// It panics if the location is outside the appropriate region of the matrix.
5457
SetBand(i, j int, v float64)
5558
}
5659

0 commit comments

Comments
 (0)