optimize: completely overhaul Global (#352)

* optimize: completely overhaul Global

The previous implementation of Global was a minefield for incorrectly implementing global optimization methods. It was very difficult to correctly implement methods (both of the provided methods were incorrect), and the resulting code is very ugly. This commit switches to use channels to communicate, allowing a more clear ordering of concurrent code. This also enables better shutdown of methods.

In addition to the main fix of Global, this refactors the two Global methods to use the updated interface, and makes some small improvements that were previously not possible. In addition, there are some small cleanups of Local to better match between the two calls.

If anyone has been curious about what is meant by 'Don't communicate by sharing memory, share memory by communicating' this is it, and why.

* respond to PR comments

* make constants

* simplify termination logic

* optimize: simplify stats collection

* overhaul documentation and respond to PR comments

* implement PR requests

* clean up cmaes
This commit is contained in:
Brendan Tracey
2018-02-05 08:44:02 -07:00
committed by GitHub
parent 95fab73f1d
commit 996b88e8f8
8 changed files with 565 additions and 440 deletions

View File

@@ -7,7 +7,6 @@ package optimize
import ( import (
"math" "math"
"sort" "sort"
"sync"
"golang.org/x/exp/rand" "golang.org/x/exp/rand"
@@ -103,15 +102,15 @@ type CmaEsChol struct {
mean []float64 mean []float64
chol mat.Cholesky chol mat.Cholesky
// Parallel fields.
mux sync.Mutex // protect access to evals.
wg sync.WaitGroup // wait for simulations to finish before iterating.
taskIdxs []int // Stores which simulation the task ran.
evals []int // remaining evaluations in this iteration.
// Overall best. // Overall best.
bestX []float64 bestX []float64
bestF float64 bestF float64
// Synchronization.
sentIdx int
receivedIdx int
operation chan<- GlobalTask
updateErr error
} }
var ( var (
@@ -123,21 +122,26 @@ func (cma *CmaEsChol) Needs() struct{ Gradient, Hessian bool } {
return struct{ Gradient, Hessian bool }{false, false} return struct{ Gradient, Hessian bool }{false, false}
} }
func (cma *CmaEsChol) Done() {} func (cma *CmaEsChol) methodConverged() Status {
// Status returns the status of the method.
func (cma *CmaEsChol) Status() (Status, error) {
sd := cma.StopLogDet sd := cma.StopLogDet
switch { switch {
case math.IsNaN(sd): case math.IsNaN(sd):
return NotTerminated, nil return NotTerminated
case sd == 0: case sd == 0:
sd = float64(cma.dim) * -36.8413614879 // ln(1e-16) sd = float64(cma.dim) * -36.8413614879 // ln(1e-16)
} }
if cma.chol.LogDet() < sd { if cma.chol.LogDet() < sd {
return MethodConverge, nil return MethodConverge
} }
return NotTerminated, nil return NotTerminated
}
// Status returns the status of the method.
func (cma *CmaEsChol) Status() (Status, error) {
if cma.updateErr != nil {
return Failure, cma.updateErr
}
return cma.methodConverged(), nil
} }
func (cma *CmaEsChol) InitGlobal(dim, tasks int) int { func (cma *CmaEsChol) InitGlobal(dim, tasks int) int {
@@ -226,90 +230,169 @@ func (cma *CmaEsChol) InitGlobal(dim, tasks int) int {
cma.chol = chol cma.chol = chol
} }
cma.evals = make([]int, cma.pop)
for i := range cma.evals {
cma.evals[i] = i
}
cma.bestX = resize(cma.bestX, dim) cma.bestX = resize(cma.bestX, dim)
cma.bestF = math.Inf(1) cma.bestF = math.Inf(1)
cma.sentIdx = 0
cma.receivedIdx = 0
cma.operation = nil
cma.updateErr = nil
t := min(tasks, cma.pop) t := min(tasks, cma.pop)
cma.taskIdxs = make([]int, t)
for i := 0; i < t; i++ {
cma.taskIdxs[i] = -1
}
// Get a new mutex and waitgroup so that if the structure is reused there
// aren't residual interactions with the previous optimization.
cma.mux = sync.Mutex{}
cma.wg = sync.WaitGroup{}
return t return t
} }
func (cma *CmaEsChol) IterateGlobal(task int, loc *Location) (Operation, error) { func (cma *CmaEsChol) sendInitTasks(tasks []GlobalTask) {
// Check the status of the incoming task. If it is a number, it means for i, task := range tasks {
// that task contains a valid location. cma.sendTask(i, task)
idx := cma.taskIdxs[task] }
if idx != -1 { cma.sentIdx = len(tasks)
cma.fs[idx] = loc.F
cma.wg.Done()
} }
// Get the next task and send it to be run if there is a next task to be run. // sendTask generates a sample and sends the task. It does not update the cma index.
// If all of the tasks have been run, perform an update step. Note that the func (cma *CmaEsChol) sendTask(idx int, task GlobalTask) {
// use of this mutex means that only one task can proceed, all of the task.ID = idx
// other tasks should get stuck and then get a new location. task.Op = FuncEvaluation
cma.mux.Lock() distmv.NormalRand(cma.xs.RawRowView(idx), cma.mean, &cma.chol, cma.Src)
if len(cma.evals) != 0 { copy(task.X, cma.xs.RawRowView(idx))
// There are still tasks to evaluate. Grab one and remove it from the list. cma.operation <- task
newIdx := cma.evals[len(cma.evals)-1]
cma.evals = cma.evals[:len(cma.evals)-1]
cma.wg.Add(1)
cma.mux.Unlock()
// Sample x and send it to be evaluated.
distmv.NormalRand(cma.xs.RawRowView(newIdx), cma.mean, &cma.chol, cma.Src)
copy(loc.X, cma.xs.RawRowView(newIdx))
cma.taskIdxs[task] = newIdx
return FuncEvaluation, nil
} }
// There are no more tasks to evaluate. This means the iteration is over.
// Find the best current f, update the parameters, and re-establish
// the evaluations to run.
// Wait for all of the outstanding tasks to finish, so the full set of functions // bestIdx returns the best index in the functions. Returns -1 if all values
// has been evaluated. // are NaN.
cma.wg.Wait() func (cma *CmaEsChol) bestIdx() int {
best := -1
bestVal := math.Inf(1)
for i, v := range cma.fs {
if math.IsNaN(v) {
continue
}
// Use equality in case somewhere evaluates to +inf.
if v <= bestVal {
best = i
bestVal = v
}
}
return best
}
// Find the best f out of all the tasks. // findBestAndUpdateTask finds the best task in the current list, updates the
best := floats.MinIdx(cma.fs) // new best overall, and then stores the best location into task.
bestF := cma.fs[best] func (cma *CmaEsChol) findBestAndUpdateTask(task GlobalTask) GlobalTask {
bestX := cma.xs.RawRowView(best) // Find and update the best location.
// Don't use floats because there may be NaN values.
best := cma.bestIdx()
bestF := math.NaN()
bestX := cma.xs.RawRowView(0)
if best != -1 {
bestF = cma.fs[best]
bestX = cma.xs.RawRowView(best)
}
if cma.ForgetBest { if cma.ForgetBest {
loc.F = bestF task.F = bestF
copy(loc.X, bestX) copy(task.X, bestX)
} else { } else {
if bestF < cma.bestF { if bestF < cma.bestF {
cma.bestF = bestF cma.bestF = bestF
copy(cma.bestX, bestX) copy(cma.bestX, bestX)
} }
loc.F = cma.bestF task.F = cma.bestF
copy(loc.X, cma.bestX) copy(task.X, cma.bestX)
}
return task
} }
cma.taskIdxs[task] = -1 func (cma *CmaEsChol) RunGlobal(operations chan<- GlobalTask, results <-chan GlobalTask, tasks []GlobalTask) {
cma.operation = operations
// Send the initial tasks. We know there are at most as many tasks as elements
// of the population.
cma.sendInitTasks(tasks)
// Update the parameters of the distribution Loop:
for {
result := <-results
switch result.Op {
default:
panic("unknown operation")
case PostIteration:
break Loop
case MajorIteration:
// The last thing we did was update all of the tasks and send the
// major iteration. Now we can send a group of tasks again.
cma.sendInitTasks(tasks)
case FuncEvaluation:
cma.receivedIdx++
cma.fs[result.ID] = result.F
switch {
case cma.sentIdx < cma.pop:
// There are still tasks to evaluate. Send the next.
cma.sendTask(cma.sentIdx, result)
cma.sentIdx++
case cma.receivedIdx < cma.pop:
// All the tasks have been sent, but not all of them have been received.
// Need to wait until all are back.
continue Loop
default:
// All of the evaluations have been received.
if cma.receivedIdx != cma.pop {
panic("bad logic")
}
cma.receivedIdx = 0
cma.sentIdx = 0
task := cma.findBestAndUpdateTask(result)
// Update the parameters and send a MajorIteration or a convergence.
err := cma.update() err := cma.update()
// Kill the existing data.
// Reset the tasks for i := range cma.fs {
cma.evals = cma.evals[:cma.pop] cma.fs[i] = math.NaN()
cma.xs.Set(i, 0, math.NaN())
cma.mux.Unlock() }
return MajorIteration, err switch {
case err != nil:
cma.updateErr = err
task.Op = MethodDone
case cma.methodConverged() != NotTerminated:
task.Op = MethodDone
default:
task.Op = MajorIteration
task.ID = -1
}
operations <- task
}
}
} }
// update computes the new parameters (mean, cholesky, etc.) // Been told to stop. Clean up.
// Need to see best of our evaluated tasks so far. Should instead just
// collect, then see.
for task := range results {
switch task.Op {
case MajorIteration:
case FuncEvaluation:
cma.fs[task.ID] = task.F
default:
panic("unknown operation")
}
}
// Send the new best value if the evaluation is better than any we've
// found so far. Keep this separate from findBestAndUpdateTask so that
// we only send an iteration if we find a better location.
if !cma.ForgetBest {
best := cma.bestIdx()
if best != -1 && cma.fs[best] < cma.bestF {
task := tasks[0]
task.F = cma.fs[best]
copy(task.X, cma.xs.RawRowView(best))
task.Op = MajorIteration
task.ID = -1
operations <- task
}
}
close(operations)
}
// update computes the new parameters (mean, cholesky, etc.). Does not update
// any of the synchronization parameters (taskIdx).
func (cma *CmaEsChol) update() error { func (cma *CmaEsChol) update() error {
// Sort the function values to find the elite samples. // Sort the function values to find the elite samples.
ftmp := make([]float64, cma.pop) ftmp := make([]float64, cma.pop)

View File

@@ -21,7 +21,7 @@ type cmaTestCase struct {
problem Problem problem Problem
method *CmaEsChol method *CmaEsChol
settings *Settings settings *Settings
good func(*Result, error) error good func(result *Result, err error, concurrent int) error
} }
func cmaTestCases() []cmaTestCase { func cmaTestCases() []cmaTestCase {
@@ -42,7 +42,7 @@ func cmaTestCases() []cmaTestCase {
settings: &Settings{ settings: &Settings{
FunctionThreshold: 0.01, FunctionThreshold: 0.01,
}, },
good: func(result *Result, err error) error { good: func(result *Result, err error, concurrent int) error {
if result.Status != FunctionThreshold { if result.Status != FunctionThreshold {
return errors.New("result not function threshold") return errors.New("result not function threshold")
} }
@@ -63,7 +63,7 @@ func cmaTestCases() []cmaTestCase {
settings: &Settings{ settings: &Settings{
FunctionThreshold: math.Inf(-1), FunctionThreshold: math.Inf(-1),
}, },
good: func(result *Result, err error) error { good: func(result *Result, err error, concurrent int) error {
if result.Status != MethodConverge { if result.Status != MethodConverge {
return errors.New("result not method converge") return errors.New("result not method converge")
} }
@@ -82,24 +82,30 @@ func cmaTestCases() []cmaTestCase {
}, },
method: &CmaEsChol{ method: &CmaEsChol{
Population: 100, Population: 100,
ForgetBest: true, // Otherwise may get an update at the end.
}, },
settings: &Settings{ settings: &Settings{
FunctionThreshold: math.Inf(-1), FunctionThreshold: math.Inf(-1),
MajorIterations: 10, MajorIterations: 10,
}, },
good: func(result *Result, err error) error { good: func(result *Result, err error, concurrent int) error {
if result.Status != IterationLimit { if result.Status != IterationLimit {
return errors.New("result not iteration limit") return errors.New("result not iteration limit")
} }
if result.FuncEvaluations != 1000 { threshLower := 10
return errors.New("wrong number of evaluations") threshUpper := 10
if concurrent != 0 {
// Could have one more from final update.
threshUpper++
}
if result.MajorIterations < threshLower || result.MajorIterations > threshUpper {
return errors.New("wrong number of iterations")
} }
return nil return nil
}, },
}, },
{ {
// Test that works properly in parallel, and stops with some // Test that work stops with some number of function evaluations.
// number of function evaluations.
dim: 5, dim: 5,
problem: Problem{ problem: Problem{
Func: functions.ExtendedRosenbrock{}.Func, Func: functions.ExtendedRosenbrock{}.Func,
@@ -108,18 +114,22 @@ func cmaTestCases() []cmaTestCase {
Population: 100, Population: 100,
}, },
settings: &Settings{ settings: &Settings{
Concurrent: 5,
FunctionThreshold: math.Inf(-1), FunctionThreshold: math.Inf(-1),
FuncEvaluations: 250, // Somewhere in the middle of an iteration. FuncEvaluations: 250, // Somewhere in the middle of an iteration.
}, },
good: func(result *Result, err error) error { good: func(result *Result, err error, concurrent int) error {
if result.Status != FunctionEvaluationLimit { if result.Status != FunctionEvaluationLimit {
return errors.New("result not function evaluations") return errors.New("result not function evaluations")
} }
if result.FuncEvaluations < 250 { threshLower := 250
threshUpper := 251
if concurrent != 0 {
threshUpper = threshLower + concurrent
}
if result.FuncEvaluations < threshLower {
return errors.New("too few function evaluations") return errors.New("too few function evaluations")
} }
if result.FuncEvaluations > 250+4 { // can't guarantee exactly, because could grab extras in parallel first. if result.FuncEvaluations > threshUpper {
return errors.New("too many function evaluations") return errors.New("too many function evaluations")
} }
return nil return nil
@@ -137,7 +147,7 @@ func cmaTestCases() []cmaTestCase {
settings: &Settings{ settings: &Settings{
FunctionThreshold: math.Inf(-1), FunctionThreshold: math.Inf(-1),
}, },
good: func(result *Result, err error) error { good: func(result *Result, err error, concurrent int) error {
if result.Status != MethodConverge { if result.Status != MethodConverge {
return errors.New("result not method converge") return errors.New("result not method converge")
} }
@@ -157,15 +167,16 @@ func cmaTestCases() []cmaTestCase {
Population: 100, // Increase the population size to reduce noise. Population: 100, // Increase the population size to reduce noise.
InitMean: localMinMean, InitMean: localMinMean,
InitCholesky: &localMinChol, InitCholesky: &localMinChol,
ForgetBest: true, // So that if it accidentally finds a better place we still converge to the minimum.
}, },
settings: &Settings{ settings: &Settings{
FunctionThreshold: math.Inf(-1), FunctionThreshold: math.Inf(-1),
}, },
good: func(result *Result, err error) error { good: func(result *Result, err error, concurrent int) error {
if result.Status != MethodConverge { if result.Status != MethodConverge {
return errors.New("result not method converge") return errors.New("result not method converge")
} }
if !floats.EqualApprox(result.X, []float64{2, -2}, 1e-2) { if !floats.EqualApprox(result.X, []float64{2, -2}, 3e-2) {
return errors.New("local minimum not found") return errors.New("local minimum not found")
} }
return nil return nil
@@ -181,14 +192,22 @@ func TestCmaEsChol(t *testing.T) {
method.Src = src method.Src = src
// Run and check that the expected termination occurs. // Run and check that the expected termination occurs.
result, err := Global(test.problem, test.dim, test.settings, method) result, err := Global(test.problem, test.dim, test.settings, method)
if testErr := test.good(result, err); testErr != nil { if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
t.Errorf("cas %d: %v", i, testErr) t.Errorf("cas %d: %v", i, testErr)
} }
// Run a second time to make sure there are no residual effects // Run a second time to make sure there are no residual effects
result, err = Global(test.problem, test.dim, test.settings, method) result, err = Global(test.problem, test.dim, test.settings, method)
if testErr := test.good(result, err); testErr != nil { if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
t.Errorf("cas %d second: %v", i, testErr) t.Errorf("cas %d second: %v", i, testErr)
} }
// Test the problem in parallel.
test.settings.Concurrent = 5
result, err = Global(test.problem, test.dim, test.settings, method)
if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
t.Errorf("cas %d concurrent: %v", i, testErr)
}
test.settings.Concurrent = 0
} }
} }

View File

@@ -6,65 +6,65 @@ package optimize
import ( import (
"math" "math"
"sync"
"time" "time"
) )
var ( // DefaultSettingsGlobal returns the default settings for Global optimization.
nonpositiveDimension string = "optimize: non-positive input dimension" func DefaultSettingsGlobal() *Settings {
negativeTasks string = "optimize: negative input number of tasks" return &Settings{
) FunctionThreshold: math.Inf(-1),
FunctionConverge: &FunctionConverge{
Absolute: 1e-10,
Iterations: 100,
},
}
}
// GlobalMethod is an optimization method which seeks to find the global minimum // GlobalTask is a type to communicate between the GlobalMethod and the outer
// of an objective function. // calling script.
// type GlobalTask struct {
// At the beginning of the optimization, InitGlobal is called to communicate ID int
// the dimension of the input and maximum number of concurrent tasks. Op Operation
// The actual number of concurrent tasks will be set from the return of InitGlobal, *Location
// which must not be greater than the input tasks. }
//
// During the optimization, a reverse-communication interface is used between // GlobalMethod is a type which can search for a global optimum for an objective function.
// the GlobalMethod and the caller.
// GlobalMethod acts as a client that asks the caller to perform
// needed operations given the return from IterateGlobal.
// This allows and enforces automation of maintaining statistics and checking for
// (various types of) convergence.
//
// The return from IterateGlobal can be an Evaluation, a MajorIteration or NoOperation.
//
// An evaluation is one or more of the Evaluation operations (FuncEvaluation,
// GradEvaluation, etc.) combined with the bitwise or operator. In an evaluation
// operation, the requested fields of Problem will be evaluated at the value
// in Location.X, filling the corresponding fields of Location. These values
// can be retrieved and used upon the next call to IterateGlobal with that task id.
// The GlobalMethod interface requires that entries of Location are not modified
// aside from the commanded evaluations. Thus, the type implementing GlobalMethod
// may use multiple Operations to set the Location fields at a particular x value.
//
// When IterateGlobal declares MajorIteration, the caller updates the optimal
// location to the values in Location, and checks for convergence. The type
// implementing GlobalMethod must make sure that the fields of Location are valid
// and consistent.
//
// IterateGlobal must not return InitIteration and PostIteration operations. These are
// reserved for the clients to be passed to Recorders. A Method must also not
// combine the Evaluation operations with the Iteration operations.
type GlobalMethod interface { type GlobalMethod interface {
Needser Needser
// InitGlobal communicates the input dimension and maximum number of tasks, // InitGlobal takes as input the problem dimension and number of available
// and returns the number of concurrent processes. The return must be less // concurrent tasks, and returns the number of concurrent processes to be used.
// than or equal to tasks. // The returned value must be less than or equal to tasks.
InitGlobal(dim, tasks int) int InitGlobal(dim, tasks int) int
// RunGlobal runs a global optimization. The method sends GlobalTasks on
// IterateGlobal retrieves information from the location associated with // the operation channel (for performing function evaluations, major
// the given task ID, and returns the next operation to perform with that // iterations, etc.). The result of the tasks will be returned on Result.
// Location. IterateGlobal may assume that the same pointer is associated // See the documentation for Operation types for the possible tasks.
// with the same task. //
IterateGlobal(task int, loc *Location) (Operation, error) // The caller of RunGlobal will signal the termination of the optimization
// (i.e. convergence from user settings) by sending a task with a PostIteration
// Done communicates that the optimization has concluded to allow for shutdown. // Op field on result. More tasks may still be sent on operation after this
// After Done is called, no more calls to IterateGlobal will be made. // occurs, but only MajorIteration operations will still be conducted
Done() // appropriately. Thus, it can not be guaranteed that all Evaluations sent
// on operation will be evaluated, however if an Evaluation is started,
// the results of that evaluation will be sent on results.
//
// The GlobalMethod must read from the result channel until it is closed.
// During this, the GlobalMethod may want to send new MajorIteration(s) on
// operation. GlobalMethod then must close operation, and return from RunGlobal.
//
// The las parameter to RunGlobal is a slice of tasks with length equal to
// the return from InitGlobal. GlobalTask has an ID field which may be
// set and modified by GlobalMethod, and must not be modified by the caller.
//
// GlobalMethod may have its own specific convergence criteria, which can
// be communicated using a MethodDone operation. This will trigger a
// PostIteration to be sent on result, and the MethodDone task will not be
// returned on result. The GlobalMethod must implement Statuser, and the
// call to Status must return a Status other than NotTerminated.
//
// The operation and result tasks are guaranteed to have a buffer length
// equal to the return from InitGlobal.
RunGlobal(operation chan<- GlobalTask, result <-chan GlobalTask, tasks []GlobalTask)
} }
// Global uses a global optimizer to search for the global minimum of a // Global uses a global optimizer to search for the global minimum of a
@@ -86,30 +86,23 @@ type GlobalMethod interface {
// The third argument contains the settings for the minimization. The // The third argument contains the settings for the minimization. The
// DefaultGlobalSettings function can be called for a Settings struct with the // DefaultGlobalSettings function can be called for a Settings struct with the
// default values initialized. If settings == nil, the default settings are used. // default values initialized. If settings == nil, the default settings are used.
// Global optimization methods typically do not make assumptions about the number // All of the settings will be followed, but many of them may be counterproductive
// and location of local minima. Thus, the only convergence metric used is the // to use (such as GradientThreshold). Global cannot guarantee strict adherence
// function values found at major iterations of the optimization. Bounds on the // to the bounds specified when performing concurrent evaluations and updates.
// length of optimization are obeyed, such as the number of allowed function
// evaluations.
// //
// The final argument is the optimization method to use. If method == nil, then // The final argument is the optimization method to use. If method == nil, then
// an appropriate default is chosen based on the properties of the other arguments // an appropriate default is chosen based on the properties of the other arguments
// (dimension, gradient-free or gradient-based, etc.). // (dimension, gradient-free or gradient-based, etc.).
// //
// If method implements Statuser, method.Status is called before every call
// to method.Iterate. If the returned Status is not NotTerminated or the
// error is non-nil, the optimization run is terminated.
//
// Global returns a Result struct and any error that occurred. See the // Global returns a Result struct and any error that occurred. See the
// documentation of Result for more information. // documentation of Result for more information.
// //
// See the documentation for GlobalMethod for the details on implementing a method.
//
// Be aware that the default behavior of Global is to find the minimum. // Be aware that the default behavior of Global is to find the minimum.
// For certain functions and optimization methods, this process can take many // For certain functions and optimization methods, this process can take many
// function evaluations. The Settings input struct can be used to limit this, // function evaluations. The Settings input struct can be used to limit this,
// for example by modifying the maximum runtime or maximum function evaluations. // for example by modifying the maximum runtime or maximum function evaluations.
//
// Global cannot guarantee strict adherence to the bounds specified in Settings
// when performing concurrent evaluations and updates.
func Global(p Problem, dim int, settings *Settings, method GlobalMethod) (*Result, error) { func Global(p Problem, dim int, settings *Settings, method GlobalMethod) (*Result, error) {
startTime := time.Now() startTime := time.Now()
if method == nil { if method == nil {
@@ -157,23 +150,10 @@ func Global(p Problem, dim int, settings *Settings, method GlobalMethod) (*Resul
}, err }, err
} }
// minimizeGlobal is the high-level function for a Global optimization. It launches // minimizeGlobal performs a Global optimization. minimizeGlobal updates the
// concurrent workers to perform the mimization, and shuts them down properly // settings and optLoc, and returns the final Status and error.
// at the conclusion. func minimizeGlobal(prob *Problem, method GlobalMethod, settings *Settings, stats *Stats, optLoc *Location, startTime time.Time) (Status, error) {
func minimizeGlobal(p *Problem, method GlobalMethod, settings *Settings, stats *Stats, optLoc *Location, startTime time.Time) (status Status, err error) {
dim := len(optLoc.X) dim := len(optLoc.X)
statuser, _ := method.(Statuser)
gs := &globalStatus{
mux: &sync.RWMutex{},
stats: stats,
status: NotTerminated,
p: p,
startTime: startTime,
optLoc: optLoc,
settings: settings,
statuser: statuser,
}
nTasks := settings.Concurrent nTasks := settings.Concurrent
if nTasks == 0 { if nTasks == 0 {
nTasks = 1 nTasks = 1
@@ -184,158 +164,160 @@ func minimizeGlobal(p *Problem, method GlobalMethod, settings *Settings, stats *
} }
nTasks = newNTasks nTasks = newNTasks
// Launch optimization workers. Each worker is individually responsible // Launch the method. The method communicates tasks using the operations
// for maintaining stats and evaluating the function. // channel, and results is used to return the evaluated results.
var wg sync.WaitGroup operations := make(chan GlobalTask, nTasks)
for task := 0; task < nTasks; task++ { results := make(chan GlobalTask, nTasks)
wg.Add(1) go func() {
go func(task int) { tasks := make([]GlobalTask, nTasks)
defer wg.Done() for i := range tasks {
loc := newLocation(dim, method) tasks[i].Location = newLocation(dim, method)
x := make([]float64, dim)
globalWorker(task, method, gs, loc, x)
}(task)
}
wg.Wait()
method.Done()
return gs.status, gs.err
} }
method.RunGlobal(operations, results, tasks)
}()
// globalWorker runs the optimization steps for a single (concurrently-executing) // Algorithmic Overview:
// optimization task. // There are three pieces to performing a concurrent global optimization,
func globalWorker(task int, m GlobalMethod, g *globalStatus, loc *Location, x []float64) { // the distributor, the workers, and the stats combiner. At a high level,
// the distributor reads in tasks sent by method, sending evaluations to the
// workers, and forwarding other operations to the statsCombiner. The workers
// read these forwarded evaluation tasks, evaluate the relevant parts of Problem
// and forward the results on to the stats combiner. The stats combiner reads
// in results from the workers, as well as tasks from the distributor, and
// uses them to update optimization statistics (function evaluations, etc.)
// and to check optimization convergence.
//
// The complicated part is correctly shutting down the optimization. The
// procedure is as follows. First, the stats combiner closes done and sends
// a PostIteration to the method. The distributor then reads that done has
// been closed, and closes the channel with the workers. At this point, no
// more evaluation operations will be executed. As the workers finish their
// evaluations, they forward the results onto the stats combiner, and then
// signal their shutdown to the stats combiner. When all workers have successfully
// finished, the stats combiner closes the results channel, signaling to the
// method that all results have been collected. At this point, the method
// may send MajorIteration(s) to update an optimum location based on these
// last returned results, and then the method will close the operations channel.
// Now that no more tasks will be commanded by the method, the distributor
// closes statsChan, and with no more statistics to update the optimization
// concludes.
workerChan := make(chan GlobalTask) // Delegate tasks to the workers.
statsChan := make(chan GlobalTask) // Send evaluation updates.
done := make(chan struct{}) // Communicate the optimization is done.
// Read tasks from the method and distribute as appropriate.
distributor := func() {
for { for {
// Find Evaluation location select {
op, err := m.IterateGlobal(task, loc) case task := <-operations:
if err != nil { switch task.Op {
g.updateStatus(Failure, err)
break
}
// Evaluate location and/or update stats.
status := g.globalOperation(op, loc, x)
if status != NotTerminated {
break
}
}
}
// globalStatus coordinates access to information shared between concurrently
// executing optimization tasks.
type globalStatus struct {
mux *sync.RWMutex
stats *Stats
status Status
p *Problem
startTime time.Time
optLoc *Location
settings *Settings
method GlobalMethod
statuser Statuser
err error
}
// getStatus returns the current status of the optimization.
func (g *globalStatus) getStatus() Status {
var status Status
g.mux.RLock()
defer g.mux.RUnlock()
status = g.status
return status
}
func (g *globalStatus) incrementMajorIteration() {
g.mux.Lock()
defer g.mux.Unlock()
g.stats.MajorIterations++
}
func (g *globalStatus) updateOptLoc(loc *Location) {
g.mux.Lock()
defer g.mux.Unlock()
copyLocation(g.optLoc, loc)
}
// checkConvergence checks the convergence of the global optimization and returns
// the status
func (g *globalStatus) checkConvergence() Status {
g.mux.RLock()
defer g.mux.RUnlock()
return checkConvergence(g.optLoc, g.settings, false)
}
// updateStats updates the evaluation statistics for the given operation.
func (g *globalStatus) updateStats(op Operation) {
g.mux.Lock()
defer g.mux.Unlock()
updateEvaluationStats(g.stats, op)
}
// updateStatus updates the status and error fields of g. This update only happens
// if status == NotTerminated, so that the first different status is the one
// maintained.
func (g *globalStatus) updateStatus(s Status, err error) {
g.mux.Lock()
defer g.mux.Unlock()
if s != NotTerminated {
g.status = s
g.err = err
}
}
func (g *globalStatus) finishIteration(status Status, err error, loc *Location, op Operation) (Status, error) {
g.mux.Lock()
defer g.mux.Unlock()
return finishIteration(status, err, g.stats, g.settings, g.statuser, g.startTime, loc, op)
}
// globalOperation executes the requested operation at the given location.
// When modifying this function, keep in mind that it can be called concurrently.
// Uses of the internal fields should be through the methods of globalStatus and
// protected by a mutex where appropriate.
func (g *globalStatus) globalOperation(op Operation, loc *Location, x []float64) Status {
// Do a quick check to see if one of the other workers converged in the meantime.
status := g.getStatus()
if status != NotTerminated {
return status
}
var err error
switch op {
case NoOperation:
case InitIteration: case InitIteration:
panic("optimize: GlobalMethod returned InitIteration") panic("optimize: GlobalMethod returned InitIteration")
case PostIteration: case PostIteration:
panic("optimize: GlobalMethod returned PostIteration") panic("optimize: GlobalMethod returned PostIteration")
case NoOperation, MajorIteration, MethodDone:
statsChan <- task
default:
if !task.Op.isEvaluation() {
panic("global: expecting evaluation operation")
}
workerChan <- task
}
case <-done:
// No more evaluations will be sent, shut down the workers, and
// read the final tasks.
close(workerChan)
for task := range operations {
if task.Op == MajorIteration {
statsChan <- task
}
}
close(statsChan)
return
}
}
}
go distributor()
// Evaluate the Problem concurrently.
worker := func() {
x := make([]float64, dim)
for task := range workerChan {
evaluate(prob, task.Location, task.Op, x)
statsChan <- task
}
// Signal successful worker completion.
statsChan <- GlobalTask{Op: signalDone}
}
for i := 0; i < nTasks; i++ {
go worker()
}
var (
workersDone int // effective wg for the workers
status Status
err error
finalStatus Status
finalError error
)
// Update optimization statistics and check convergence.
for task := range statsChan {
switch task.Op {
default:
if !task.Op.isEvaluation() {
panic("global: evaluation task expected")
}
updateEvaluationStats(stats, task.Op)
status, err = checkEvaluationLimits(prob, stats, settings)
case signalDone:
workersDone++
if workersDone == nTasks {
close(results)
}
continue
case NoOperation:
// Just send the task back.
case MajorIteration: case MajorIteration:
g.incrementMajorIteration() status = performMajorIteration(optLoc, task.Location, stats, startTime, settings)
g.updateOptLoc(loc) case MethodDone:
status = g.checkConvergence() statuser, ok := method.(Statuser)
default: // Any of the Evaluation operations. if !ok {
status, err = evaluate(g.p, loc, op, x) panic("optimize: global method returned MethodDone but is not a Statuser")
g.updateStats(op)
} }
status, err = statuser.Status()
status, err = g.finishIteration(status, err, loc, op) if status == NotTerminated {
panic("optimize: global method returned MethodDone but a NotTerminated status")
}
}
if settings.Recorder != nil && status == NotTerminated && err == nil {
stats.Runtime = time.Since(startTime)
// Allow err to be overloaded if the Recorder fails.
err = settings.Recorder.Record(task.Location, task.Op, stats)
if err != nil {
status = Failure
}
}
// If this is the first termination status, trigger the conclusion of
// the optimization.
if status != NotTerminated || err != nil { if status != NotTerminated || err != nil {
g.updateStatus(status, err) select {
case <-done:
default:
finalStatus = status
finalError = err
results <- GlobalTask{
Op: PostIteration,
} }
return status close(done)
}
// DefaultSettingsGlobal returns the default settings for Global optimization.
func DefaultSettingsGlobal() *Settings {
return &Settings{
FunctionThreshold: math.Inf(-1),
FunctionConverge: &FunctionConverge{
Absolute: 1e-10,
Iterations: 100,
},
} }
} }
func min(a, b int) int { // Send the result back to the Problem if there are still active workers.
if a < b { if workersDone != nTasks && task.Op != MethodDone {
return a results <- task
} }
return b }
return finalStatus, finalError
} }

View File

@@ -6,7 +6,6 @@ package optimize
import ( import (
"math" "math"
"sync"
"gonum.org/v1/gonum/stat/distmv" "gonum.org/v1/gonum/stat/distmv"
) )
@@ -16,9 +15,6 @@ import (
type GuessAndCheck struct { type GuessAndCheck struct {
Rander distmv.Rander Rander distmv.Rander
eval []bool
mux *sync.Mutex
bestF float64 bestF float64
bestX []float64 bestX []float64
} }
@@ -27,34 +23,68 @@ func (g *GuessAndCheck) Needs() struct{ Gradient, Hessian bool } {
return struct{ Gradient, Hessian bool }{false, false} return struct{ Gradient, Hessian bool }{false, false}
} }
func (g *GuessAndCheck) Done() {
// No cleanup needed
}
func (g *GuessAndCheck) InitGlobal(dim, tasks int) int { func (g *GuessAndCheck) InitGlobal(dim, tasks int) int {
g.eval = make([]bool, tasks) if dim <= 0 {
panic(nonpositiveDimension)
}
if tasks < 0 {
panic(negativeTasks)
}
g.bestF = math.Inf(1) g.bestF = math.Inf(1)
g.bestX = resize(g.bestX, dim) g.bestX = resize(g.bestX, dim)
g.mux = &sync.Mutex{}
return tasks return tasks
} }
func (g *GuessAndCheck) IterateGlobal(task int, loc *Location) (Operation, error) { func (g *GuessAndCheck) sendNewLoc(operation chan<- GlobalTask, task GlobalTask) {
// Task is true if it contains a new function evaluation. g.Rander.Rand(task.X)
if g.eval[task] { task.Op = FuncEvaluation
g.eval[task] = false operation <- task
g.mux.Lock() }
if loc.F < g.bestF {
g.bestF = loc.F func (g *GuessAndCheck) updateMajor(operation chan<- GlobalTask, task GlobalTask) {
copy(g.bestX, loc.X) // Update the best value seen so far, and send a MajorIteration.
if task.F < g.bestF {
g.bestF = task.F
copy(g.bestX, task.X)
} else { } else {
loc.F = g.bestF task.F = g.bestF
copy(loc.X, g.bestX) copy(task.X, g.bestX)
} }
g.mux.Unlock() task.Op = MajorIteration
return MajorIteration, nil operation <- task
} }
g.eval[task] = true
g.Rander.Rand(loc.X) func (g *GuessAndCheck) RunGlobal(operation chan<- GlobalTask, result <-chan GlobalTask, tasks []GlobalTask) {
return FuncEvaluation, nil // Send initial tasks to evaluate
for _, task := range tasks {
g.sendNewLoc(operation, task)
}
// Read from the channel until PostIteration is sent.
Loop:
for {
task := <-result
switch task.Op {
default:
panic("unknown operation")
case PostIteration:
break Loop
case MajorIteration:
g.sendNewLoc(operation, task)
case FuncEvaluation:
g.updateMajor(operation, task)
}
}
// PostIteration was sent. Update the best new values.
for task := range result {
switch task.Op {
default:
panic("unknown operation")
case MajorIteration:
case FuncEvaluation:
g.updateMajor(operation, task)
}
}
close(operation)
} }

View File

@@ -27,6 +27,7 @@ func TestGuessAndCheck(t *testing.T) {
panic("bad test") panic("bad test")
} }
Global(problem, dim, nil, &GuessAndCheck{Rander: d}) Global(problem, dim, nil, &GuessAndCheck{Rander: d})
settings := DefaultSettingsGlobal() settings := DefaultSettingsGlobal()
settings.Concurrent = 5 settings.Concurrent = 5
settings.MajorIterations = 15 settings.MajorIterations = 15

View File

@@ -95,7 +95,7 @@ func Local(p Problem, initX []float64, settings *Settings, method Method) (*Resu
} }
// Check if the starting location satisfies the convergence criteria. // Check if the starting location satisfies the convergence criteria.
status := checkConvergence(optLoc, settings, true) status := checkLocationConvergence(optLoc, settings)
// Run optimization // Run optimization
if status == NotTerminated && err == nil { if status == NotTerminated && err == nil {
@@ -123,8 +123,6 @@ func minimize(p *Problem, method Method, settings *Settings, stats *Stats, optLo
copyLocation(loc, optLoc) copyLocation(loc, optLoc)
x := make([]float64, len(loc.X)) x := make([]float64, len(loc.X))
statuser, _ := method.(Statuser)
var op Operation var op Operation
op, err = method.Init(loc) op, err = method.Init(loc)
if err != nil { if err != nil {
@@ -143,18 +141,34 @@ func minimize(p *Problem, method Method, settings *Settings, stats *Stats, optLo
case PostIteration: case PostIteration:
panic("optimize: Method returned PostIteration") panic("optimize: Method returned PostIteration")
case MajorIteration: case MajorIteration:
copyLocation(optLoc, loc) status = performMajorIteration(optLoc, loc, stats, startTime, settings)
stats.MajorIterations++ case MethodDone:
status = checkConvergence(optLoc, settings, true) statuser, ok := method.(Statuser)
default: // Any of the Evaluation operations. if !ok {
status, err = evaluate(p, loc, op, x) panic("optimize: method returned MethodDone is not a Statuser")
updateEvaluationStats(stats, op) }
status, err = statuser.Status()
if status == NotTerminated {
panic("optimize: method returned MethodDone but a NotTerminated status")
}
default: // Any of the Evaluation operations.
evaluate(p, loc, op, x)
updateEvaluationStats(stats, op)
status, err = checkEvaluationLimits(p, stats, settings)
} }
status, err = finishIteration(status, err, stats, settings, statuser, startTime, loc, op)
if status != NotTerminated || err != nil { if status != NotTerminated || err != nil {
return return
} }
if settings.Recorder != nil {
stats.Runtime = time.Since(startTime)
err = settings.Recorder.Record(loc, op, stats)
if err != nil {
if status == NotTerminated {
status = Failure
}
return status, err
}
}
op, err = method.Iterate(loc) op, err = method.Iterate(loc)
if err != nil { if err != nil {

View File

@@ -13,6 +13,18 @@ import (
"gonum.org/v1/gonum/mat" "gonum.org/v1/gonum/mat"
) )
const (
nonpositiveDimension string = "optimize: non-positive input dimension"
negativeTasks string = "optimize: negative input number of tasks"
)
func min(a, b int) int {
if a < b {
return a
}
return b
}
// newLocation allocates a new locatian structure of the appropriate size. It // newLocation allocates a new locatian structure of the appropriate size. It
// allocates memory based on the dimension and the values in Needs. The initial // allocates memory based on the dimension and the values in Needs. The initial
// function value is set to math.Inf(1). // function value is set to math.Inf(1).
@@ -74,18 +86,12 @@ func checkOptimization(p Problem, dim int, method Needser, recorder Recorder) er
} }
// evaluate evaluates the routines specified by the Operation at loc.X, and stores // evaluate evaluates the routines specified by the Operation at loc.X, and stores
// the answer into loc. loc.X is copied into x before // the answer into loc. loc.X is copied into x before evaluating in order to
// evaluating in order to prevent the routines from modifying it. // prevent the routines from modifying it.
func evaluate(p *Problem, loc *Location, op Operation, x []float64) (Status, error) { func evaluate(p *Problem, loc *Location, op Operation, x []float64) {
if !op.isEvaluation() { if !op.isEvaluation() {
panic(fmt.Sprintf("optimize: invalid evaluation %v", op)) panic(fmt.Sprintf("optimize: invalid evaluation %v", op))
} }
if p.Status != nil {
status, err := p.Status()
if err != nil || status != NotTerminated {
return status, err
}
}
copy(x, loc.X) copy(x, loc.X)
if op&FuncEvaluation != 0 { if op&FuncEvaluation != 0 {
loc.F = p.Func(x) loc.F = p.Func(x)
@@ -96,29 +102,6 @@ func evaluate(p *Problem, loc *Location, op Operation, x []float64) (Status, err
if op&HessEvaluation != 0 { if op&HessEvaluation != 0 {
p.Hess(loc.Hessian, x) p.Hess(loc.Hessian, x)
} }
return NotTerminated, nil
}
// checkConvergence returns NotTerminated if the Location does not satisfy the
// convergence criteria given by settings. Otherwise a corresponding status is
// returned.
// Unlike checkLimits, checkConvergence is called only at MajorIterations.
//
// If local is true, gradient convergence is also checked.
func checkConvergence(loc *Location, settings *Settings, local bool) Status {
if local && loc.Gradient != nil {
norm := floats.Norm(loc.Gradient, math.Inf(1))
if norm < settings.GradientThreshold {
return GradientThreshold
}
}
if loc.F < settings.FunctionThreshold {
return FunctionThreshold
}
if settings.FunctionConverge != nil {
return settings.FunctionConverge.FunctionConverged(loc.F)
}
return NotTerminated
} }
// updateEvaluationStats updates the statistics based on the operation. // updateEvaluationStats updates the statistics based on the operation.
@@ -134,70 +117,75 @@ func updateEvaluationStats(stats *Stats, op Operation) {
} }
} }
// checkLimits returns NotTerminated status if the various limits given by // checkLocationConvergence checks if the current optimal location satisfies
// settings have not been reached. Otherwise it returns a corresponding status. // any of the convergence criteria based on the function location.
// Unlike checkConvergence, checkLimits is called by Local and Global at _every_ //
// iteration. // checkLocationConvergence returns NotTerminated if the Location does not satisfy
func checkLimits(loc *Location, stats *Stats, settings *Settings) Status { // the convergence criteria given by settings. Otherwise a corresponding status is
// Check the objective function value for negative infinity because it // returned.
// could break the linesearches and -inf is the best we can do anyway. // Unlike checkLimits, checkConvergence is called only at MajorIterations.
func checkLocationConvergence(loc *Location, settings *Settings) Status {
if math.IsInf(loc.F, -1) { if math.IsInf(loc.F, -1) {
return FunctionNegativeInfinity return FunctionNegativeInfinity
} }
if loc.Gradient != nil {
if settings.MajorIterations > 0 && stats.MajorIterations >= settings.MajorIterations { norm := floats.Norm(loc.Gradient, math.Inf(1))
return IterationLimit if norm < settings.GradientThreshold {
return GradientThreshold
} }
if settings.FuncEvaluations > 0 && stats.FuncEvaluations >= settings.FuncEvaluations {
return FunctionEvaluationLimit
} }
if loc.F < settings.FunctionThreshold {
if settings.GradEvaluations > 0 && stats.GradEvaluations >= settings.GradEvaluations { return FunctionThreshold
return GradientEvaluationLimit
} }
if settings.FunctionConverge != nil {
if settings.HessEvaluations > 0 && stats.HessEvaluations >= settings.HessEvaluations { return settings.FunctionConverge.FunctionConverged(loc.F)
return HessianEvaluationLimit
} }
// TODO(vladimir-ch): It would be nice to update Runtime here.
if settings.Runtime > 0 && stats.Runtime >= settings.Runtime {
return RuntimeLimit
}
return NotTerminated return NotTerminated
} }
// finishIteration performs cleanup tasks at the end of an optimization iteration. // checkEvaluationLimits checks the optimization limits after an evaluation
// It checks the status, sends information to recorders, and updates the runtime. // Operation. It checks the number of evaluations (of various kinds) and checks
func finishIteration(status Status, err error, stats *Stats, settings *Settings, statuser Statuser, startTime time.Time, loc *Location, op Operation) (Status, error) { // the status of the Problem, if applicable.
if status != NotTerminated || err != nil { func checkEvaluationLimits(p *Problem, stats *Stats, settings *Settings) (Status, error) {
return status, err if p.Status != nil {
} status, err := p.Status()
if settings.Recorder != nil {
stats.Runtime = time.Since(startTime)
err = settings.Recorder.Record(loc, op, stats)
if err != nil {
if status == NotTerminated {
status = Failure
}
return status, err
}
}
stats.Runtime = time.Since(startTime)
status = checkLimits(loc, stats, settings)
if status != NotTerminated {
return status, nil
}
if statuser != nil {
status, err = statuser.Status()
if err != nil || status != NotTerminated { if err != nil || status != NotTerminated {
return status, err return status, err
} }
} }
return status, nil if settings.FuncEvaluations > 0 && stats.FuncEvaluations >= settings.FuncEvaluations {
return FunctionEvaluationLimit, nil
}
if settings.GradEvaluations > 0 && stats.GradEvaluations >= settings.GradEvaluations {
return GradientEvaluationLimit, nil
}
if settings.HessEvaluations > 0 && stats.HessEvaluations >= settings.HessEvaluations {
return HessianEvaluationLimit, nil
}
return NotTerminated, nil
}
// checkIterationLimits checks the limits on iterations affected by MajorIteration.
func checkIterationLimits(loc *Location, stats *Stats, settings *Settings) Status {
if settings.MajorIterations > 0 && stats.MajorIterations >= settings.MajorIterations {
return IterationLimit
}
if settings.Runtime > 0 && stats.Runtime >= settings.Runtime {
return RuntimeLimit
}
return NotTerminated
}
// performMajorIteration does all of the steps needed to perform a MajorIteration.
// It increments the iteration count, updates the optimal location, and checks
// the necessary convergence criteria.
func performMajorIteration(optLoc, loc *Location, stats *Stats, startTime time.Time, settings *Settings) Status {
copyLocation(optLoc, loc)
stats.MajorIterations++
stats.Runtime = time.Since(startTime)
status := checkLocationConvergence(optLoc, settings)
if status != NotTerminated {
return status
}
return checkIterationLimits(optLoc, stats, settings)
} }

View File

@@ -38,6 +38,10 @@ const (
// MajorIteration indicates that the next candidate location for // MajorIteration indicates that the next candidate location for
// an optimum has been found and convergence should be checked. // an optimum has been found and convergence should be checked.
MajorIteration MajorIteration
// MethodDone declares that the method is done running. A method must
// be a Statuser in order to use this iteration, and after returning
// MethodDone, the Status must return other than NotTerminated.
MethodDone
// FuncEvaluation specifies that the objective function // FuncEvaluation specifies that the objective function
// should be evaluated. // should be evaluated.
FuncEvaluation FuncEvaluation
@@ -47,6 +51,8 @@ const (
// HessEvaluation specifies that the Hessian // HessEvaluation specifies that the Hessian
// of the objective function should be evaluated. // of the objective function should be evaluated.
HessEvaluation HessEvaluation
// signalDone is used internally to signal completion.
signalDone
// Mask for the evaluating operations. // Mask for the evaluating operations.
evalMask = FuncEvaluation | GradEvaluation | HessEvaluation evalMask = FuncEvaluation | GradEvaluation | HessEvaluation
@@ -76,6 +82,8 @@ var operationNames = map[Operation]string{
InitIteration: "InitIteration", InitIteration: "InitIteration",
MajorIteration: "MajorIteration", MajorIteration: "MajorIteration",
PostIteration: "PostIteration", PostIteration: "PostIteration",
MethodDone: "MethodDone",
signalDone: "signalDone",
} }
// Location represents a location in the optimization procedure. // Location represents a location in the optimization procedure.
@@ -201,7 +209,7 @@ type Settings struct {
// Runtime is the maximum runtime allowed. RuntimeLimit status is returned // Runtime is the maximum runtime allowed. RuntimeLimit status is returned
// if the duration of the run is longer than this value. Runtime is only // if the duration of the run is longer than this value. Runtime is only
// checked at iterations of the Method. // checked at MajorIterations of the Method.
// If it equals zero, this setting has no effect. // If it equals zero, this setting has no effect.
// The default value is 0. // The default value is 0.
Runtime time.Duration Runtime time.Duration