mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 15:16:59 +08:00
189 lines
4.6 KiB
Go
189 lines
4.6 KiB
Go
// Copyright ©2014 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package fd
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
|
|
"golang.org/x/exp/rand"
|
|
|
|
"gonum.org/v1/gonum/floats"
|
|
)
|
|
|
|
type Rosenbrock struct {
|
|
nDim int
|
|
}
|
|
|
|
func (r Rosenbrock) F(x []float64) (sum float64) {
|
|
deriv := make([]float64, len(x))
|
|
return r.FDf(x, deriv)
|
|
}
|
|
|
|
func (r Rosenbrock) FDf(x []float64, deriv []float64) (sum float64) {
|
|
for i := range deriv {
|
|
deriv[i] = 0
|
|
}
|
|
|
|
for i := 0; i < len(x)-1; i++ {
|
|
sum += math.Pow(1-x[i], 2) + 100*math.Pow(x[i+1]-math.Pow(x[i], 2), 2)
|
|
}
|
|
for i := 0; i < len(x)-1; i++ {
|
|
deriv[i] += -1 * 2 * (1 - x[i])
|
|
deriv[i] += 2 * 100 * (x[i+1] - math.Pow(x[i], 2)) * (-2 * x[i])
|
|
}
|
|
for i := 1; i < len(x); i++ {
|
|
deriv[i] += 2 * 100 * (x[i] - math.Pow(x[i-1], 2))
|
|
}
|
|
|
|
return sum
|
|
}
|
|
|
|
func TestGradient(t *testing.T) {
|
|
rand.Seed(1)
|
|
for i, test := range []struct {
|
|
nDim int
|
|
tol float64
|
|
formula Formula
|
|
}{
|
|
{
|
|
nDim: 2,
|
|
tol: 2e-4,
|
|
formula: Forward,
|
|
},
|
|
{
|
|
nDim: 2,
|
|
tol: 1e-6,
|
|
formula: Central,
|
|
},
|
|
{
|
|
nDim: 40,
|
|
tol: 2e-4,
|
|
formula: Forward,
|
|
},
|
|
{
|
|
nDim: 40,
|
|
tol: 1e-5,
|
|
formula: Central,
|
|
},
|
|
} {
|
|
x := make([]float64, test.nDim)
|
|
for i := range x {
|
|
x[i] = rand.Float64()
|
|
}
|
|
xcopy := make([]float64, len(x))
|
|
copy(xcopy, x)
|
|
|
|
r := Rosenbrock{len(x)}
|
|
trueGradient := make([]float64, len(x))
|
|
r.FDf(x, trueGradient)
|
|
|
|
// Try with gradient nil.
|
|
gradient := Gradient(nil, r.F, x, &Settings{
|
|
Formula: test.formula,
|
|
})
|
|
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
|
|
t.Errorf("Case %v: gradient mismatch in serial with nil. Want: %v, Got: %v.", i, trueGradient, gradient)
|
|
}
|
|
if !floats.Equal(x, xcopy) {
|
|
t.Errorf("Case %v: x modified during call to gradient in serial with nil.", i)
|
|
}
|
|
|
|
// Try with provided gradient.
|
|
for i := range gradient {
|
|
gradient[i] = rand.Float64()
|
|
}
|
|
Gradient(gradient, r.F, x, &Settings{
|
|
Formula: test.formula,
|
|
})
|
|
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
|
|
t.Errorf("Case %v: gradient mismatch in serial. Want: %v, Got: %v.", i, trueGradient, gradient)
|
|
}
|
|
if !floats.Equal(x, xcopy) {
|
|
t.Errorf("Case %v: x modified during call to gradient in serial with non-nil.", i)
|
|
}
|
|
|
|
// Try with known value.
|
|
for i := range gradient {
|
|
gradient[i] = rand.Float64()
|
|
}
|
|
Gradient(gradient, r.F, x, &Settings{
|
|
Formula: test.formula,
|
|
OriginKnown: true,
|
|
OriginValue: r.F(x),
|
|
})
|
|
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
|
|
t.Errorf("Case %v: gradient mismatch with known origin in serial. Want: %v, Got: %v.", i, trueGradient, gradient)
|
|
}
|
|
|
|
// Try with concurrent evaluation.
|
|
for i := range gradient {
|
|
gradient[i] = rand.Float64()
|
|
}
|
|
Gradient(gradient, r.F, x, &Settings{
|
|
Formula: test.formula,
|
|
Concurrent: true,
|
|
})
|
|
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
|
|
t.Errorf("Case %v: gradient mismatch with unknown origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient)
|
|
}
|
|
if !floats.Equal(x, xcopy) {
|
|
t.Errorf("Case %v: x modified during call to gradient in parallel", i)
|
|
}
|
|
|
|
// Try with concurrent evaluation with origin known.
|
|
for i := range gradient {
|
|
gradient[i] = rand.Float64()
|
|
}
|
|
Gradient(gradient, r.F, x, &Settings{
|
|
Formula: test.formula,
|
|
Concurrent: true,
|
|
OriginKnown: true,
|
|
OriginValue: r.F(x),
|
|
})
|
|
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
|
|
t.Errorf("Case %v: gradient mismatch with known origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient)
|
|
}
|
|
|
|
// Try with nil settings.
|
|
for i := range gradient {
|
|
gradient[i] = rand.Float64()
|
|
}
|
|
Gradient(gradient, r.F, x, nil)
|
|
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
|
|
t.Errorf("Case %v: gradient mismatch with default settings. Want: %v, Got: %v.", i, trueGradient, gradient)
|
|
}
|
|
|
|
// Try with zero-valued settings.
|
|
for i := range gradient {
|
|
gradient[i] = rand.Float64()
|
|
}
|
|
Gradient(gradient, r.F, x, &Settings{})
|
|
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
|
|
t.Errorf("Case %v: gradient mismatch with zero settings. Want: %v, Got: %v.", i, trueGradient, gradient)
|
|
}
|
|
}
|
|
}
|
|
|
|
func Panics(fun func()) (b bool) {
|
|
defer func() {
|
|
err := recover()
|
|
if err != nil {
|
|
b = true
|
|
}
|
|
}()
|
|
fun()
|
|
return
|
|
}
|
|
|
|
func TestGradientPanics(t *testing.T) {
|
|
// Test that it panics
|
|
if !Panics(func() {
|
|
Gradient([]float64{0.0}, func(x []float64) float64 { return x[0] * x[0] }, []float64{0.0, 0.0}, nil)
|
|
}) {
|
|
t.Errorf("Gradient did not panic with length mismatch")
|
|
}
|
|
}
|