mirror of
				https://github.com/gonum/gonum.git
				synced 2025-10-25 08:10:28 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			113 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			113 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright ©2017 The Gonum Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package fd
 | |
| 
 | |
| import (
 | |
| 	"testing"
 | |
| 
 | |
| 	"gonum.org/v1/gonum/mat"
 | |
| )
 | |
| 
 | |
| type HessianTester interface {
 | |
| 	Func(x []float64) float64
 | |
| 	Grad(grad, x []float64)
 | |
| 	Hess(dst mat.MutableSymmetric, x []float64)
 | |
| }
 | |
| 
 | |
| type hessianTestCase struct {
 | |
| 	h        HessianTester
 | |
| 	x        []float64
 | |
| 	settings *Settings
 | |
| 	tol      float64
 | |
| }
 | |
| 
 | |
| var _hessianTestCases = []hessianTestCase{
 | |
| 	{
 | |
| 		h:   Watson{},
 | |
| 		x:   []float64{0.2, 0.3, 0.1, 0.4},
 | |
| 		tol: 1e-3,
 | |
| 	},
 | |
| 	{
 | |
| 		h:   Watson{},
 | |
| 		x:   []float64{2, 3, 1, 4},
 | |
| 		tol: 1e-3,
 | |
| 		settings: &Settings{
 | |
| 			Step:    1e-5,
 | |
| 			Formula: Central,
 | |
| 		},
 | |
| 	},
 | |
| 	{
 | |
| 		h:   Watson{},
 | |
| 		x:   []float64{2, 3, 1},
 | |
| 		tol: 1e-3,
 | |
| 		settings: &Settings{
 | |
| 			OriginKnown: true,
 | |
| 			OriginValue: 7606.529501201192,
 | |
| 		},
 | |
| 	},
 | |
| 	{
 | |
| 		h:   ConstFunc(5),
 | |
| 		x:   []float64{1, 9},
 | |
| 		tol: 1e-16,
 | |
| 	},
 | |
| 	{
 | |
| 		h:   LinearFunc{w: []float64{10, 6, -1}, c: 5},
 | |
| 		x:   []float64{3, 1, 8},
 | |
| 		tol: 1e-6,
 | |
| 	},
 | |
| 	{
 | |
| 		h: QuadFunc{
 | |
| 			a: mat.NewSymDense(3, []float64{
 | |
| 				10, 2, 1,
 | |
| 				2, 5, -3,
 | |
| 				1, -3, 6,
 | |
| 			}),
 | |
| 			b: mat.NewVecDense(3, []float64{3, -2, -1}),
 | |
| 			c: 5,
 | |
| 		},
 | |
| 		x:   []float64{-1.6, -3, 2},
 | |
| 		tol: 1e-6,
 | |
| 	},
 | |
| }
 | |
| 
 | |
| func hessianTestCases() []hessianTestCase {
 | |
| 	xs := []hessianTestCase{}
 | |
| 	for _, test := range _hessianTestCases {
 | |
| 		n := test
 | |
| 		if test.settings != nil {
 | |
| 			clone := *test.settings
 | |
| 			n.settings = &clone
 | |
| 		}
 | |
| 		xs = append(xs, n)
 | |
| 	}
 | |
| 	return xs
 | |
| }
 | |
| 
 | |
| func TestHessian(t *testing.T) {
 | |
| 	t.Parallel()
 | |
| 	for cas, test := range hessianTestCases() {
 | |
| 		n := len(test.x)
 | |
| 		var got mat.SymDense
 | |
| 		Hessian(&got, test.h.Func, test.x, test.settings)
 | |
| 		want := mat.NewSymDense(n, nil)
 | |
| 		test.h.Hess(want, test.x)
 | |
| 		if !mat.EqualApprox(&got, want, test.tol) {
 | |
| 			t.Errorf("Cas %d: Hessian mismatch\ngot=\n%0.4v\nwant=\n%0.4v\n", cas, mat.Formatted(&got), mat.Formatted(want))
 | |
| 		}
 | |
| 
 | |
| 		// Test that concurrency works.
 | |
| 		settings := test.settings
 | |
| 		if settings == nil {
 | |
| 			settings = &Settings{}
 | |
| 		}
 | |
| 		settings.Concurrent = true
 | |
| 		var got2 mat.SymDense
 | |
| 		Hessian(&got2, test.h.Func, test.x, settings)
 | |
| 		if !mat.EqualApprox(&got, &got2, 1e-5) {
 | |
| 			t.Errorf("Cas %d: Hessian mismatch concurrent\ngot=\n%0.6v\nwant=\n%0.6v\n", cas, mat.Formatted(&got2), mat.Formatted(&got))
 | |
| 		}
 | |
| 	}
 | |
| }
 | 
