// Copyright ©2014 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package fd import ( "math" "testing" "golang.org/x/exp/rand" "gonum.org/v1/gonum/floats" ) type Rosenbrock struct { nDim int } func (r Rosenbrock) F(x []float64) (sum float64) { deriv := make([]float64, len(x)) return r.FDf(x, deriv) } func (r Rosenbrock) FDf(x []float64, deriv []float64) (sum float64) { for i := range deriv { deriv[i] = 0 } for i := 0; i < len(x)-1; i++ { sum += math.Pow(1-x[i], 2) + 100*math.Pow(x[i+1]-math.Pow(x[i], 2), 2) } for i := 0; i < len(x)-1; i++ { deriv[i] += -1 * 2 * (1 - x[i]) deriv[i] += 2 * 100 * (x[i+1] - math.Pow(x[i], 2)) * (-2 * x[i]) } for i := 1; i < len(x); i++ { deriv[i] += 2 * 100 * (x[i] - math.Pow(x[i-1], 2)) } return sum } func TestGradient(t *testing.T) { t.Parallel() rnd := rand.New(rand.NewSource(1)) for i, test := range []struct { nDim int tol float64 formula Formula }{ { nDim: 2, tol: 2e-4, formula: Forward, }, { nDim: 2, tol: 1e-6, formula: Central, }, { nDim: 40, tol: 2e-4, formula: Forward, }, { nDim: 40, tol: 1e-5, formula: Central, }, } { x := make([]float64, test.nDim) for i := range x { x[i] = rnd.Float64() } xcopy := make([]float64, len(x)) copy(xcopy, x) r := Rosenbrock{len(x)} trueGradient := make([]float64, len(x)) r.FDf(x, trueGradient) // Try with gradient nil. gradient := Gradient(nil, r.F, x, &Settings{ Formula: test.formula, }) if !floats.EqualApprox(gradient, trueGradient, test.tol) { t.Errorf("Case %v: gradient mismatch in serial with nil. Want: %v, Got: %v.", i, trueGradient, gradient) } if !floats.Equal(x, xcopy) { t.Errorf("Case %v: x modified during call to gradient in serial with nil.", i) } // Try with provided gradient. for i := range gradient { gradient[i] = rnd.Float64() } Gradient(gradient, r.F, x, &Settings{ Formula: test.formula, }) if !floats.EqualApprox(gradient, trueGradient, test.tol) { t.Errorf("Case %v: gradient mismatch in serial. Want: %v, Got: %v.", i, trueGradient, gradient) } if !floats.Equal(x, xcopy) { t.Errorf("Case %v: x modified during call to gradient in serial with non-nil.", i) } // Try with known value. for i := range gradient { gradient[i] = rnd.Float64() } Gradient(gradient, r.F, x, &Settings{ Formula: test.formula, OriginKnown: true, OriginValue: r.F(x), }) if !floats.EqualApprox(gradient, trueGradient, test.tol) { t.Errorf("Case %v: gradient mismatch with known origin in serial. Want: %v, Got: %v.", i, trueGradient, gradient) } // Try with concurrent evaluation. for i := range gradient { gradient[i] = rnd.Float64() } Gradient(gradient, r.F, x, &Settings{ Formula: test.formula, Concurrent: true, }) if !floats.EqualApprox(gradient, trueGradient, test.tol) { t.Errorf("Case %v: gradient mismatch with unknown origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient) } if !floats.Equal(x, xcopy) { t.Errorf("Case %v: x modified during call to gradient in parallel", i) } // Try with concurrent evaluation with origin known. for i := range gradient { gradient[i] = rnd.Float64() } Gradient(gradient, r.F, x, &Settings{ Formula: test.formula, Concurrent: true, OriginKnown: true, OriginValue: r.F(x), }) if !floats.EqualApprox(gradient, trueGradient, test.tol) { t.Errorf("Case %v: gradient mismatch with known origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient) } // Try with nil settings. for i := range gradient { gradient[i] = rnd.Float64() } Gradient(gradient, r.F, x, nil) if !floats.EqualApprox(gradient, trueGradient, test.tol) { t.Errorf("Case %v: gradient mismatch with default settings. Want: %v, Got: %v.", i, trueGradient, gradient) } // Try with zero-valued settings. for i := range gradient { gradient[i] = rnd.Float64() } Gradient(gradient, r.F, x, &Settings{}) if !floats.EqualApprox(gradient, trueGradient, test.tol) { t.Errorf("Case %v: gradient mismatch with zero settings. Want: %v, Got: %v.", i, trueGradient, gradient) } } } func Panics(fun func()) (b bool) { defer func() { err := recover() if err != nil { b = true } }() fun() return } func TestGradientPanics(t *testing.T) { t.Parallel() // Test that it panics if !Panics(func() { Gradient([]float64{0.0}, func(x []float64) float64 { return x[0] * x[0] }, []float64{0.0, 0.0}, nil) }) { t.Errorf("Gradient did not panic with length mismatch") } }