mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 07:06:54 +08:00
optimize: Update documentation for Global and clean up usage of globa… (#265)
* optimize: Update documentation for Global and clean up usage of globalStatus
This commit is contained in:
@@ -10,17 +10,55 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// GlobalMethod is a global optimizer. Typically will require more function
|
||||
// evaluations and no sense of local convergence
|
||||
// GlobalMethod is an optimization method which seeks to find the global minimum
|
||||
// of an objective function.
|
||||
//
|
||||
// At the beginning of the optimization, InitGlobal is called to communicate
|
||||
// the dimension of the input and maximum number of concurrent tasks.
|
||||
// The actual number of concurrent tasks will be set from the return of InitGlobal,
|
||||
// which must not be greater than the input tasks.
|
||||
//
|
||||
// During the optimization, a reverse-communication interface is used between
|
||||
// the GlobalMethod and the caller.
|
||||
// GlobalMethod acts as a client that asks the caller to perform
|
||||
// needed operations given the return from IterateGlobal.
|
||||
// This allows and enforces automation of maintaining statistics and checking for
|
||||
// (various types of) convergence.
|
||||
//
|
||||
// The return from IterateGlobal can be an Evaluation, a MajorIteration or NoOperation.
|
||||
//
|
||||
// An evaluation is one or more of the Evaluation operations (FuncEvaluation,
|
||||
// GradEvaluation, etc.) combined with the bitwise or operator. In an evaluation
|
||||
// operation, the requested fields of Problem will be evaluated at the value
|
||||
// in Location.X, filling the corresponding fields of Location. These values
|
||||
// can be retrieved and used upon the next call to IterateGlobal with that task id.
|
||||
// The GlobalMethod interface requires that entries of Location are not modified
|
||||
// aside from the commanded evaluations. Thus, the type implementing GlobalMethod
|
||||
// may use multiple Operations to set the Location fields at a particular x value.
|
||||
//
|
||||
// When IterateGlobal declares MajorIteration, the caller updates the optimal
|
||||
// location to the values in Location, and checks for convergence. The type
|
||||
// implementing GlobalMethod must make sure that the fields of Location are valid
|
||||
// and consistent.
|
||||
//
|
||||
// IterateGlobal must not return InitIteration and PostIteration operations. These are
|
||||
// reserved for the clients to be passed to Recorders. A Method must also not
|
||||
// combine the Evaluation operations with the Iteration operations.
|
||||
type GlobalMethod interface {
|
||||
// Global tells method the max number of tasks, method returns how many it wants.
|
||||
// This is needed to sync the Global goroutines and inside goroutines.
|
||||
InitGlobal(dim, tasks int) int
|
||||
// Global method may assume that the same task id always has the same pointer with it.
|
||||
IterateGlobal(task int, loc *Location) (Operation, error)
|
||||
Needser
|
||||
// Done communicates to the optimization method that the optimization has
|
||||
// concluded to allow for shutdown.
|
||||
// InitGlobal communicates the input dimension and maximum number of tasks,
|
||||
// and returns the number of concurrent processes. The return must be less
|
||||
// than or equal to tasks.
|
||||
InitGlobal(dim, tasks int) int
|
||||
|
||||
// IterateGlobal retrieves information from the location associated with
|
||||
// the given task ID, and returns the next operation to perform with that
|
||||
// Location. IterateGlobal may assume that the same pointer is associated
|
||||
// with the same task.
|
||||
IterateGlobal(task int, loc *Location) (Operation, error)
|
||||
|
||||
// Done communicates that the optimization has concluded to allow for shutdown.
|
||||
// After Done is called, no more calls to IterateGlobal will be made.
|
||||
Done()
|
||||
}
|
||||
|
||||
@@ -37,7 +75,7 @@ type GlobalMethod interface {
|
||||
// described below.
|
||||
//
|
||||
// If p.Status is not nil, it is called before every evaluation. If the
|
||||
// returned Status is not NotTerminated or the error is not nil, the
|
||||
// returned Status is other than NotTerminated or if the error is not nil, the
|
||||
// optimization run is terminated.
|
||||
//
|
||||
// The third argument contains the settings for the minimization. The
|
||||
@@ -62,12 +100,11 @@ type GlobalMethod interface {
|
||||
//
|
||||
// Be aware that the default behavior of Global is to find the minimum.
|
||||
// For certain functions and optimization methods, this process can take many
|
||||
// function evaluations. If you would like to put limits on this, for example
|
||||
// maximum runtime or maximum function evaluations, modify the Settings
|
||||
// input struct.
|
||||
// function evaluations. The Settings input struct can be used to limit this,
|
||||
// for example by modifying the maximum runtime or maximum function evaluations.
|
||||
//
|
||||
// Something about Global cannot guarantee strict bounds on function evaluations,
|
||||
// iterations, etc. in the precense of concurrency.
|
||||
// Global cannot guarantee strict adherence to the bounds specified in Settings
|
||||
// when performing concurrent evaluations and updates.
|
||||
func Global(p Problem, dim int, settings *Settings, method GlobalMethod) (*Result, error) {
|
||||
startTime := time.Now()
|
||||
if method == nil {
|
||||
@@ -115,6 +152,9 @@ func Global(p Problem, dim int, settings *Settings, method GlobalMethod) (*Resul
|
||||
}, err
|
||||
}
|
||||
|
||||
// minimizeGlobal is the high-level function for a Global optimization. It launches
|
||||
// concurrent workers to perform the mimization, and shuts them down properly
|
||||
// at the conclusion.
|
||||
func minimizeGlobal(p *Problem, method GlobalMethod, settings *Settings, stats *Stats, optLoc *Location, startTime time.Time) (status Status, err error) {
|
||||
dim := len(optLoc.X)
|
||||
statuser, _ := method.(Statuser)
|
||||
@@ -130,7 +170,11 @@ func minimizeGlobal(p *Problem, method GlobalMethod, settings *Settings, stats *
|
||||
}
|
||||
|
||||
nTasks := settings.Concurrent
|
||||
nTasks = method.InitGlobal(dim, nTasks)
|
||||
newNTasks := method.InitGlobal(dim, nTasks)
|
||||
if newNTasks > nTasks {
|
||||
panic("global: too many tasks returned by GlobalMethod")
|
||||
}
|
||||
nTasks = newNTasks
|
||||
|
||||
// Launch optimization workers
|
||||
var wg sync.WaitGroup
|
||||
@@ -148,29 +192,14 @@ func minimizeGlobal(p *Problem, method GlobalMethod, settings *Settings, stats *
|
||||
return gs.status, gs.err
|
||||
}
|
||||
|
||||
type globalStatus struct {
|
||||
mux *sync.RWMutex
|
||||
stats *Stats
|
||||
status Status
|
||||
p *Problem
|
||||
startTime time.Time
|
||||
optLoc *Location
|
||||
settings *Settings
|
||||
statuser Statuser
|
||||
err error
|
||||
}
|
||||
|
||||
// globalWorker runs the optimization steps for a single (concurrently-executing)
|
||||
// optimization task.
|
||||
func globalWorker(task int, m GlobalMethod, g *globalStatus, loc *Location, x []float64) {
|
||||
for {
|
||||
// Find Evaluation location
|
||||
op, err := m.IterateGlobal(task, loc)
|
||||
if err != nil {
|
||||
// TODO(btracey): Figure out how to handle errors properly. Shut
|
||||
// everything down? Pass to globalStatus so it can shut everything down?
|
||||
g.mux.Lock()
|
||||
g.err = err
|
||||
g.status = Failure
|
||||
g.mux.Unlock()
|
||||
g.updateStatus(Failure, err)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -182,52 +211,109 @@ func globalWorker(task int, m GlobalMethod, g *globalStatus, loc *Location, x []
|
||||
}
|
||||
}
|
||||
|
||||
// globalOperation updates handles the status received by an individual worker.
|
||||
// It uses a mutex to protect updates where necessary.
|
||||
func (g *globalStatus) globalOperation(op Operation, loc *Location, x []float64) Status {
|
||||
// Do a quick check to see if one of the other workers converged in the meantime.
|
||||
// globalStatus coordinates access to information shared between concurrently
|
||||
// executing optimization tasks.
|
||||
type globalStatus struct {
|
||||
mux *sync.RWMutex
|
||||
stats *Stats
|
||||
status Status
|
||||
p *Problem
|
||||
startTime time.Time
|
||||
optLoc *Location
|
||||
settings *Settings
|
||||
method GlobalMethod
|
||||
statuser Statuser
|
||||
err error
|
||||
}
|
||||
|
||||
// getStatus returns the current status of the optimization.
|
||||
func (g *globalStatus) getStatus() Status {
|
||||
var status Status
|
||||
var err error
|
||||
g.mux.RLock()
|
||||
defer g.mux.RUnlock()
|
||||
status = g.status
|
||||
g.mux.RUnlock()
|
||||
if status != NotTerminated {
|
||||
return status
|
||||
}
|
||||
switch op {
|
||||
case NoOperation:
|
||||
case InitIteration:
|
||||
panic("optimize: Method returned InitIteration")
|
||||
case PostIteration:
|
||||
panic("optimize: Method returned PostIteration")
|
||||
case MajorIteration:
|
||||
g.mux.Lock()
|
||||
g.stats.MajorIterations++
|
||||
copyLocation(g.optLoc, loc)
|
||||
g.mux.Unlock()
|
||||
|
||||
g.mux.RLock()
|
||||
status = checkConvergence(g.optLoc, g.settings, false)
|
||||
g.mux.RUnlock()
|
||||
default: // Any of the Evaluation operations.
|
||||
status, err = evaluate(g.p, loc, op, x)
|
||||
g.mux.Lock()
|
||||
updateStats(g.stats, op)
|
||||
g.mux.Unlock()
|
||||
}
|
||||
|
||||
g.mux.Lock()
|
||||
status, err = iterCleanup(status, err, g.stats, g.settings, g.statuser, g.startTime, loc, op)
|
||||
// Update the termination status if it hasn't already terminated.
|
||||
if g.status == NotTerminated {
|
||||
g.status = status
|
||||
g.err = err
|
||||
}
|
||||
g.mux.Unlock()
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
func (g *globalStatus) incrementMajorIteration() {
|
||||
g.mux.Lock()
|
||||
defer g.mux.Unlock()
|
||||
g.stats.MajorIterations++
|
||||
}
|
||||
|
||||
func (g *globalStatus) updateOptLoc(loc *Location) {
|
||||
g.mux.Lock()
|
||||
defer g.mux.Unlock()
|
||||
copyLocation(g.optLoc, loc)
|
||||
}
|
||||
|
||||
// checkConvergence checks the convergence of the global optimization and returns
|
||||
// the status
|
||||
func (g *globalStatus) checkConvergence() Status {
|
||||
g.mux.RLock()
|
||||
defer g.mux.RUnlock()
|
||||
return checkConvergence(g.optLoc, g.settings, false)
|
||||
}
|
||||
|
||||
// updateStats updates the evaluation statistics for the given operation.
|
||||
func (g *globalStatus) updateStats(op Operation) {
|
||||
g.mux.Lock()
|
||||
defer g.mux.Unlock()
|
||||
updateEvaluationStats(g.stats, op)
|
||||
}
|
||||
|
||||
// updateStatus updates the status and error fields of g. This update only happens
|
||||
// if status == NotTerminated, so that the first different status is the one
|
||||
// maintained.
|
||||
func (g *globalStatus) updateStatus(s Status, err error) {
|
||||
g.mux.Lock()
|
||||
defer g.mux.Unlock()
|
||||
if g.status != NotTerminated {
|
||||
g.status = s
|
||||
g.err = err
|
||||
}
|
||||
}
|
||||
|
||||
func (g *globalStatus) finishIteration(status Status, err error, loc *Location, op Operation) (Status, error) {
|
||||
g.mux.Lock()
|
||||
defer g.mux.Unlock()
|
||||
return finishIteration(status, err, g.stats, g.settings, g.statuser, g.startTime, loc, op)
|
||||
}
|
||||
|
||||
// globalOperation executes the requested operation at the given location.
|
||||
// When modifying this function, keep in mind that it can be called concurrently.
|
||||
// Uses of the internal fields should be through the methods of globalStatus and
|
||||
// protected by a mutex where appropriate.
|
||||
func (g *globalStatus) globalOperation(op Operation, loc *Location, x []float64) Status {
|
||||
// Do a quick check to see if one of the other workers converged in the meantime.
|
||||
status := g.getStatus()
|
||||
if status != NotTerminated {
|
||||
return status
|
||||
}
|
||||
var err error
|
||||
switch op {
|
||||
case NoOperation:
|
||||
case InitIteration:
|
||||
panic("optimize: GlobalMethod returned InitIteration")
|
||||
case PostIteration:
|
||||
panic("optimize: GlobalMethod returned PostIteration")
|
||||
case MajorIteration:
|
||||
g.incrementMajorIteration()
|
||||
g.updateOptLoc(loc)
|
||||
status = g.checkConvergence()
|
||||
default: // Any of the Evaluation operations.
|
||||
status, err = evaluate(g.p, loc, op, x)
|
||||
g.updateStats(op)
|
||||
}
|
||||
|
||||
status, err = g.finishIteration(status, err, loc, op)
|
||||
if status != NotTerminated || err != nil {
|
||||
g.updateStatus(status, err)
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
// DefaultSettingsGlobal returns the default settings for Global optimization.
|
||||
func DefaultSettingsGlobal() *Settings {
|
||||
return &Settings{
|
||||
FunctionThreshold: math.Inf(-1),
|
||||
|
@@ -148,10 +148,10 @@ func minimize(p *Problem, method Method, settings *Settings, stats *Stats, optLo
|
||||
status = checkConvergence(optLoc, settings, true)
|
||||
default: // Any of the Evaluation operations.
|
||||
status, err = evaluate(p, loc, op, x)
|
||||
updateStats(stats, op)
|
||||
updateEvaluationStats(stats, op)
|
||||
}
|
||||
|
||||
status, err = iterCleanup(status, err, stats, settings, statuser, startTime, loc, op)
|
||||
status, err = finishIteration(status, err, stats, settings, statuser, startTime, loc, op)
|
||||
if status != NotTerminated || err != nil {
|
||||
return
|
||||
}
|
||||
@@ -209,7 +209,7 @@ func getStartingLocation(p *Problem, method Method, initX []float64, stats *Stat
|
||||
}
|
||||
x := make([]float64, len(loc.X))
|
||||
evaluate(p, loc, eval, x)
|
||||
updateStats(stats, eval)
|
||||
updateEvaluationStats(stats, eval)
|
||||
}
|
||||
|
||||
if math.IsInf(loc.F, 1) || math.IsNaN(loc.F) {
|
||||
|
@@ -121,8 +121,8 @@ func checkConvergence(loc *Location, settings *Settings, local bool) Status {
|
||||
return NotTerminated
|
||||
}
|
||||
|
||||
// updateStats updates the statistics based on the operation.
|
||||
func updateStats(stats *Stats, op Operation) {
|
||||
// updateEvaluationStats updates the statistics based on the operation.
|
||||
func updateEvaluationStats(stats *Stats, op Operation) {
|
||||
if op&FuncEvaluation != 0 {
|
||||
stats.FuncEvaluations++
|
||||
}
|
||||
@@ -169,8 +169,9 @@ func checkLimits(loc *Location, stats *Stats, settings *Settings) Status {
|
||||
return NotTerminated
|
||||
}
|
||||
|
||||
// TODO(btracey): better name
|
||||
func iterCleanup(status Status, err error, stats *Stats, settings *Settings, statuser Statuser, startTime time.Time, loc *Location, op Operation) (Status, error) {
|
||||
// finishIteration performs cleanup tasks at the end of an optimization iteration.
|
||||
// It checks the status, sends information to recorders, and updates the runtime.
|
||||
func finishIteration(status Status, err error, stats *Stats, settings *Settings, statuser Statuser, startTime time.Time, loc *Location, op Operation) (Status, error) {
|
||||
if status != NotTerminated || err != nil {
|
||||
return status, err
|
||||
}
|
||||
|
Reference in New Issue
Block a user