mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00

* diff/fd: implement Hessian finite difference, and code cleanups. This commit primarily adds the Hessian function for finding a finite difference approximation to the Hessian. At the same time, it combines duplicated functionality across the difference routines so that the preludes to all the difference routines look similar
146 lines
2.8 KiB
Go
146 lines
2.8 KiB
Go
// Copyright ©2014 The gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package fd
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
)
|
|
|
|
var xSquared = func(x float64) float64 { return x * x }
|
|
|
|
type testPoint struct {
|
|
f func(float64) float64
|
|
loc float64
|
|
fofx float64
|
|
ans float64
|
|
}
|
|
|
|
var testsFirst = []testPoint{
|
|
{
|
|
f: xSquared,
|
|
loc: 0,
|
|
fofx: 0,
|
|
ans: 0,
|
|
},
|
|
{
|
|
f: xSquared,
|
|
loc: 5,
|
|
fofx: 25,
|
|
ans: 10,
|
|
},
|
|
{
|
|
f: xSquared,
|
|
loc: 2,
|
|
fofx: 4,
|
|
ans: 4,
|
|
},
|
|
{
|
|
f: xSquared,
|
|
loc: -5,
|
|
fofx: 25,
|
|
ans: -10,
|
|
},
|
|
}
|
|
|
|
var testsSecond = []testPoint{
|
|
{
|
|
f: xSquared,
|
|
loc: 0,
|
|
fofx: 0,
|
|
ans: 2,
|
|
},
|
|
{
|
|
f: xSquared,
|
|
loc: 5,
|
|
fofx: 25,
|
|
ans: 2,
|
|
},
|
|
{
|
|
f: xSquared,
|
|
loc: 2,
|
|
fofx: 4,
|
|
ans: 2,
|
|
},
|
|
{
|
|
f: xSquared,
|
|
loc: -5,
|
|
fofx: 25,
|
|
ans: 2,
|
|
},
|
|
}
|
|
|
|
func testDerivative(t *testing.T, formula Formula, tol float64, tests []testPoint) {
|
|
for i, test := range tests {
|
|
|
|
ans := Derivative(test.f, test.loc, &Settings{
|
|
Formula: formula,
|
|
})
|
|
if math.Abs(test.ans-ans) > tol {
|
|
t.Errorf("Case %v: ans mismatch serial: expected %v, found %v", i, test.ans, ans)
|
|
}
|
|
|
|
ans = Derivative(test.f, test.loc, &Settings{
|
|
Formula: formula,
|
|
OriginKnown: true,
|
|
OriginValue: test.fofx,
|
|
})
|
|
if math.Abs(test.ans-ans) > tol {
|
|
t.Errorf("Case %v: ans mismatch serial origin known: expected %v, found %v", i, test.ans, ans)
|
|
}
|
|
|
|
ans = Derivative(test.f, test.loc, &Settings{
|
|
Formula: formula,
|
|
Concurrent: true,
|
|
})
|
|
if math.Abs(test.ans-ans) > tol {
|
|
t.Errorf("Case %v: ans mismatch concurrent: expected %v, found %v", i, test.ans, ans)
|
|
}
|
|
|
|
ans = Derivative(test.f, test.loc, &Settings{
|
|
Formula: formula,
|
|
OriginKnown: true,
|
|
OriginValue: test.fofx,
|
|
Concurrent: true,
|
|
})
|
|
if math.Abs(test.ans-ans) > tol {
|
|
t.Errorf("Case %v: ans mismatch concurrent: expected %v, found %v", i, test.ans, ans)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestForward(t *testing.T) {
|
|
testDerivative(t, Forward, 2e-4, testsFirst)
|
|
}
|
|
|
|
func TestBackward(t *testing.T) {
|
|
testDerivative(t, Backward, 2e-4, testsFirst)
|
|
}
|
|
|
|
func TestCentral(t *testing.T) {
|
|
testDerivative(t, Central, 1e-6, testsFirst)
|
|
}
|
|
|
|
func TestCentralSecond(t *testing.T) {
|
|
testDerivative(t, Central2nd, 1e-3, testsSecond)
|
|
}
|
|
|
|
// TestDerivativeDefault checks that the derivative works when settings is nil
|
|
// or zero value.
|
|
func TestDerivativeDefault(t *testing.T) {
|
|
tol := 1e-6
|
|
for i, test := range testsFirst {
|
|
ans := Derivative(test.f, test.loc, nil)
|
|
if math.Abs(test.ans-ans) > tol {
|
|
t.Errorf("Case %v: ans mismatch default: expected %v, found %v", i, test.ans, ans)
|
|
}
|
|
|
|
ans = Derivative(test.f, test.loc, &Settings{})
|
|
if math.Abs(test.ans-ans) > tol {
|
|
t.Errorf("Case %v: ans mismatch zero value: expected %v, found %v", i, test.ans, ans)
|
|
}
|
|
}
|
|
}
|