mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 15:16:59 +08:00

* diff/fd: implement Hessian finite difference, and code cleanups. This commit primarily adds the Hessian function for finding a finite difference approximation to the Hessian. At the same time, it combines duplicated functionality across the difference routines so that the preludes to all the difference routines look similar
145 lines
3.7 KiB
Go
145 lines
3.7 KiB
Go
// Copyright ©2017 The gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package fd
|
|
|
|
import "gonum.org/v1/gonum/floats"
|
|
|
|
// Gradient estimates the gradient of the multivariate function f at the
|
|
// location x. If dst is not nil, the result will be stored in-place into dst
|
|
// and returned, otherwise a new slice will be allocated first. Finite
|
|
// difference formula and other options are specified by settings. If settings is
|
|
// nil, the gradient will be estimated using the Forward formula and a default
|
|
// step size.
|
|
//
|
|
// Gradient panics if the length of dst and x is not equal, or if the derivative
|
|
// order of the formula is not 1.
|
|
func Gradient(dst []float64, f func([]float64) float64, x []float64, settings *Settings) []float64 {
|
|
if dst == nil {
|
|
dst = make([]float64, len(x))
|
|
}
|
|
if len(dst) != len(x) {
|
|
panic("fd: slice length mismatch")
|
|
}
|
|
|
|
// Default settings.
|
|
formula := Forward
|
|
step := formula.Step
|
|
var originValue float64
|
|
var originKnown, concurrent bool
|
|
|
|
// Use user settings if provided.
|
|
if settings != nil {
|
|
if !settings.Formula.isZero() {
|
|
formula = settings.Formula
|
|
step = formula.Step
|
|
checkFormula(formula)
|
|
if formula.Derivative != 1 {
|
|
panic(badDerivOrder)
|
|
}
|
|
}
|
|
if settings.Step != 0 {
|
|
step = settings.Step
|
|
}
|
|
originKnown = settings.OriginKnown
|
|
originValue = settings.OriginValue
|
|
concurrent = settings.Concurrent
|
|
}
|
|
|
|
evals := len(formula.Stencil) * len(x)
|
|
nWorkers := computeWorkers(concurrent, evals)
|
|
|
|
hasOrigin := usesOrigin(formula.Stencil)
|
|
xcopy := make([]float64, len(x)) // So that x is not modified during the call.
|
|
if hasOrigin && !originKnown {
|
|
copy(xcopy, x)
|
|
originValue = f(xcopy)
|
|
}
|
|
|
|
if nWorkers == 1 {
|
|
for i := range xcopy {
|
|
var deriv float64
|
|
for _, pt := range formula.Stencil {
|
|
if pt.Loc == 0 {
|
|
deriv += pt.Coeff * originValue
|
|
continue
|
|
}
|
|
// Copying the code anew has two benefits. First, it
|
|
// avoids floating point issues where adding and then
|
|
// subtracting the step don't return to the exact same
|
|
// location. Secondly, it protects against the function
|
|
// modifying the input data.
|
|
copy(xcopy, x)
|
|
xcopy[i] += pt.Loc * step
|
|
deriv += pt.Coeff * f(xcopy)
|
|
}
|
|
dst[i] = deriv / step
|
|
}
|
|
return dst
|
|
}
|
|
|
|
sendChan := make(chan fdrun, evals)
|
|
ansChan := make(chan fdrun, evals)
|
|
quit := make(chan struct{})
|
|
defer close(quit)
|
|
|
|
// Launch workers. Workers receive an index and a step, and compute the answer.
|
|
for i := 0; i < nWorkers; i++ {
|
|
go func(sendChan <-chan fdrun, ansChan chan<- fdrun, quit <-chan struct{}) {
|
|
xcopy := make([]float64, len(x))
|
|
for {
|
|
select {
|
|
case <-quit:
|
|
return
|
|
case run := <-sendChan:
|
|
// See above comment on the copy.
|
|
copy(xcopy, x)
|
|
xcopy[run.idx] += run.pt.Loc * step
|
|
run.result = f(xcopy)
|
|
ansChan <- run
|
|
}
|
|
}
|
|
}(sendChan, ansChan, quit)
|
|
}
|
|
|
|
// Launch the distributor. Distributor sends the cases to be computed.
|
|
go func(sendChan chan<- fdrun, ansChan chan<- fdrun) {
|
|
for i := range x {
|
|
for _, pt := range formula.Stencil {
|
|
if pt.Loc == 0 {
|
|
// Answer already known. Send the answer on the answer channel.
|
|
ansChan <- fdrun{
|
|
idx: i,
|
|
pt: pt,
|
|
result: originValue,
|
|
}
|
|
continue
|
|
}
|
|
// Answer not known, send the answer to be computed.
|
|
sendChan <- fdrun{
|
|
idx: i,
|
|
pt: pt,
|
|
}
|
|
}
|
|
}
|
|
}(sendChan, ansChan)
|
|
|
|
for i := range dst {
|
|
dst[i] = 0
|
|
}
|
|
// Read in all of the results.
|
|
for i := 0; i < evals; i++ {
|
|
run := <-ansChan
|
|
dst[run.idx] += run.pt.Coeff * run.result
|
|
}
|
|
floats.Scale(1/step, dst)
|
|
return dst
|
|
}
|
|
|
|
type fdrun struct {
|
|
idx int
|
|
pt Point
|
|
result float64
|
|
}
|