mirror of
https://github.com/gonum/gonum.git
synced 2025-10-12 18:40:09 +08:00
optimize: remove Local implementation and replace with a call to Global (#485)
* optimize: remove Local implementation and replace with a call to Global This PR starts the process described in #482. It removes the existing Local implementation, replacing with a function that wraps Method to act as a GlobalMethod. This PR also adds a hack to fix an inconsistency with FunctionConverge between Global and Local (and a TODO to make it not a hack in the future)
This commit is contained in:
@@ -13,16 +13,23 @@ type FunctionConverge struct {
|
|||||||
Relative float64
|
Relative float64
|
||||||
Iterations int
|
Iterations int
|
||||||
|
|
||||||
best float64
|
first bool
|
||||||
iter int
|
best float64
|
||||||
|
iter int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fc *FunctionConverge) Init(f float64) {
|
func (fc *FunctionConverge) Init(f float64) {
|
||||||
fc.best = f
|
fc.first = true
|
||||||
|
fc.best = 0
|
||||||
fc.iter = 0
|
fc.iter = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fc *FunctionConverge) FunctionConverged(f float64) Status {
|
func (fc *FunctionConverge) FunctionConverged(f float64) Status {
|
||||||
|
if fc.first {
|
||||||
|
fc.best = f
|
||||||
|
fc.first = false
|
||||||
|
return NotTerminated
|
||||||
|
}
|
||||||
if fc.Iterations == 0 {
|
if fc.Iterations == 0 {
|
||||||
return NotTerminated
|
return NotTerminated
|
||||||
}
|
}
|
||||||
|
@@ -51,8 +51,11 @@ type GlobalMethod interface {
|
|||||||
// The GlobalMethod must read from the result channel until it is closed.
|
// The GlobalMethod must read from the result channel until it is closed.
|
||||||
// During this, the GlobalMethod may want to send new MajorIteration(s) on
|
// During this, the GlobalMethod may want to send new MajorIteration(s) on
|
||||||
// operation. GlobalMethod then must close operation, and return from RunGlobal.
|
// operation. GlobalMethod then must close operation, and return from RunGlobal.
|
||||||
|
// These steps must establish a "happens-before" relationship between result
|
||||||
|
// being closed (externally) and RunGlobal closing operation, for example
|
||||||
|
// by using a range loop to read from result even if no results are expected.
|
||||||
//
|
//
|
||||||
// The las parameter to RunGlobal is a slice of tasks with length equal to
|
// The last parameter to RunGlobal is a slice of tasks with length equal to
|
||||||
// the return from InitGlobal. GlobalTask has an ID field which may be
|
// the return from InitGlobal. GlobalTask has an ID field which may be
|
||||||
// set and modified by GlobalMethod, and must not be modified by the caller.
|
// set and modified by GlobalMethod, and must not be modified by the caller.
|
||||||
//
|
//
|
||||||
@@ -117,6 +120,9 @@ func Global(p Problem, dim int, settings *Settings, method GlobalMethod) (*Resul
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(btracey): These init calls don't do anything with their arguments
|
||||||
|
// because optLoc is meaningless at this point. Should change the function
|
||||||
|
// signatures.
|
||||||
optLoc := newLocation(dim, method)
|
optLoc := newLocation(dim, method)
|
||||||
optLoc.F = math.Inf(1)
|
optLoc.F = math.Inf(1)
|
||||||
|
|
||||||
@@ -198,6 +204,8 @@ func minimizeGlobal(prob *Problem, method GlobalMethod, settings *Settings, stat
|
|||||||
// method that all results have been collected. At this point, the method
|
// method that all results have been collected. At this point, the method
|
||||||
// may send MajorIteration(s) to update an optimum location based on these
|
// may send MajorIteration(s) to update an optimum location based on these
|
||||||
// last returned results, and then the method will close the operations channel.
|
// last returned results, and then the method will close the operations channel.
|
||||||
|
// The GlobalMethod must ensure that the closing of results happens before the
|
||||||
|
// closing of operations in order to ensure proper shutdown order.
|
||||||
// Now that no more tasks will be commanded by the method, the distributor
|
// Now that no more tasks will be commanded by the method, the distributor
|
||||||
// closes statsChan, and with no more statistics to update the optimization
|
// closes statsChan, and with no more statistics to update the optimization
|
||||||
// concludes.
|
// concludes.
|
||||||
|
@@ -13,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestGuessAndCheck(t *testing.T) {
|
func TestGuessAndCheck(t *testing.T) {
|
||||||
dim := 3000
|
dim := 30
|
||||||
problem := Problem{
|
problem := Problem{
|
||||||
Func: functions.ExtendedRosenbrock{}.Func,
|
Func: functions.ExtendedRosenbrock{}.Func,
|
||||||
}
|
}
|
||||||
|
@@ -4,10 +4,7 @@
|
|||||||
|
|
||||||
package optimize
|
package optimize
|
||||||
|
|
||||||
import (
|
import "math"
|
||||||
"math"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Local finds a local minimum of a minimization problem using a sequential
|
// Local finds a local minimum of a minimization problem using a sequential
|
||||||
// algorithm. A maximization problem can be transformed into a minimization
|
// algorithm. A maximization problem can be transformed into a minimization
|
||||||
@@ -59,123 +56,18 @@ import (
|
|||||||
// maximum runtime or maximum function evaluations, modify the Settings
|
// maximum runtime or maximum function evaluations, modify the Settings
|
||||||
// input struct.
|
// input struct.
|
||||||
func Local(p Problem, initX []float64, settings *Settings, method Method) (*Result, error) {
|
func Local(p Problem, initX []float64, settings *Settings, method Method) (*Result, error) {
|
||||||
startTime := time.Now()
|
|
||||||
dim := len(initX)
|
|
||||||
if method == nil {
|
if method == nil {
|
||||||
method = getDefaultMethod(&p)
|
method = getDefaultMethod(&p)
|
||||||
}
|
}
|
||||||
if settings == nil {
|
if settings == nil {
|
||||||
settings = DefaultSettings()
|
settings = DefaultSettings()
|
||||||
}
|
}
|
||||||
|
lg := &localGlobal{
|
||||||
stats := &Stats{}
|
Method: method,
|
||||||
|
InitX: initX,
|
||||||
err := checkOptimization(p, dim, method, settings.Recorder)
|
Settings: settings,
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
optLoc, err := getStartingLocation(&p, method, initX, stats, settings)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if settings.FunctionConverge != nil {
|
|
||||||
settings.FunctionConverge.Init(optLoc.F)
|
|
||||||
}
|
|
||||||
|
|
||||||
stats.Runtime = time.Since(startTime)
|
|
||||||
|
|
||||||
// Send initial location to Recorder
|
|
||||||
if settings.Recorder != nil {
|
|
||||||
err = settings.Recorder.Record(optLoc, InitIteration, stats)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the starting location satisfies the convergence criteria.
|
|
||||||
status := checkLocationConvergence(optLoc, settings)
|
|
||||||
|
|
||||||
// Run optimization
|
|
||||||
if status == NotTerminated && err == nil {
|
|
||||||
// The starting location is not good enough, we need to perform a
|
|
||||||
// minimization. The optimal location will be stored in-place in
|
|
||||||
// optLoc.
|
|
||||||
status, err = minimize(&p, method, settings, stats, optLoc, startTime)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup and collect results
|
|
||||||
if settings.Recorder != nil && err == nil {
|
|
||||||
// Send the optimal location to Recorder.
|
|
||||||
err = settings.Recorder.Record(optLoc, PostIteration, stats)
|
|
||||||
}
|
|
||||||
stats.Runtime = time.Since(startTime)
|
|
||||||
return &Result{
|
|
||||||
Location: *optLoc,
|
|
||||||
Stats: *stats,
|
|
||||||
Status: status,
|
|
||||||
}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func minimize(p *Problem, method Method, settings *Settings, stats *Stats, optLoc *Location, startTime time.Time) (status Status, err error) {
|
|
||||||
loc := &Location{}
|
|
||||||
copyLocation(loc, optLoc)
|
|
||||||
x := make([]float64, len(loc.X))
|
|
||||||
|
|
||||||
var op Operation
|
|
||||||
op, err = method.Init(loc)
|
|
||||||
if err != nil {
|
|
||||||
status = Failure
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
// Sequentially call method.Iterate, performing the operations it has
|
|
||||||
// commanded, until convergence.
|
|
||||||
|
|
||||||
switch op {
|
|
||||||
case NoOperation:
|
|
||||||
case InitIteration:
|
|
||||||
panic("optimize: Method returned InitIteration")
|
|
||||||
case PostIteration:
|
|
||||||
panic("optimize: Method returned PostIteration")
|
|
||||||
case MajorIteration:
|
|
||||||
status = performMajorIteration(optLoc, loc, stats, startTime, settings)
|
|
||||||
case MethodDone:
|
|
||||||
statuser, ok := method.(Statuser)
|
|
||||||
if !ok {
|
|
||||||
panic("optimize: method returned MethodDone is not a Statuser")
|
|
||||||
}
|
|
||||||
status, err = statuser.Status()
|
|
||||||
if status == NotTerminated {
|
|
||||||
panic("optimize: method returned MethodDone but a NotTerminated status")
|
|
||||||
}
|
|
||||||
default: // Any of the Evaluation operations.
|
|
||||||
evaluate(p, loc, op, x)
|
|
||||||
updateEvaluationStats(stats, op)
|
|
||||||
status, err = checkEvaluationLimits(p, stats, settings)
|
|
||||||
}
|
|
||||||
if status != NotTerminated || err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if settings.Recorder != nil {
|
|
||||||
stats.Runtime = time.Since(startTime)
|
|
||||||
err = settings.Recorder.Record(loc, op, stats)
|
|
||||||
if err != nil {
|
|
||||||
if status == NotTerminated {
|
|
||||||
status = Failure
|
|
||||||
}
|
|
||||||
return status, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
op, err = method.Iterate(loc)
|
|
||||||
if err != nil {
|
|
||||||
status = Failure
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return Global(p, len(initX), settings, lg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDefaultMethod(p *Problem) Method {
|
func getDefaultMethod(p *Problem) Method {
|
||||||
@@ -185,55 +77,155 @@ func getDefaultMethod(p *Problem) Method {
|
|||||||
return &NelderMead{}
|
return &NelderMead{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getStartingLocation allocates and initializes the starting location for the minimization.
|
// localGlobal is a wrapper for Local methods to allow them to be optimized by Global.
|
||||||
func getStartingLocation(p *Problem, method Method, initX []float64, stats *Stats, settings *Settings) (*Location, error) {
|
type localGlobal struct {
|
||||||
dim := len(initX)
|
Method Method
|
||||||
loc := newLocation(dim, method)
|
InitX []float64
|
||||||
copy(loc.X, initX)
|
Settings *Settings
|
||||||
|
|
||||||
if settings.UseInitialData {
|
dim int
|
||||||
loc.F = settings.InitialValue
|
status Status
|
||||||
if loc.Gradient != nil {
|
err error
|
||||||
initG := settings.InitialGradient
|
}
|
||||||
if initG == nil {
|
|
||||||
|
func (l *localGlobal) InitGlobal(dim, tasks int) int {
|
||||||
|
if dim != len(l.InitX) {
|
||||||
|
panic("optimize: initial length mismatch")
|
||||||
|
}
|
||||||
|
l.dim = dim
|
||||||
|
l.status = NotTerminated
|
||||||
|
l.err = nil
|
||||||
|
return 1 // Local optimizations always run in serial.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *localGlobal) Status() (Status, error) {
|
||||||
|
return l.status, l.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *localGlobal) Needs() struct {
|
||||||
|
Gradient bool
|
||||||
|
Hessian bool
|
||||||
|
} {
|
||||||
|
return l.Method.Needs()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *localGlobal) RunGlobal(operations chan<- GlobalTask, results <-chan GlobalTask, tasks []GlobalTask) {
|
||||||
|
// Local methods start with a fully-specified initial location.
|
||||||
|
task := tasks[0]
|
||||||
|
op := l.getStartingLocation(operations, results, task)
|
||||||
|
if op == PostIteration {
|
||||||
|
l.cleanup(operations, results)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check the starting condition.
|
||||||
|
if math.IsInf(task.F, 1) || math.IsNaN(task.F) {
|
||||||
|
l.status = Failure
|
||||||
|
l.err = ErrFunc(task.F)
|
||||||
|
}
|
||||||
|
for i, v := range task.Gradient {
|
||||||
|
if math.IsInf(v, 0) || math.IsNaN(v) {
|
||||||
|
l.status = Failure
|
||||||
|
l.err = ErrGrad{Grad: v, Index: i}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if l.status == Failure {
|
||||||
|
l.exitFailure(operations, results, tasks[0])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a major iteration with the starting location.
|
||||||
|
task.Op = MajorIteration
|
||||||
|
operations <- task
|
||||||
|
task = <-results
|
||||||
|
if task.Op == PostIteration {
|
||||||
|
l.cleanup(operations, results)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
op, err := l.Method.Init(task.Location)
|
||||||
|
if err != nil {
|
||||||
|
l.status = Failure
|
||||||
|
l.err = err
|
||||||
|
l.exitFailure(operations, results, tasks[0])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
task.Op = op
|
||||||
|
operations <- task
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
result := <-results
|
||||||
|
switch result.Op {
|
||||||
|
case PostIteration:
|
||||||
|
break Loop
|
||||||
|
default:
|
||||||
|
op, err := l.Method.Iterate(result.Location)
|
||||||
|
if err != nil {
|
||||||
|
l.status = Failure
|
||||||
|
l.err = err
|
||||||
|
l.exitFailure(operations, results, result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result.Op = op
|
||||||
|
operations <- result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.cleanup(operations, results)
|
||||||
|
}
|
||||||
|
|
||||||
|
// exitFailure cleans up from a failure of the local method.
|
||||||
|
func (l *localGlobal) exitFailure(operation chan<- GlobalTask, result <-chan GlobalTask, task GlobalTask) {
|
||||||
|
task.Op = MethodDone
|
||||||
|
operation <- task
|
||||||
|
task = <-result
|
||||||
|
if task.Op != PostIteration {
|
||||||
|
panic("task should have returned post iteration")
|
||||||
|
}
|
||||||
|
l.cleanup(operation, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *localGlobal) cleanup(operation chan<- GlobalTask, result <-chan GlobalTask) {
|
||||||
|
// Guarantee that result is closed before operation is closed.
|
||||||
|
for range result {
|
||||||
|
}
|
||||||
|
close(operation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *localGlobal) getStartingLocation(operation chan<- GlobalTask, result <-chan GlobalTask, task GlobalTask) Operation {
|
||||||
|
copy(task.X, l.InitX)
|
||||||
|
if l.Settings.UseInitialData {
|
||||||
|
task.F = l.Settings.InitialValue
|
||||||
|
if task.Gradient != nil {
|
||||||
|
g := l.Settings.InitialGradient
|
||||||
|
if g == nil {
|
||||||
panic("optimize: initial gradient is nil")
|
panic("optimize: initial gradient is nil")
|
||||||
}
|
}
|
||||||
if len(initG) != dim {
|
if len(g) != l.dim {
|
||||||
panic("optimize: initial gradient size mismatch")
|
panic("optimize: initial gradient size mismatch")
|
||||||
}
|
}
|
||||||
copy(loc.Gradient, initG)
|
copy(task.Gradient, g)
|
||||||
}
|
}
|
||||||
if loc.Hessian != nil {
|
if task.Hessian != nil {
|
||||||
initH := settings.InitialHessian
|
h := l.Settings.InitialHessian
|
||||||
if initH == nil {
|
if h == nil {
|
||||||
panic("optimize: initial Hessian is nil")
|
panic("optimize: initial Hessian is nil")
|
||||||
}
|
}
|
||||||
if initH.Symmetric() != dim {
|
if h.Symmetric() != l.dim {
|
||||||
panic("optimize: initial Hessian size mismatch")
|
panic("optimize: initial Hessian size mismatch")
|
||||||
}
|
}
|
||||||
loc.Hessian.CopySym(initH)
|
task.Hessian.CopySym(h)
|
||||||
}
|
}
|
||||||
} else {
|
return NoOperation
|
||||||
eval := FuncEvaluation
|
|
||||||
if loc.Gradient != nil {
|
|
||||||
eval |= GradEvaluation
|
|
||||||
}
|
|
||||||
if loc.Hessian != nil {
|
|
||||||
eval |= HessEvaluation
|
|
||||||
}
|
|
||||||
x := make([]float64, len(loc.X))
|
|
||||||
evaluate(p, loc, eval, x)
|
|
||||||
updateEvaluationStats(stats, eval)
|
|
||||||
}
|
}
|
||||||
|
eval := FuncEvaluation
|
||||||
if math.IsInf(loc.F, 1) || math.IsNaN(loc.F) {
|
if task.Gradient != nil {
|
||||||
return loc, ErrFunc(loc.F)
|
eval |= GradEvaluation
|
||||||
}
|
}
|
||||||
for i, v := range loc.Gradient {
|
if task.Hessian != nil {
|
||||||
if math.IsInf(v, 0) || math.IsNaN(v) {
|
eval |= HessEvaluation
|
||||||
return loc, ErrGrad{Grad: v, Index: i}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
task.Op = eval
|
||||||
return loc, nil
|
operation <- task
|
||||||
|
task = <-result
|
||||||
|
return task.Op
|
||||||
}
|
}
|
||||||
|
@@ -1155,7 +1155,7 @@ func TestNewton(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testLocal(t *testing.T, tests []unconstrainedTest, method Method) {
|
func testLocal(t *testing.T, tests []unconstrainedTest, method Method) {
|
||||||
for _, test := range tests {
|
for cas, test := range tests {
|
||||||
if test.long && testing.Short() {
|
if test.long && testing.Short() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1182,11 +1182,11 @@ func testLocal(t *testing.T, tests []unconstrainedTest, method Method) {
|
|||||||
|
|
||||||
result, err := Local(test.p, test.x, settings, method)
|
result, err := Local(test.p, test.x, settings, method)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error finding minimum (%v) for:\n%v", err, test)
|
t.Errorf("Case %d: error finding minimum (%v) for:\n%v", cas, err, test)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if result == nil {
|
if result == nil {
|
||||||
t.Errorf("nil result without error for:\n%v", test)
|
t.Errorf("Case %d: nil result without error for:\n%v", cas, test)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1194,8 +1194,8 @@ func testLocal(t *testing.T, tests []unconstrainedTest, method Method) {
|
|||||||
// equal to result.F.
|
// equal to result.F.
|
||||||
optF := test.p.Func(result.X)
|
optF := test.p.Func(result.X)
|
||||||
if optF != result.F {
|
if optF != result.F {
|
||||||
t.Errorf("Function value at the optimum location %v not equal to the returned value %v for:\n%v",
|
t.Errorf("Case %d: Function value at the optimum location %v not equal to the returned value %v for:\n%v",
|
||||||
optF, result.F, test)
|
cas, optF, result.F, test)
|
||||||
}
|
}
|
||||||
if result.Gradient != nil {
|
if result.Gradient != nil {
|
||||||
// Evaluate the norm of the gradient at the found optimum location.
|
// Evaluate the norm of the gradient at the found optimum location.
|
||||||
@@ -1203,15 +1203,15 @@ func testLocal(t *testing.T, tests []unconstrainedTest, method Method) {
|
|||||||
test.p.Grad(g, result.X)
|
test.p.Grad(g, result.X)
|
||||||
|
|
||||||
if !floats.Equal(result.Gradient, g) {
|
if !floats.Equal(result.Gradient, g) {
|
||||||
t.Errorf("Gradient at the optimum location not equal to the returned value for:\n%v", test)
|
t.Errorf("Case %d: Gradient at the optimum location not equal to the returned value for:\n%v", cas, test)
|
||||||
}
|
}
|
||||||
|
|
||||||
optNorm := floats.Norm(g, math.Inf(1))
|
optNorm := floats.Norm(g, math.Inf(1))
|
||||||
// Check that the norm of the gradient at the found optimum location is
|
// Check that the norm of the gradient at the found optimum location is
|
||||||
// smaller than the tolerance.
|
// smaller than the tolerance.
|
||||||
if optNorm >= settings.GradientThreshold {
|
if optNorm >= settings.GradientThreshold {
|
||||||
t.Errorf("Norm of the gradient at the optimum location %v not smaller than tolerance %v for:\n%v",
|
t.Errorf("Case %d: Norm of the gradient at the optimum location %v not smaller than tolerance %v for:\n%v",
|
||||||
optNorm, settings.GradientThreshold, test)
|
cas, optNorm, settings.GradientThreshold, test)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user