Files
gonum/optimize/gradientdescent.go
Brendan Tracey 2704973b50 optimize: Remove Local function (#538)
* optimize: Remove Local function

This change removes the Local function. In order to do so, this changes the previous LocalGlobal wrapper to LocalController to allow Local methods to be used as a Global optimizer. This adds methods to all of the Local methods in order to implement GlobalMethod, and changes the tests accordingly. The next commit will fix all of the names
2018-07-18 12:18:18 -06:00

83 lines
2.1 KiB
Go

// Copyright ©2014 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package optimize
import "gonum.org/v1/gonum/floats"
// GradientDescent implements the steepest descent optimization method that
// performs successive steps along the direction of the negative gradient.
type GradientDescent struct {
// Linesearcher selects suitable steps along the descent direction.
// If Linesearcher is nil, a reasonable default will be chosen.
Linesearcher Linesearcher
// StepSizer determines the initial step size along each direction.
// If StepSizer is nil, a reasonable default will be chosen.
StepSizer StepSizer
ls *LinesearchMethod
status Status
err error
}
func (g *GradientDescent) Status() (Status, error) {
return g.status, g.err
}
func (g *GradientDescent) InitGlobal(dim, tasks int) int {
g.status = NotTerminated
g.err = nil
return 1
}
func (g *GradientDescent) RunGlobal(operation chan<- GlobalTask, result <-chan GlobalTask, tasks []GlobalTask) {
g.status, g.err = localOptimizer{}.runGlobal(g, operation, result, tasks)
close(operation)
return
}
func (g *GradientDescent) Init(loc *Location) (Operation, error) {
if g.Linesearcher == nil {
g.Linesearcher = &Backtracking{}
}
if g.StepSizer == nil {
g.StepSizer = &QuadraticStepSize{}
}
if g.ls == nil {
g.ls = &LinesearchMethod{}
}
g.ls.Linesearcher = g.Linesearcher
g.ls.NextDirectioner = g
return g.ls.Init(loc)
}
func (g *GradientDescent) Iterate(loc *Location) (Operation, error) {
return g.ls.Iterate(loc)
}
func (g *GradientDescent) InitDirection(loc *Location, dir []float64) (stepSize float64) {
copy(dir, loc.Gradient)
floats.Scale(-1, dir)
return g.StepSizer.Init(loc, dir)
}
func (g *GradientDescent) NextDirection(loc *Location, dir []float64) (stepSize float64) {
copy(dir, loc.Gradient)
floats.Scale(-1, dir)
return g.StepSizer.StepSize(loc, dir)
}
func (*GradientDescent) Needs() struct {
Gradient bool
Hessian bool
} {
return struct {
Gradient bool
Hessian bool
}{true, false}
}