mirror of
https://github.com/gonum/gonum.git
synced 2025-10-03 14:26:36 +08:00

* optimize: Remove Local function This change removes the Local function. In order to do so, this changes the previous LocalGlobal wrapper to LocalController to allow Local methods to be used as a Global optimizer. This adds methods to all of the Local methods in order to implement GlobalMethod, and changes the tests accordingly. The next commit will fix all of the names
83 lines
2.1 KiB
Go
83 lines
2.1 KiB
Go
// Copyright ©2014 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package optimize
|
|
|
|
import "gonum.org/v1/gonum/floats"
|
|
|
|
// GradientDescent implements the steepest descent optimization method that
|
|
// performs successive steps along the direction of the negative gradient.
|
|
type GradientDescent struct {
|
|
// Linesearcher selects suitable steps along the descent direction.
|
|
// If Linesearcher is nil, a reasonable default will be chosen.
|
|
Linesearcher Linesearcher
|
|
// StepSizer determines the initial step size along each direction.
|
|
// If StepSizer is nil, a reasonable default will be chosen.
|
|
StepSizer StepSizer
|
|
|
|
ls *LinesearchMethod
|
|
|
|
status Status
|
|
err error
|
|
}
|
|
|
|
func (g *GradientDescent) Status() (Status, error) {
|
|
return g.status, g.err
|
|
}
|
|
|
|
func (g *GradientDescent) InitGlobal(dim, tasks int) int {
|
|
g.status = NotTerminated
|
|
g.err = nil
|
|
return 1
|
|
}
|
|
|
|
func (g *GradientDescent) RunGlobal(operation chan<- GlobalTask, result <-chan GlobalTask, tasks []GlobalTask) {
|
|
g.status, g.err = localOptimizer{}.runGlobal(g, operation, result, tasks)
|
|
close(operation)
|
|
return
|
|
}
|
|
|
|
func (g *GradientDescent) Init(loc *Location) (Operation, error) {
|
|
if g.Linesearcher == nil {
|
|
g.Linesearcher = &Backtracking{}
|
|
}
|
|
if g.StepSizer == nil {
|
|
g.StepSizer = &QuadraticStepSize{}
|
|
}
|
|
|
|
if g.ls == nil {
|
|
g.ls = &LinesearchMethod{}
|
|
}
|
|
g.ls.Linesearcher = g.Linesearcher
|
|
g.ls.NextDirectioner = g
|
|
|
|
return g.ls.Init(loc)
|
|
}
|
|
|
|
func (g *GradientDescent) Iterate(loc *Location) (Operation, error) {
|
|
return g.ls.Iterate(loc)
|
|
}
|
|
|
|
func (g *GradientDescent) InitDirection(loc *Location, dir []float64) (stepSize float64) {
|
|
copy(dir, loc.Gradient)
|
|
floats.Scale(-1, dir)
|
|
return g.StepSizer.Init(loc, dir)
|
|
}
|
|
|
|
func (g *GradientDescent) NextDirection(loc *Location, dir []float64) (stepSize float64) {
|
|
copy(dir, loc.Gradient)
|
|
floats.Scale(-1, dir)
|
|
return g.StepSizer.StepSize(loc, dir)
|
|
}
|
|
|
|
func (*GradientDescent) Needs() struct {
|
|
Gradient bool
|
|
Hessian bool
|
|
} {
|
|
return struct {
|
|
Gradient bool
|
|
Hessian bool
|
|
}{true, false}
|
|
}
|