mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 15:16:59 +08:00
Add Laplacian and CrossLaplacian difference functions (#154)
* Add Laplacian and CrossLaplacian difference functions * use usesOrigin
This commit is contained in:
186
diff/fd/crosslaplacian.go
Normal file
186
diff/fd/crosslaplacian.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
// Copyright ©2017 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package fd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CrossLaplacian computes a Laplacian-like quantity for a function of two vectors
|
||||||
|
// at the locations x and y.
|
||||||
|
// It computes
|
||||||
|
// ∇_y · ∇_x f(x,y) = \sum_i ∂^2 f(x,y)/∂x_i ∂y_i
|
||||||
|
// The two input vector lengths must be the same.
|
||||||
|
//
|
||||||
|
// Finite difference formula and other options are specified by settings. If
|
||||||
|
// settings is nil, CrossLaplacian will be estimated using the Forward formula and
|
||||||
|
// a default step size.
|
||||||
|
//
|
||||||
|
// CrossLaplacian panics if the two input vectors are not the same length, or if
|
||||||
|
// the derivative order of the formula is not 1.
|
||||||
|
func CrossLaplacian(f func(x, y []float64) float64, x, y []float64, settings *Settings) float64 {
|
||||||
|
n := len(x)
|
||||||
|
if n == 0 {
|
||||||
|
panic("crosslaplacian: x has zero length")
|
||||||
|
}
|
||||||
|
if len(x) != len(y) {
|
||||||
|
panic("crosslaplacian: input vector length mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default settings.
|
||||||
|
formula := Forward
|
||||||
|
step := math.Sqrt(formula.Step) // Use the sqrt because taking derivatives of derivatives.
|
||||||
|
var originValue float64
|
||||||
|
var originKnown, concurrent bool
|
||||||
|
|
||||||
|
// Use user settings if provided.
|
||||||
|
if settings != nil {
|
||||||
|
if !settings.Formula.isZero() {
|
||||||
|
formula = settings.Formula
|
||||||
|
step = math.Sqrt(formula.Step)
|
||||||
|
checkFormula(formula)
|
||||||
|
if formula.Derivative != 1 {
|
||||||
|
panic(badDerivOrder)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if settings.Step != 0 {
|
||||||
|
if settings.Step < 0 {
|
||||||
|
panic(negativeStep)
|
||||||
|
}
|
||||||
|
step = settings.Step
|
||||||
|
}
|
||||||
|
originKnown = settings.OriginKnown
|
||||||
|
originValue = settings.OriginValue
|
||||||
|
concurrent = settings.Concurrent
|
||||||
|
}
|
||||||
|
|
||||||
|
evals := n * len(formula.Stencil) * len(formula.Stencil)
|
||||||
|
if usesOrigin(formula.Stencil) {
|
||||||
|
evals -= n
|
||||||
|
}
|
||||||
|
|
||||||
|
nWorkers := computeWorkers(concurrent, evals)
|
||||||
|
if nWorkers == 1 {
|
||||||
|
return crossLaplacianSerial(f, x, y, formula.Stencil, step, originKnown, originValue)
|
||||||
|
}
|
||||||
|
return crossLaplacianConcurrent(nWorkers, evals, f, x, y, formula.Stencil, step, originKnown, originValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
func crossLaplacianSerial(f func(x, y []float64) float64, x, y []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
|
||||||
|
n := len(x)
|
||||||
|
xCopy := make([]float64, len(x))
|
||||||
|
yCopy := make([]float64, len(y))
|
||||||
|
fo := func() float64 {
|
||||||
|
// Copy x and y in case they are modified during the call.
|
||||||
|
copy(xCopy, x)
|
||||||
|
copy(yCopy, y)
|
||||||
|
return f(x, y)
|
||||||
|
}
|
||||||
|
origin := getOrigin(originKnown, originValue, fo, stencil)
|
||||||
|
|
||||||
|
is2 := 1 / (step * step)
|
||||||
|
var laplacian float64
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for _, pty := range stencil {
|
||||||
|
for _, ptx := range stencil {
|
||||||
|
var v float64
|
||||||
|
if ptx.Loc == 0 && pty.Loc == 0 {
|
||||||
|
v = origin
|
||||||
|
} else {
|
||||||
|
// Copying the data anew has two benefits. First, it
|
||||||
|
// avoids floating point issues where adding and then
|
||||||
|
// subtracting the step don't return to the exact same
|
||||||
|
// location. Secondly, it protects against the function
|
||||||
|
// modifying the input data.
|
||||||
|
copy(yCopy, y)
|
||||||
|
copy(xCopy, x)
|
||||||
|
yCopy[i] += pty.Loc * step
|
||||||
|
xCopy[i] += ptx.Loc * step
|
||||||
|
v = f(xCopy, yCopy)
|
||||||
|
}
|
||||||
|
laplacian += v * ptx.Coeff * pty.Coeff * is2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return laplacian
|
||||||
|
}
|
||||||
|
|
||||||
|
func crossLaplacianConcurrent(nWorkers, evals int, f func(x, y []float64) float64, x, y []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
|
||||||
|
n := len(x)
|
||||||
|
type run struct {
|
||||||
|
i int
|
||||||
|
xIdx, yIdx int
|
||||||
|
result float64
|
||||||
|
}
|
||||||
|
|
||||||
|
send := make(chan run, evals)
|
||||||
|
ans := make(chan run, evals)
|
||||||
|
|
||||||
|
var originWG sync.WaitGroup
|
||||||
|
hasOrigin := usesOrigin(stencil)
|
||||||
|
if hasOrigin {
|
||||||
|
originWG.Add(1)
|
||||||
|
// Launch worker to compute the origin.
|
||||||
|
go func() {
|
||||||
|
defer originWG.Done()
|
||||||
|
xCopy := make([]float64, len(x))
|
||||||
|
yCopy := make([]float64, len(y))
|
||||||
|
copy(xCopy, x)
|
||||||
|
copy(yCopy, y)
|
||||||
|
originValue = f(xCopy, yCopy)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
var workerWG sync.WaitGroup
|
||||||
|
// Launch workers.
|
||||||
|
for i := 0; i < nWorkers; i++ {
|
||||||
|
workerWG.Add(1)
|
||||||
|
go func(send <-chan run, ans chan<- run) {
|
||||||
|
defer workerWG.Done()
|
||||||
|
xCopy := make([]float64, len(x))
|
||||||
|
yCopy := make([]float64, len(y))
|
||||||
|
for r := range send {
|
||||||
|
if stencil[r.xIdx].Loc == 0 && stencil[r.yIdx].Loc == 0 {
|
||||||
|
originWG.Wait()
|
||||||
|
r.result = originValue
|
||||||
|
} else {
|
||||||
|
// See crossLaplacianSerial for comment on the copy.
|
||||||
|
copy(xCopy, x)
|
||||||
|
copy(yCopy, y)
|
||||||
|
xCopy[r.i] += stencil[r.xIdx].Loc * step
|
||||||
|
yCopy[r.i] += stencil[r.yIdx].Loc * step
|
||||||
|
r.result = f(xCopy, yCopy)
|
||||||
|
}
|
||||||
|
ans <- r
|
||||||
|
}
|
||||||
|
}(send, ans)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch the distributor, which sends all of runs.
|
||||||
|
go func(send chan<- run) {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for xIdx := range stencil {
|
||||||
|
for yIdx := range stencil {
|
||||||
|
send <- run{
|
||||||
|
i: i, xIdx: xIdx, yIdx: yIdx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(send)
|
||||||
|
// Wait for all the workers to quit, then close the ans channel.
|
||||||
|
workerWG.Wait()
|
||||||
|
close(ans)
|
||||||
|
}(send)
|
||||||
|
|
||||||
|
// Read in the results.
|
||||||
|
is2 := 1 / (step * step)
|
||||||
|
var laplacian float64
|
||||||
|
for r := range ans {
|
||||||
|
laplacian += r.result * stencil[r.xIdx].Coeff * stencil[r.yIdx].Coeff * is2
|
||||||
|
}
|
||||||
|
return laplacian
|
||||||
|
}
|
111
diff/fd/crosslaplacian_test.go
Normal file
111
diff/fd/crosslaplacian_test.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
// Copyright ©2017 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package fd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gonum.org/v1/gonum/floats"
|
||||||
|
"gonum.org/v1/gonum/mat"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CrossLaplacianTester interface {
|
||||||
|
Func(x, y []float64) float64
|
||||||
|
CrossLaplacian(x, y []float64) float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type WrapperCL struct {
|
||||||
|
Tester HessianTester
|
||||||
|
}
|
||||||
|
|
||||||
|
func (WrapperCL) constructZ(x, y []float64) []float64 {
|
||||||
|
z := make([]float64, len(x)+len(y))
|
||||||
|
copy(z, x)
|
||||||
|
copy(z[len(x):], y)
|
||||||
|
return z
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w WrapperCL) Func(x, y []float64) float64 {
|
||||||
|
z := w.constructZ(x, y)
|
||||||
|
return w.Tester.Func(z)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w WrapperCL) CrossLaplacian(x, y []float64) float64 {
|
||||||
|
z := w.constructZ(x, y)
|
||||||
|
hess := mat.NewSymDense(len(z), nil)
|
||||||
|
w.Tester.Hess(hess, z)
|
||||||
|
// The CrossLaplacian is the trace of the off-diagonal block of the Hessian.
|
||||||
|
var l float64
|
||||||
|
for i := 0; i < len(x); i++ {
|
||||||
|
l += hess.At(i, i+len(x))
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossLaplacian(t *testing.T) {
|
||||||
|
for cas, test := range []struct {
|
||||||
|
l CrossLaplacianTester
|
||||||
|
x, y []float64
|
||||||
|
settings *Settings
|
||||||
|
tol float64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
l: WrapperCL{Watson{}},
|
||||||
|
x: []float64{0.2, 0.3},
|
||||||
|
y: []float64{0.1, 0.4},
|
||||||
|
tol: 1e-3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
l: WrapperCL{Watson{}},
|
||||||
|
x: []float64{2, 3, 1},
|
||||||
|
y: []float64{1, 4, 1},
|
||||||
|
tol: 1e-3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
l: WrapperCL{ConstFunc(6)},
|
||||||
|
x: []float64{2, -3, 1},
|
||||||
|
y: []float64{1, 4, -5},
|
||||||
|
tol: 1e-6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
l: WrapperCL{LinearFunc{w: []float64{10, 6, -1, 5}, c: 5}},
|
||||||
|
x: []float64{3, 1},
|
||||||
|
y: []float64{8, 6},
|
||||||
|
tol: 1e-6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
l: WrapperCL{QuadFunc{
|
||||||
|
a: mat.NewSymDense(4, []float64{
|
||||||
|
10, 2, 1, 9,
|
||||||
|
2, 5, -3, 4,
|
||||||
|
1, -3, 6, 2,
|
||||||
|
9, 4, 2, -14,
|
||||||
|
}),
|
||||||
|
b: mat.NewVecDense(4, []float64{3, -2, -1, 4}),
|
||||||
|
c: 5,
|
||||||
|
}},
|
||||||
|
x: []float64{-1.6, -3},
|
||||||
|
y: []float64{1.8, 3.4},
|
||||||
|
tol: 1e-6,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
got := CrossLaplacian(test.l.Func, test.x, test.y, test.settings)
|
||||||
|
want := test.l.CrossLaplacian(test.x, test.y)
|
||||||
|
if !floats.EqualWithinAbsOrRel(got, want, test.tol, test.tol) {
|
||||||
|
t.Errorf("Cas %d: CrossLaplacian mismatch serial. got %v, want %v", cas, got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that concurrency works.
|
||||||
|
settings := test.settings
|
||||||
|
if settings == nil {
|
||||||
|
settings = &Settings{}
|
||||||
|
}
|
||||||
|
settings.Concurrent = true
|
||||||
|
got2 := CrossLaplacian(test.l.Func, test.x, test.y, settings)
|
||||||
|
if !floats.EqualWithinAbsOrRel(got, got2, 1e-6, 1e-6) {
|
||||||
|
t.Errorf("Cas %d: Laplacian mismatch. got %v, want %v", cas, got2, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -51,7 +51,8 @@ func Gradient(dst []float64, f func([]float64) float64, x []float64, settings *S
|
|||||||
nWorkers := computeWorkers(concurrent, evals)
|
nWorkers := computeWorkers(concurrent, evals)
|
||||||
|
|
||||||
hasOrigin := usesOrigin(formula.Stencil)
|
hasOrigin := usesOrigin(formula.Stencil)
|
||||||
xcopy := make([]float64, len(x)) // So that x is not modified during the call.
|
// Copy x in case it is modified during the call.
|
||||||
|
xcopy := make([]float64, len(x))
|
||||||
if hasOrigin && !originKnown {
|
if hasOrigin && !originKnown {
|
||||||
copy(xcopy, x)
|
copy(xcopy, x)
|
||||||
originValue = f(xcopy)
|
originValue = f(xcopy)
|
||||||
@@ -65,7 +66,7 @@ func Gradient(dst []float64, f func([]float64) float64, x []float64, settings *S
|
|||||||
deriv += pt.Coeff * originValue
|
deriv += pt.Coeff * originValue
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Copying the code anew has two benefits. First, it
|
// Copying the data anew has two benefits. First, it
|
||||||
// avoids floating point issues where adding and then
|
// avoids floating point issues where adding and then
|
||||||
// subtracting the step don't return to the exact same
|
// subtracting the step don't return to the exact same
|
||||||
// location. Secondly, it protects against the function
|
// location. Secondly, it protects against the function
|
||||||
|
@@ -84,6 +84,7 @@ func hessianSerial(dst *mat.SymDense, f func(x []float64) float64, x []float64,
|
|||||||
n := len(x)
|
n := len(x)
|
||||||
xCopy := make([]float64, n)
|
xCopy := make([]float64, n)
|
||||||
fo := func() float64 {
|
fo := func() float64 {
|
||||||
|
// Copy x in case it is modified during the call.
|
||||||
copy(xCopy, x)
|
copy(xCopy, x)
|
||||||
return f(x)
|
return f(x)
|
||||||
}
|
}
|
||||||
@@ -98,7 +99,7 @@ func hessianSerial(dst *mat.SymDense, f func(x []float64) float64, x []float64,
|
|||||||
if pti.Loc == 0 && ptj.Loc == 0 {
|
if pti.Loc == 0 && ptj.Loc == 0 {
|
||||||
v = origin
|
v = origin
|
||||||
} else {
|
} else {
|
||||||
// Copying the code anew has two benefits. First, it
|
// Copying the data anew has two benefits. First, it
|
||||||
// avoids floating point issues where adding and then
|
// avoids floating point issues where adding and then
|
||||||
// subtracting the step don't return to the exact same
|
// subtracting the step don't return to the exact same
|
||||||
// location. Secondly, it protects against the function
|
// location. Secondly, it protects against the function
|
||||||
@@ -125,7 +126,7 @@ func hessianConcurrent(dst *mat.SymDense, nWorkers, evals int, f func(x []float6
|
|||||||
}
|
}
|
||||||
|
|
||||||
send := make(chan run, evals)
|
send := make(chan run, evals)
|
||||||
ans := make(chan run)
|
ans := make(chan run, evals)
|
||||||
|
|
||||||
var originWG sync.WaitGroup
|
var originWG sync.WaitGroup
|
||||||
hasOrigin := usesOrigin(stencil)
|
hasOrigin := usesOrigin(stencil)
|
||||||
|
@@ -16,8 +16,7 @@ type HessianTester interface {
|
|||||||
Hess(dst mat.MutableSymmetric, x []float64)
|
Hess(dst mat.MutableSymmetric, x []float64)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHessian(t *testing.T) {
|
var hessianTestCases = []struct {
|
||||||
for cas, test := range []struct {
|
|
||||||
h HessianTester
|
h HessianTester
|
||||||
x []float64
|
x []float64
|
||||||
settings *Settings
|
settings *Settings
|
||||||
@@ -69,7 +68,10 @@ func TestHessian(t *testing.T) {
|
|||||||
x: []float64{-1.6, -3, 2},
|
x: []float64{-1.6, -3, 2},
|
||||||
tol: 1e-6,
|
tol: 1e-6,
|
||||||
},
|
},
|
||||||
} {
|
}
|
||||||
|
|
||||||
|
func TestHessian(t *testing.T) {
|
||||||
|
for cas, test := range hessianTestCases {
|
||||||
n := len(test.x)
|
n := len(test.x)
|
||||||
got := Hessian(nil, test.h.Func, test.x, test.settings)
|
got := Hessian(nil, test.h.Func, test.x, test.settings)
|
||||||
want := mat.NewSymDense(n, nil)
|
want := mat.NewSymDense(n, nil)
|
||||||
|
158
diff/fd/laplacian.go
Normal file
158
diff/fd/laplacian.go
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
// Copyright ©2017 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package fd
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// Laplacian computes the Laplacian of the multivariate function f at the location
|
||||||
|
// x. That is, Laplacian returns
|
||||||
|
// ∆ f(x) = ∇ · ∇ f(x) = \sum_i ∂^2 f(x)/∂x_i^2
|
||||||
|
// The finite difference formula and other options are specified by settings.
|
||||||
|
// The order of the difference formula must be 2 or Laplacian will panic.
|
||||||
|
func Laplacian(f func(x []float64) float64, x []float64, settings *Settings) float64 {
|
||||||
|
n := len(x)
|
||||||
|
if n == 0 {
|
||||||
|
panic("laplacian: x has zero length")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default settings.
|
||||||
|
formula := Central2nd
|
||||||
|
step := formula.Step
|
||||||
|
var originValue float64
|
||||||
|
var originKnown, concurrent bool
|
||||||
|
|
||||||
|
// Use user settings if provided.
|
||||||
|
if settings != nil {
|
||||||
|
if !settings.Formula.isZero() {
|
||||||
|
formula = settings.Formula
|
||||||
|
step = formula.Step
|
||||||
|
checkFormula(formula)
|
||||||
|
if formula.Derivative != 2 {
|
||||||
|
panic(badDerivOrder)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if settings.Step != 0 {
|
||||||
|
if settings.Step < 0 {
|
||||||
|
panic(negativeStep)
|
||||||
|
}
|
||||||
|
step = settings.Step
|
||||||
|
}
|
||||||
|
originKnown = settings.OriginKnown
|
||||||
|
originValue = settings.OriginValue
|
||||||
|
concurrent = settings.Concurrent
|
||||||
|
}
|
||||||
|
|
||||||
|
evals := n * len(formula.Stencil)
|
||||||
|
if usesOrigin(formula.Stencil) {
|
||||||
|
evals -= n
|
||||||
|
}
|
||||||
|
|
||||||
|
nWorkers := computeWorkers(concurrent, evals)
|
||||||
|
if nWorkers == 1 {
|
||||||
|
return laplacianSerial(f, x, formula.Stencil, step, originKnown, originValue)
|
||||||
|
}
|
||||||
|
return laplacianConcurrent(nWorkers, evals, f, x, formula.Stencil, step, originKnown, originValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
func laplacianSerial(f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
|
||||||
|
n := len(x)
|
||||||
|
xCopy := make([]float64, n)
|
||||||
|
fo := func() float64 {
|
||||||
|
// Copy x in case it is modified during the call.
|
||||||
|
copy(xCopy, x)
|
||||||
|
return f(x)
|
||||||
|
}
|
||||||
|
is2 := 1 / (step * step)
|
||||||
|
origin := getOrigin(originKnown, originValue, fo, stencil)
|
||||||
|
var laplacian float64
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for _, pt := range stencil {
|
||||||
|
var v float64
|
||||||
|
if pt.Loc == 0 {
|
||||||
|
v = origin
|
||||||
|
} else {
|
||||||
|
// Copying the data anew has two benefits. First, it
|
||||||
|
// avoids floating point issues where adding and then
|
||||||
|
// subtracting the step don't return to the exact same
|
||||||
|
// location. Secondly, it protects against the function
|
||||||
|
// modifying the input data.
|
||||||
|
copy(xCopy, x)
|
||||||
|
xCopy[i] += pt.Loc * step
|
||||||
|
v = f(xCopy)
|
||||||
|
}
|
||||||
|
laplacian += v * pt.Coeff * is2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return laplacian
|
||||||
|
}
|
||||||
|
|
||||||
|
func laplacianConcurrent(nWorkers, evals int, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
|
||||||
|
type run struct {
|
||||||
|
i int
|
||||||
|
idx int
|
||||||
|
result float64
|
||||||
|
}
|
||||||
|
n := len(x)
|
||||||
|
send := make(chan run, evals)
|
||||||
|
ans := make(chan run, evals)
|
||||||
|
|
||||||
|
var originWG sync.WaitGroup
|
||||||
|
hasOrigin := usesOrigin(stencil)
|
||||||
|
if hasOrigin {
|
||||||
|
originWG.Add(1)
|
||||||
|
// Launch worker to compute the origin.
|
||||||
|
go func() {
|
||||||
|
defer originWG.Done()
|
||||||
|
xCopy := make([]float64, len(x))
|
||||||
|
copy(xCopy, x)
|
||||||
|
originValue = f(xCopy)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
var workerWG sync.WaitGroup
|
||||||
|
// Launch workers.
|
||||||
|
for i := 0; i < nWorkers; i++ {
|
||||||
|
workerWG.Add(1)
|
||||||
|
go func(send <-chan run, ans chan<- run) {
|
||||||
|
defer workerWG.Done()
|
||||||
|
xCopy := make([]float64, len(x))
|
||||||
|
for r := range send {
|
||||||
|
if stencil[r.idx].Loc == 0 {
|
||||||
|
originWG.Wait()
|
||||||
|
r.result = originValue
|
||||||
|
} else {
|
||||||
|
// See laplacianSerial for comment on the copy.
|
||||||
|
copy(xCopy, x)
|
||||||
|
xCopy[r.i] += stencil[r.idx].Loc * step
|
||||||
|
r.result = f(xCopy)
|
||||||
|
}
|
||||||
|
ans <- r
|
||||||
|
}
|
||||||
|
}(send, ans)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch the distributor, which sends all of runs.
|
||||||
|
go func(send chan<- run) {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for idx := range stencil {
|
||||||
|
send <- run{
|
||||||
|
i: i, idx: idx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(send)
|
||||||
|
// Wait for all the workers to quit, then close the ans channel.
|
||||||
|
workerWG.Wait()
|
||||||
|
close(ans)
|
||||||
|
}(send)
|
||||||
|
|
||||||
|
// Read in the results.
|
||||||
|
is2 := 1 / (step * step)
|
||||||
|
var laplacian float64
|
||||||
|
for r := range ans {
|
||||||
|
laplacian += r.result * stencil[r.idx].Coeff * is2
|
||||||
|
}
|
||||||
|
return laplacian
|
||||||
|
}
|
44
diff/fd/laplacian_test.go
Normal file
44
diff/fd/laplacian_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
// Copyright ©2017 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package fd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gonum.org/v1/gonum/floats"
|
||||||
|
"gonum.org/v1/gonum/mat"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLaplacian(t *testing.T) {
|
||||||
|
for cas, test := range hessianTestCases {
|
||||||
|
// Modify the test cases where the forumla is set.
|
||||||
|
settings := test.settings
|
||||||
|
if settings != nil && !settings.Formula.isZero() {
|
||||||
|
settings.Formula = Forward2nd
|
||||||
|
}
|
||||||
|
|
||||||
|
n := len(test.x)
|
||||||
|
got := Laplacian(test.h.Func, test.x, test.settings)
|
||||||
|
hess := mat.NewSymDense(n, nil)
|
||||||
|
test.h.Hess(hess, test.x)
|
||||||
|
var want float64
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
want += hess.At(i, i)
|
||||||
|
}
|
||||||
|
if !floats.EqualWithinAbsOrRel(got, want, test.tol, test.tol) {
|
||||||
|
t.Errorf("Cas %d: Laplacian mismatch. got %v, want %v", cas, got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that concurrency works.
|
||||||
|
if settings == nil {
|
||||||
|
settings = &Settings{}
|
||||||
|
}
|
||||||
|
settings.Concurrent = true
|
||||||
|
got2 := Laplacian(test.h.Func, test.x, settings)
|
||||||
|
if !floats.EqualWithinAbsOrRel(got, got2, 1e-5, 1e-5) {
|
||||||
|
t.Errorf("Cas %d: Laplacian mismatch. got %v, want %v", cas, got2, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user