mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 07:06:54 +08:00
146 lines
3.7 KiB
Go
146 lines
3.7 KiB
Go
// Copyright ©2017 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package fd
|
|
|
|
import "gonum.org/v1/gonum/floats"
|
|
|
|
// Gradient estimates the gradient of the multivariate function f at the
|
|
// location x. If dst is not nil, the result will be stored in-place into dst
|
|
// and returned, otherwise a new slice will be allocated first. Finite
|
|
// difference formula and other options are specified by settings. If settings is
|
|
// nil, the gradient will be estimated using the Forward formula and a default
|
|
// step size.
|
|
//
|
|
// Gradient panics if the length of dst and x is not equal, or if the derivative
|
|
// order of the formula is not 1.
|
|
func Gradient(dst []float64, f func([]float64) float64, x []float64, settings *Settings) []float64 {
|
|
if dst == nil {
|
|
dst = make([]float64, len(x))
|
|
}
|
|
if len(dst) != len(x) {
|
|
panic("fd: slice length mismatch")
|
|
}
|
|
|
|
// Default settings.
|
|
formula := Forward
|
|
step := formula.Step
|
|
var originValue float64
|
|
var originKnown, concurrent bool
|
|
|
|
// Use user settings if provided.
|
|
if settings != nil {
|
|
if !settings.Formula.isZero() {
|
|
formula = settings.Formula
|
|
step = formula.Step
|
|
checkFormula(formula)
|
|
if formula.Derivative != 1 {
|
|
panic(badDerivOrder)
|
|
}
|
|
}
|
|
if settings.Step != 0 {
|
|
step = settings.Step
|
|
}
|
|
originKnown = settings.OriginKnown
|
|
originValue = settings.OriginValue
|
|
concurrent = settings.Concurrent
|
|
}
|
|
|
|
evals := len(formula.Stencil) * len(x)
|
|
nWorkers := computeWorkers(concurrent, evals)
|
|
|
|
hasOrigin := usesOrigin(formula.Stencil)
|
|
// Copy x in case it is modified during the call.
|
|
xcopy := make([]float64, len(x))
|
|
if hasOrigin && !originKnown {
|
|
copy(xcopy, x)
|
|
originValue = f(xcopy)
|
|
}
|
|
|
|
if nWorkers == 1 {
|
|
for i := range xcopy {
|
|
var deriv float64
|
|
for _, pt := range formula.Stencil {
|
|
if pt.Loc == 0 {
|
|
deriv += pt.Coeff * originValue
|
|
continue
|
|
}
|
|
// Copying the data anew has two benefits. First, it
|
|
// avoids floating point issues where adding and then
|
|
// subtracting the step don't return to the exact same
|
|
// location. Secondly, it protects against the function
|
|
// modifying the input data.
|
|
copy(xcopy, x)
|
|
xcopy[i] += pt.Loc * step
|
|
deriv += pt.Coeff * f(xcopy)
|
|
}
|
|
dst[i] = deriv / step
|
|
}
|
|
return dst
|
|
}
|
|
|
|
sendChan := make(chan fdrun, evals)
|
|
ansChan := make(chan fdrun, evals)
|
|
quit := make(chan struct{})
|
|
defer close(quit)
|
|
|
|
// Launch workers. Workers receive an index and a step, and compute the answer.
|
|
for i := 0; i < nWorkers; i++ {
|
|
go func(sendChan <-chan fdrun, ansChan chan<- fdrun, quit <-chan struct{}) {
|
|
xcopy := make([]float64, len(x))
|
|
for {
|
|
select {
|
|
case <-quit:
|
|
return
|
|
case run := <-sendChan:
|
|
// See above comment on the copy.
|
|
copy(xcopy, x)
|
|
xcopy[run.idx] += run.pt.Loc * step
|
|
run.result = f(xcopy)
|
|
ansChan <- run
|
|
}
|
|
}
|
|
}(sendChan, ansChan, quit)
|
|
}
|
|
|
|
// Launch the distributor. Distributor sends the cases to be computed.
|
|
go func(sendChan chan<- fdrun, ansChan chan<- fdrun) {
|
|
for i := range x {
|
|
for _, pt := range formula.Stencil {
|
|
if pt.Loc == 0 {
|
|
// Answer already known. Send the answer on the answer channel.
|
|
ansChan <- fdrun{
|
|
idx: i,
|
|
pt: pt,
|
|
result: originValue,
|
|
}
|
|
continue
|
|
}
|
|
// Answer not known, send the answer to be computed.
|
|
sendChan <- fdrun{
|
|
idx: i,
|
|
pt: pt,
|
|
}
|
|
}
|
|
}
|
|
}(sendChan, ansChan)
|
|
|
|
for i := range dst {
|
|
dst[i] = 0
|
|
}
|
|
// Read in all of the results.
|
|
for i := 0; i < evals; i++ {
|
|
run := <-ansChan
|
|
dst[run.idx] += run.pt.Coeff * run.result
|
|
}
|
|
floats.Scale(1/step, dst)
|
|
return dst
|
|
}
|
|
|
|
type fdrun struct {
|
|
idx int
|
|
pt Point
|
|
result float64
|
|
}
|