mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00
97 lines
2.0 KiB
Go
97 lines
2.0 KiB
Go
// Copyright ©2017 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package fd
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"gonum.org/v1/gonum/mat"
|
|
)
|
|
|
|
type HessianTester interface {
|
|
Func(x []float64) float64
|
|
Grad(grad, x []float64)
|
|
Hess(dst mat.MutableSymmetric, x []float64)
|
|
}
|
|
|
|
var hessianTestCases = []struct {
|
|
h HessianTester
|
|
x []float64
|
|
settings *Settings
|
|
tol float64
|
|
}{
|
|
{
|
|
h: Watson{},
|
|
x: []float64{0.2, 0.3, 0.1, 0.4},
|
|
tol: 1e-3,
|
|
},
|
|
{
|
|
h: Watson{},
|
|
x: []float64{2, 3, 1, 4},
|
|
tol: 1e-3,
|
|
settings: &Settings{
|
|
Step: 1e-5,
|
|
Formula: Central,
|
|
},
|
|
},
|
|
{
|
|
h: Watson{},
|
|
x: []float64{2, 3, 1},
|
|
tol: 1e-3,
|
|
settings: &Settings{
|
|
OriginKnown: true,
|
|
OriginValue: 7606.529501201192,
|
|
},
|
|
},
|
|
{
|
|
h: ConstFunc(5),
|
|
x: []float64{1, 9},
|
|
tol: 1e-16,
|
|
},
|
|
{
|
|
h: LinearFunc{w: []float64{10, 6, -1}, c: 5},
|
|
x: []float64{3, 1, 8},
|
|
tol: 1e-6,
|
|
},
|
|
{
|
|
h: QuadFunc{
|
|
a: mat.NewSymDense(3, []float64{
|
|
10, 2, 1,
|
|
2, 5, -3,
|
|
1, -3, 6,
|
|
}),
|
|
b: mat.NewVecDense(3, []float64{3, -2, -1}),
|
|
c: 5,
|
|
},
|
|
x: []float64{-1.6, -3, 2},
|
|
tol: 1e-6,
|
|
},
|
|
}
|
|
|
|
func TestHessian(t *testing.T) {
|
|
for cas, test := range hessianTestCases {
|
|
n := len(test.x)
|
|
var got mat.SymDense
|
|
Hessian(&got, test.h.Func, test.x, test.settings)
|
|
want := mat.NewSymDense(n, nil)
|
|
test.h.Hess(want, test.x)
|
|
if !mat.EqualApprox(&got, want, test.tol) {
|
|
t.Errorf("Cas %d: Hessian mismatch\ngot=\n%0.4v\nwant=\n%0.4v\n", cas, mat.Formatted(&got), mat.Formatted(want))
|
|
}
|
|
|
|
// Test that concurrency works.
|
|
settings := test.settings
|
|
if settings == nil {
|
|
settings = &Settings{}
|
|
}
|
|
settings.Concurrent = true
|
|
var got2 mat.SymDense
|
|
Hessian(&got2, test.h.Func, test.x, settings)
|
|
if !mat.EqualApprox(&got, &got2, 1e-5) {
|
|
t.Errorf("Cas %d: Hessian mismatch concurrent\ngot=\n%0.6v\nwant=\n%0.6v\n", cas, mat.Formatted(&got2), mat.Formatted(&got))
|
|
}
|
|
}
|
|
}
|