diff --git a/optimize/cmaes.go b/optimize/cmaes.go index bb7d0835..7b13368b 100644 --- a/optimize/cmaes.go +++ b/optimize/cmaes.go @@ -7,7 +7,6 @@ package optimize import ( "math" "sort" - "sync" "golang.org/x/exp/rand" @@ -103,15 +102,15 @@ type CmaEsChol struct { mean []float64 chol mat.Cholesky - // Parallel fields. - mux sync.Mutex // protect access to evals. - wg sync.WaitGroup // wait for simulations to finish before iterating. - taskIdxs []int // Stores which simulation the task ran. - evals []int // remaining evaluations in this iteration. - // Overall best. bestX []float64 bestF float64 + + // Synchronization. + sentIdx int + receivedIdx int + operation chan<- GlobalTask + updateErr error } var ( @@ -123,21 +122,26 @@ func (cma *CmaEsChol) Needs() struct{ Gradient, Hessian bool } { return struct{ Gradient, Hessian bool }{false, false} } -func (cma *CmaEsChol) Done() {} - -// Status returns the status of the method. -func (cma *CmaEsChol) Status() (Status, error) { +func (cma *CmaEsChol) methodConverged() Status { sd := cma.StopLogDet switch { case math.IsNaN(sd): - return NotTerminated, nil + return NotTerminated case sd == 0: sd = float64(cma.dim) * -36.8413614879 // ln(1e-16) } if cma.chol.LogDet() < sd { - return MethodConverge, nil + return MethodConverge } - return NotTerminated, nil + return NotTerminated +} + +// Status returns the status of the method. +func (cma *CmaEsChol) Status() (Status, error) { + if cma.updateErr != nil { + return Failure, cma.updateErr + } + return cma.methodConverged(), nil } func (cma *CmaEsChol) InitGlobal(dim, tasks int) int { @@ -226,90 +230,169 @@ func (cma *CmaEsChol) InitGlobal(dim, tasks int) int { cma.chol = chol } - cma.evals = make([]int, cma.pop) - for i := range cma.evals { - cma.evals[i] = i - } - cma.bestX = resize(cma.bestX, dim) cma.bestF = math.Inf(1) + cma.sentIdx = 0 + cma.receivedIdx = 0 + cma.operation = nil + cma.updateErr = nil t := min(tasks, cma.pop) - cma.taskIdxs = make([]int, t) - for i := 0; i < t; i++ { - cma.taskIdxs[i] = -1 - } - // Get a new mutex and waitgroup so that if the structure is reused there - // aren't residual interactions with the previous optimization. - cma.mux = sync.Mutex{} - cma.wg = sync.WaitGroup{} return t } -func (cma *CmaEsChol) IterateGlobal(task int, loc *Location) (Operation, error) { - // Check the status of the incoming task. If it is a number, it means - // that task contains a valid location. - idx := cma.taskIdxs[task] - if idx != -1 { - cma.fs[idx] = loc.F - cma.wg.Done() +func (cma *CmaEsChol) sendInitTasks(tasks []GlobalTask) { + for i, task := range tasks { + cma.sendTask(i, task) } + cma.sentIdx = len(tasks) +} - // Get the next task and send it to be run if there is a next task to be run. - // If all of the tasks have been run, perform an update step. Note that the - // use of this mutex means that only one task can proceed, all of the - // other tasks should get stuck and then get a new location. - cma.mux.Lock() - if len(cma.evals) != 0 { - // There are still tasks to evaluate. Grab one and remove it from the list. - newIdx := cma.evals[len(cma.evals)-1] - cma.evals = cma.evals[:len(cma.evals)-1] - cma.wg.Add(1) - cma.mux.Unlock() +// sendTask generates a sample and sends the task. It does not update the cma index. +func (cma *CmaEsChol) sendTask(idx int, task GlobalTask) { + task.ID = idx + task.Op = FuncEvaluation + distmv.NormalRand(cma.xs.RawRowView(idx), cma.mean, &cma.chol, cma.Src) + copy(task.X, cma.xs.RawRowView(idx)) + cma.operation <- task +} - // Sample x and send it to be evaluated. - distmv.NormalRand(cma.xs.RawRowView(newIdx), cma.mean, &cma.chol, cma.Src) - copy(loc.X, cma.xs.RawRowView(newIdx)) - cma.taskIdxs[task] = newIdx - return FuncEvaluation, nil +// bestIdx returns the best index in the functions. Returns -1 if all values +// are NaN. +func (cma *CmaEsChol) bestIdx() int { + best := -1 + bestVal := math.Inf(1) + for i, v := range cma.fs { + if math.IsNaN(v) { + continue + } + // Use equality in case somewhere evaluates to +inf. + if v <= bestVal { + best = i + bestVal = v + } } - // There are no more tasks to evaluate. This means the iteration is over. - // Find the best current f, update the parameters, and re-establish - // the evaluations to run. + return best +} - // Wait for all of the outstanding tasks to finish, so the full set of functions - // has been evaluated. - cma.wg.Wait() - - // Find the best f out of all the tasks. - best := floats.MinIdx(cma.fs) - bestF := cma.fs[best] - bestX := cma.xs.RawRowView(best) +// findBestAndUpdateTask finds the best task in the current list, updates the +// new best overall, and then stores the best location into task. +func (cma *CmaEsChol) findBestAndUpdateTask(task GlobalTask) GlobalTask { + // Find and update the best location. + // Don't use floats because there may be NaN values. + best := cma.bestIdx() + bestF := math.NaN() + bestX := cma.xs.RawRowView(0) + if best != -1 { + bestF = cma.fs[best] + bestX = cma.xs.RawRowView(best) + } if cma.ForgetBest { - loc.F = bestF - copy(loc.X, bestX) + task.F = bestF + copy(task.X, bestX) } else { if bestF < cma.bestF { cma.bestF = bestF copy(cma.bestX, bestX) } - loc.F = cma.bestF - copy(loc.X, cma.bestX) + task.F = cma.bestF + copy(task.X, cma.bestX) } - - cma.taskIdxs[task] = -1 - - // Update the parameters of the distribution - err := cma.update() - - // Reset the tasks - cma.evals = cma.evals[:cma.pop] - - cma.mux.Unlock() - return MajorIteration, err + return task } -// update computes the new parameters (mean, cholesky, etc.) +func (cma *CmaEsChol) RunGlobal(operations chan<- GlobalTask, results <-chan GlobalTask, tasks []GlobalTask) { + cma.operation = operations + // Send the initial tasks. We know there are at most as many tasks as elements + // of the population. + cma.sendInitTasks(tasks) + +Loop: + for { + result := <-results + switch result.Op { + default: + panic("unknown operation") + case PostIteration: + break Loop + case MajorIteration: + // The last thing we did was update all of the tasks and send the + // major iteration. Now we can send a group of tasks again. + cma.sendInitTasks(tasks) + case FuncEvaluation: + cma.receivedIdx++ + cma.fs[result.ID] = result.F + switch { + case cma.sentIdx < cma.pop: + // There are still tasks to evaluate. Send the next. + cma.sendTask(cma.sentIdx, result) + cma.sentIdx++ + case cma.receivedIdx < cma.pop: + // All the tasks have been sent, but not all of them have been received. + // Need to wait until all are back. + continue Loop + default: + // All of the evaluations have been received. + if cma.receivedIdx != cma.pop { + panic("bad logic") + } + cma.receivedIdx = 0 + cma.sentIdx = 0 + + task := cma.findBestAndUpdateTask(result) + // Update the parameters and send a MajorIteration or a convergence. + err := cma.update() + // Kill the existing data. + for i := range cma.fs { + cma.fs[i] = math.NaN() + cma.xs.Set(i, 0, math.NaN()) + } + switch { + case err != nil: + cma.updateErr = err + task.Op = MethodDone + case cma.methodConverged() != NotTerminated: + task.Op = MethodDone + default: + task.Op = MajorIteration + task.ID = -1 + } + operations <- task + } + } + } + + // Been told to stop. Clean up. + // Need to see best of our evaluated tasks so far. Should instead just + // collect, then see. + for task := range results { + switch task.Op { + case MajorIteration: + case FuncEvaluation: + cma.fs[task.ID] = task.F + default: + panic("unknown operation") + } + } + // Send the new best value if the evaluation is better than any we've + // found so far. Keep this separate from findBestAndUpdateTask so that + // we only send an iteration if we find a better location. + if !cma.ForgetBest { + best := cma.bestIdx() + if best != -1 && cma.fs[best] < cma.bestF { + task := tasks[0] + task.F = cma.fs[best] + copy(task.X, cma.xs.RawRowView(best)) + task.Op = MajorIteration + task.ID = -1 + operations <- task + } + } + close(operations) +} + +// update computes the new parameters (mean, cholesky, etc.). Does not update +// any of the synchronization parameters (taskIdx). func (cma *CmaEsChol) update() error { // Sort the function values to find the elite samples. ftmp := make([]float64, cma.pop) diff --git a/optimize/cmaes_test.go b/optimize/cmaes_test.go index bca90c7c..eb6b74b8 100644 --- a/optimize/cmaes_test.go +++ b/optimize/cmaes_test.go @@ -21,7 +21,7 @@ type cmaTestCase struct { problem Problem method *CmaEsChol settings *Settings - good func(*Result, error) error + good func(result *Result, err error, concurrent int) error } func cmaTestCases() []cmaTestCase { @@ -42,7 +42,7 @@ func cmaTestCases() []cmaTestCase { settings: &Settings{ FunctionThreshold: 0.01, }, - good: func(result *Result, err error) error { + good: func(result *Result, err error, concurrent int) error { if result.Status != FunctionThreshold { return errors.New("result not function threshold") } @@ -63,7 +63,7 @@ func cmaTestCases() []cmaTestCase { settings: &Settings{ FunctionThreshold: math.Inf(-1), }, - good: func(result *Result, err error) error { + good: func(result *Result, err error, concurrent int) error { if result.Status != MethodConverge { return errors.New("result not method converge") } @@ -82,24 +82,30 @@ func cmaTestCases() []cmaTestCase { }, method: &CmaEsChol{ Population: 100, + ForgetBest: true, // Otherwise may get an update at the end. }, settings: &Settings{ FunctionThreshold: math.Inf(-1), MajorIterations: 10, }, - good: func(result *Result, err error) error { + good: func(result *Result, err error, concurrent int) error { if result.Status != IterationLimit { return errors.New("result not iteration limit") } - if result.FuncEvaluations != 1000 { - return errors.New("wrong number of evaluations") + threshLower := 10 + threshUpper := 10 + if concurrent != 0 { + // Could have one more from final update. + threshUpper++ + } + if result.MajorIterations < threshLower || result.MajorIterations > threshUpper { + return errors.New("wrong number of iterations") } return nil }, }, { - // Test that works properly in parallel, and stops with some - // number of function evaluations. + // Test that work stops with some number of function evaluations. dim: 5, problem: Problem{ Func: functions.ExtendedRosenbrock{}.Func, @@ -108,18 +114,22 @@ func cmaTestCases() []cmaTestCase { Population: 100, }, settings: &Settings{ - Concurrent: 5, FunctionThreshold: math.Inf(-1), FuncEvaluations: 250, // Somewhere in the middle of an iteration. }, - good: func(result *Result, err error) error { + good: func(result *Result, err error, concurrent int) error { if result.Status != FunctionEvaluationLimit { return errors.New("result not function evaluations") } - if result.FuncEvaluations < 250 { + threshLower := 250 + threshUpper := 251 + if concurrent != 0 { + threshUpper = threshLower + concurrent + } + if result.FuncEvaluations < threshLower { return errors.New("too few function evaluations") } - if result.FuncEvaluations > 250+4 { // can't guarantee exactly, because could grab extras in parallel first. + if result.FuncEvaluations > threshUpper { return errors.New("too many function evaluations") } return nil @@ -137,7 +147,7 @@ func cmaTestCases() []cmaTestCase { settings: &Settings{ FunctionThreshold: math.Inf(-1), }, - good: func(result *Result, err error) error { + good: func(result *Result, err error, concurrent int) error { if result.Status != MethodConverge { return errors.New("result not method converge") } @@ -157,15 +167,16 @@ func cmaTestCases() []cmaTestCase { Population: 100, // Increase the population size to reduce noise. InitMean: localMinMean, InitCholesky: &localMinChol, + ForgetBest: true, // So that if it accidentally finds a better place we still converge to the minimum. }, settings: &Settings{ FunctionThreshold: math.Inf(-1), }, - good: func(result *Result, err error) error { + good: func(result *Result, err error, concurrent int) error { if result.Status != MethodConverge { return errors.New("result not method converge") } - if !floats.EqualApprox(result.X, []float64{2, -2}, 1e-2) { + if !floats.EqualApprox(result.X, []float64{2, -2}, 3e-2) { return errors.New("local minimum not found") } return nil @@ -181,14 +192,22 @@ func TestCmaEsChol(t *testing.T) { method.Src = src // Run and check that the expected termination occurs. result, err := Global(test.problem, test.dim, test.settings, method) - if testErr := test.good(result, err); testErr != nil { + if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil { t.Errorf("cas %d: %v", i, testErr) } // Run a second time to make sure there are no residual effects result, err = Global(test.problem, test.dim, test.settings, method) - if testErr := test.good(result, err); testErr != nil { + if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil { t.Errorf("cas %d second: %v", i, testErr) } + + // Test the problem in parallel. + test.settings.Concurrent = 5 + result, err = Global(test.problem, test.dim, test.settings, method) + if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil { + t.Errorf("cas %d concurrent: %v", i, testErr) + } + test.settings.Concurrent = 0 } } diff --git a/optimize/global.go b/optimize/global.go index ff44e831..f2544751 100644 --- a/optimize/global.go +++ b/optimize/global.go @@ -6,65 +6,65 @@ package optimize import ( "math" - "sync" "time" ) -var ( - nonpositiveDimension string = "optimize: non-positive input dimension" - negativeTasks string = "optimize: negative input number of tasks" -) +// DefaultSettingsGlobal returns the default settings for Global optimization. +func DefaultSettingsGlobal() *Settings { + return &Settings{ + FunctionThreshold: math.Inf(-1), + FunctionConverge: &FunctionConverge{ + Absolute: 1e-10, + Iterations: 100, + }, + } +} -// GlobalMethod is an optimization method which seeks to find the global minimum -// of an objective function. -// -// At the beginning of the optimization, InitGlobal is called to communicate -// the dimension of the input and maximum number of concurrent tasks. -// The actual number of concurrent tasks will be set from the return of InitGlobal, -// which must not be greater than the input tasks. -// -// During the optimization, a reverse-communication interface is used between -// the GlobalMethod and the caller. -// GlobalMethod acts as a client that asks the caller to perform -// needed operations given the return from IterateGlobal. -// This allows and enforces automation of maintaining statistics and checking for -// (various types of) convergence. -// -// The return from IterateGlobal can be an Evaluation, a MajorIteration or NoOperation. -// -// An evaluation is one or more of the Evaluation operations (FuncEvaluation, -// GradEvaluation, etc.) combined with the bitwise or operator. In an evaluation -// operation, the requested fields of Problem will be evaluated at the value -// in Location.X, filling the corresponding fields of Location. These values -// can be retrieved and used upon the next call to IterateGlobal with that task id. -// The GlobalMethod interface requires that entries of Location are not modified -// aside from the commanded evaluations. Thus, the type implementing GlobalMethod -// may use multiple Operations to set the Location fields at a particular x value. -// -// When IterateGlobal declares MajorIteration, the caller updates the optimal -// location to the values in Location, and checks for convergence. The type -// implementing GlobalMethod must make sure that the fields of Location are valid -// and consistent. -// -// IterateGlobal must not return InitIteration and PostIteration operations. These are -// reserved for the clients to be passed to Recorders. A Method must also not -// combine the Evaluation operations with the Iteration operations. +// GlobalTask is a type to communicate between the GlobalMethod and the outer +// calling script. +type GlobalTask struct { + ID int + Op Operation + *Location +} + +// GlobalMethod is a type which can search for a global optimum for an objective function. type GlobalMethod interface { Needser - // InitGlobal communicates the input dimension and maximum number of tasks, - // and returns the number of concurrent processes. The return must be less - // than or equal to tasks. + // InitGlobal takes as input the problem dimension and number of available + // concurrent tasks, and returns the number of concurrent processes to be used. + // The returned value must be less than or equal to tasks. InitGlobal(dim, tasks int) int - - // IterateGlobal retrieves information from the location associated with - // the given task ID, and returns the next operation to perform with that - // Location. IterateGlobal may assume that the same pointer is associated - // with the same task. - IterateGlobal(task int, loc *Location) (Operation, error) - - // Done communicates that the optimization has concluded to allow for shutdown. - // After Done is called, no more calls to IterateGlobal will be made. - Done() + // RunGlobal runs a global optimization. The method sends GlobalTasks on + // the operation channel (for performing function evaluations, major + // iterations, etc.). The result of the tasks will be returned on Result. + // See the documentation for Operation types for the possible tasks. + // + // The caller of RunGlobal will signal the termination of the optimization + // (i.e. convergence from user settings) by sending a task with a PostIteration + // Op field on result. More tasks may still be sent on operation after this + // occurs, but only MajorIteration operations will still be conducted + // appropriately. Thus, it can not be guaranteed that all Evaluations sent + // on operation will be evaluated, however if an Evaluation is started, + // the results of that evaluation will be sent on results. + // + // The GlobalMethod must read from the result channel until it is closed. + // During this, the GlobalMethod may want to send new MajorIteration(s) on + // operation. GlobalMethod then must close operation, and return from RunGlobal. + // + // The las parameter to RunGlobal is a slice of tasks with length equal to + // the return from InitGlobal. GlobalTask has an ID field which may be + // set and modified by GlobalMethod, and must not be modified by the caller. + // + // GlobalMethod may have its own specific convergence criteria, which can + // be communicated using a MethodDone operation. This will trigger a + // PostIteration to be sent on result, and the MethodDone task will not be + // returned on result. The GlobalMethod must implement Statuser, and the + // call to Status must return a Status other than NotTerminated. + // + // The operation and result tasks are guaranteed to have a buffer length + // equal to the return from InitGlobal. + RunGlobal(operation chan<- GlobalTask, result <-chan GlobalTask, tasks []GlobalTask) } // Global uses a global optimizer to search for the global minimum of a @@ -86,30 +86,23 @@ type GlobalMethod interface { // The third argument contains the settings for the minimization. The // DefaultGlobalSettings function can be called for a Settings struct with the // default values initialized. If settings == nil, the default settings are used. -// Global optimization methods typically do not make assumptions about the number -// and location of local minima. Thus, the only convergence metric used is the -// function values found at major iterations of the optimization. Bounds on the -// length of optimization are obeyed, such as the number of allowed function -// evaluations. +// All of the settings will be followed, but many of them may be counterproductive +// to use (such as GradientThreshold). Global cannot guarantee strict adherence +// to the bounds specified when performing concurrent evaluations and updates. // // The final argument is the optimization method to use. If method == nil, then // an appropriate default is chosen based on the properties of the other arguments // (dimension, gradient-free or gradient-based, etc.). // -// If method implements Statuser, method.Status is called before every call -// to method.Iterate. If the returned Status is not NotTerminated or the -// error is non-nil, the optimization run is terminated. -// // Global returns a Result struct and any error that occurred. See the // documentation of Result for more information. // +// See the documentation for GlobalMethod for the details on implementing a method. +// // Be aware that the default behavior of Global is to find the minimum. // For certain functions and optimization methods, this process can take many // function evaluations. The Settings input struct can be used to limit this, // for example by modifying the maximum runtime or maximum function evaluations. -// -// Global cannot guarantee strict adherence to the bounds specified in Settings -// when performing concurrent evaluations and updates. func Global(p Problem, dim int, settings *Settings, method GlobalMethod) (*Result, error) { startTime := time.Now() if method == nil { @@ -157,23 +150,10 @@ func Global(p Problem, dim int, settings *Settings, method GlobalMethod) (*Resul }, err } -// minimizeGlobal is the high-level function for a Global optimization. It launches -// concurrent workers to perform the mimization, and shuts them down properly -// at the conclusion. -func minimizeGlobal(p *Problem, method GlobalMethod, settings *Settings, stats *Stats, optLoc *Location, startTime time.Time) (status Status, err error) { +// minimizeGlobal performs a Global optimization. minimizeGlobal updates the +// settings and optLoc, and returns the final Status and error. +func minimizeGlobal(prob *Problem, method GlobalMethod, settings *Settings, stats *Stats, optLoc *Location, startTime time.Time) (Status, error) { dim := len(optLoc.X) - statuser, _ := method.(Statuser) - gs := &globalStatus{ - mux: &sync.RWMutex{}, - stats: stats, - status: NotTerminated, - p: p, - startTime: startTime, - optLoc: optLoc, - settings: settings, - statuser: statuser, - } - nTasks := settings.Concurrent if nTasks == 0 { nTasks = 1 @@ -184,158 +164,160 @@ func minimizeGlobal(p *Problem, method GlobalMethod, settings *Settings, stats * } nTasks = newNTasks - // Launch optimization workers. Each worker is individually responsible - // for maintaining stats and evaluating the function. - var wg sync.WaitGroup - for task := 0; task < nTasks; task++ { - wg.Add(1) - go func(task int) { - defer wg.Done() - loc := newLocation(dim, method) - x := make([]float64, dim) - globalWorker(task, method, gs, loc, x) - }(task) - } - wg.Wait() - method.Done() - return gs.status, gs.err -} - -// globalWorker runs the optimization steps for a single (concurrently-executing) -// optimization task. -func globalWorker(task int, m GlobalMethod, g *globalStatus, loc *Location, x []float64) { - for { - // Find Evaluation location - op, err := m.IterateGlobal(task, loc) - if err != nil { - g.updateStatus(Failure, err) - break + // Launch the method. The method communicates tasks using the operations + // channel, and results is used to return the evaluated results. + operations := make(chan GlobalTask, nTasks) + results := make(chan GlobalTask, nTasks) + go func() { + tasks := make([]GlobalTask, nTasks) + for i := range tasks { + tasks[i].Location = newLocation(dim, method) } + method.RunGlobal(operations, results, tasks) + }() - // Evaluate location and/or update stats. - status := g.globalOperation(op, loc, x) - if status != NotTerminated { - break + // Algorithmic Overview: + // There are three pieces to performing a concurrent global optimization, + // the distributor, the workers, and the stats combiner. At a high level, + // the distributor reads in tasks sent by method, sending evaluations to the + // workers, and forwarding other operations to the statsCombiner. The workers + // read these forwarded evaluation tasks, evaluate the relevant parts of Problem + // and forward the results on to the stats combiner. The stats combiner reads + // in results from the workers, as well as tasks from the distributor, and + // uses them to update optimization statistics (function evaluations, etc.) + // and to check optimization convergence. + // + // The complicated part is correctly shutting down the optimization. The + // procedure is as follows. First, the stats combiner closes done and sends + // a PostIteration to the method. The distributor then reads that done has + // been closed, and closes the channel with the workers. At this point, no + // more evaluation operations will be executed. As the workers finish their + // evaluations, they forward the results onto the stats combiner, and then + // signal their shutdown to the stats combiner. When all workers have successfully + // finished, the stats combiner closes the results channel, signaling to the + // method that all results have been collected. At this point, the method + // may send MajorIteration(s) to update an optimum location based on these + // last returned results, and then the method will close the operations channel. + // Now that no more tasks will be commanded by the method, the distributor + // closes statsChan, and with no more statistics to update the optimization + // concludes. + + workerChan := make(chan GlobalTask) // Delegate tasks to the workers. + statsChan := make(chan GlobalTask) // Send evaluation updates. + done := make(chan struct{}) // Communicate the optimization is done. + + // Read tasks from the method and distribute as appropriate. + distributor := func() { + for { + select { + case task := <-operations: + switch task.Op { + case InitIteration: + panic("optimize: GlobalMethod returned InitIteration") + case PostIteration: + panic("optimize: GlobalMethod returned PostIteration") + case NoOperation, MajorIteration, MethodDone: + statsChan <- task + default: + if !task.Op.isEvaluation() { + panic("global: expecting evaluation operation") + } + workerChan <- task + } + case <-done: + // No more evaluations will be sent, shut down the workers, and + // read the final tasks. + close(workerChan) + for task := range operations { + if task.Op == MajorIteration { + statsChan <- task + } + } + close(statsChan) + return + } } } -} + go distributor() -// globalStatus coordinates access to information shared between concurrently -// executing optimization tasks. -type globalStatus struct { - mux *sync.RWMutex - stats *Stats - status Status - p *Problem - startTime time.Time - optLoc *Location - settings *Settings - method GlobalMethod - statuser Statuser - err error -} - -// getStatus returns the current status of the optimization. -func (g *globalStatus) getStatus() Status { - var status Status - g.mux.RLock() - defer g.mux.RUnlock() - status = g.status - return status -} - -func (g *globalStatus) incrementMajorIteration() { - g.mux.Lock() - defer g.mux.Unlock() - g.stats.MajorIterations++ -} - -func (g *globalStatus) updateOptLoc(loc *Location) { - g.mux.Lock() - defer g.mux.Unlock() - copyLocation(g.optLoc, loc) -} - -// checkConvergence checks the convergence of the global optimization and returns -// the status -func (g *globalStatus) checkConvergence() Status { - g.mux.RLock() - defer g.mux.RUnlock() - return checkConvergence(g.optLoc, g.settings, false) -} - -// updateStats updates the evaluation statistics for the given operation. -func (g *globalStatus) updateStats(op Operation) { - g.mux.Lock() - defer g.mux.Unlock() - updateEvaluationStats(g.stats, op) -} - -// updateStatus updates the status and error fields of g. This update only happens -// if status == NotTerminated, so that the first different status is the one -// maintained. -func (g *globalStatus) updateStatus(s Status, err error) { - g.mux.Lock() - defer g.mux.Unlock() - if s != NotTerminated { - g.status = s - g.err = err + // Evaluate the Problem concurrently. + worker := func() { + x := make([]float64, dim) + for task := range workerChan { + evaluate(prob, task.Location, task.Op, x) + statsChan <- task + } + // Signal successful worker completion. + statsChan <- GlobalTask{Op: signalDone} } -} - -func (g *globalStatus) finishIteration(status Status, err error, loc *Location, op Operation) (Status, error) { - g.mux.Lock() - defer g.mux.Unlock() - return finishIteration(status, err, g.stats, g.settings, g.statuser, g.startTime, loc, op) -} - -// globalOperation executes the requested operation at the given location. -// When modifying this function, keep in mind that it can be called concurrently. -// Uses of the internal fields should be through the methods of globalStatus and -// protected by a mutex where appropriate. -func (g *globalStatus) globalOperation(op Operation, loc *Location, x []float64) Status { - // Do a quick check to see if one of the other workers converged in the meantime. - status := g.getStatus() - if status != NotTerminated { - return status - } - var err error - switch op { - case NoOperation: - case InitIteration: - panic("optimize: GlobalMethod returned InitIteration") - case PostIteration: - panic("optimize: GlobalMethod returned PostIteration") - case MajorIteration: - g.incrementMajorIteration() - g.updateOptLoc(loc) - status = g.checkConvergence() - default: // Any of the Evaluation operations. - status, err = evaluate(g.p, loc, op, x) - g.updateStats(op) + for i := 0; i < nTasks; i++ { + go worker() } - status, err = g.finishIteration(status, err, loc, op) - if status != NotTerminated || err != nil { - g.updateStatus(status, err) - } - return status -} + var ( + workersDone int // effective wg for the workers + status Status + err error + finalStatus Status + finalError error + ) -// DefaultSettingsGlobal returns the default settings for Global optimization. -func DefaultSettingsGlobal() *Settings { - return &Settings{ - FunctionThreshold: math.Inf(-1), - FunctionConverge: &FunctionConverge{ - Absolute: 1e-10, - Iterations: 100, - }, - } -} + // Update optimization statistics and check convergence. + for task := range statsChan { + switch task.Op { + default: + if !task.Op.isEvaluation() { + panic("global: evaluation task expected") + } + updateEvaluationStats(stats, task.Op) + status, err = checkEvaluationLimits(prob, stats, settings) + case signalDone: + workersDone++ + if workersDone == nTasks { + close(results) + } + continue + case NoOperation: + // Just send the task back. + case MajorIteration: + status = performMajorIteration(optLoc, task.Location, stats, startTime, settings) + case MethodDone: + statuser, ok := method.(Statuser) + if !ok { + panic("optimize: global method returned MethodDone but is not a Statuser") + } + status, err = statuser.Status() + if status == NotTerminated { + panic("optimize: global method returned MethodDone but a NotTerminated status") + } + } + if settings.Recorder != nil && status == NotTerminated && err == nil { + stats.Runtime = time.Since(startTime) + // Allow err to be overloaded if the Recorder fails. + err = settings.Recorder.Record(task.Location, task.Op, stats) + if err != nil { + status = Failure + } + } + // If this is the first termination status, trigger the conclusion of + // the optimization. + if status != NotTerminated || err != nil { + select { + case <-done: + default: + finalStatus = status + finalError = err + results <- GlobalTask{ + Op: PostIteration, + } + close(done) + } + } -func min(a, b int) int { - if a < b { - return a + // Send the result back to the Problem if there are still active workers. + if workersDone != nTasks && task.Op != MethodDone { + results <- task + } } - return b + return finalStatus, finalError } diff --git a/optimize/guessandcheck.go b/optimize/guessandcheck.go index 84440865..ffaec00d 100644 --- a/optimize/guessandcheck.go +++ b/optimize/guessandcheck.go @@ -6,7 +6,6 @@ package optimize import ( "math" - "sync" "gonum.org/v1/gonum/stat/distmv" ) @@ -16,9 +15,6 @@ import ( type GuessAndCheck struct { Rander distmv.Rander - eval []bool - - mux *sync.Mutex bestF float64 bestX []float64 } @@ -27,34 +23,68 @@ func (g *GuessAndCheck) Needs() struct{ Gradient, Hessian bool } { return struct{ Gradient, Hessian bool }{false, false} } -func (g *GuessAndCheck) Done() { - // No cleanup needed -} - func (g *GuessAndCheck) InitGlobal(dim, tasks int) int { - g.eval = make([]bool, tasks) + if dim <= 0 { + panic(nonpositiveDimension) + } + if tasks < 0 { + panic(negativeTasks) + } g.bestF = math.Inf(1) g.bestX = resize(g.bestX, dim) - g.mux = &sync.Mutex{} return tasks } -func (g *GuessAndCheck) IterateGlobal(task int, loc *Location) (Operation, error) { - // Task is true if it contains a new function evaluation. - if g.eval[task] { - g.eval[task] = false - g.mux.Lock() - if loc.F < g.bestF { - g.bestF = loc.F - copy(g.bestX, loc.X) - } else { - loc.F = g.bestF - copy(loc.X, g.bestX) - } - g.mux.Unlock() - return MajorIteration, nil - } - g.eval[task] = true - g.Rander.Rand(loc.X) - return FuncEvaluation, nil +func (g *GuessAndCheck) sendNewLoc(operation chan<- GlobalTask, task GlobalTask) { + g.Rander.Rand(task.X) + task.Op = FuncEvaluation + operation <- task +} + +func (g *GuessAndCheck) updateMajor(operation chan<- GlobalTask, task GlobalTask) { + // Update the best value seen so far, and send a MajorIteration. + if task.F < g.bestF { + g.bestF = task.F + copy(g.bestX, task.X) + } else { + task.F = g.bestF + copy(task.X, g.bestX) + } + task.Op = MajorIteration + operation <- task +} + +func (g *GuessAndCheck) RunGlobal(operation chan<- GlobalTask, result <-chan GlobalTask, tasks []GlobalTask) { + // Send initial tasks to evaluate + for _, task := range tasks { + g.sendNewLoc(operation, task) + } + + // Read from the channel until PostIteration is sent. +Loop: + for { + task := <-result + switch task.Op { + default: + panic("unknown operation") + case PostIteration: + break Loop + case MajorIteration: + g.sendNewLoc(operation, task) + case FuncEvaluation: + g.updateMajor(operation, task) + } + } + + // PostIteration was sent. Update the best new values. + for task := range result { + switch task.Op { + default: + panic("unknown operation") + case MajorIteration: + case FuncEvaluation: + g.updateMajor(operation, task) + } + } + close(operation) } diff --git a/optimize/guessandcheck_test.go b/optimize/guessandcheck_test.go index f7bf1629..54787f5d 100644 --- a/optimize/guessandcheck_test.go +++ b/optimize/guessandcheck_test.go @@ -27,6 +27,7 @@ func TestGuessAndCheck(t *testing.T) { panic("bad test") } Global(problem, dim, nil, &GuessAndCheck{Rander: d}) + settings := DefaultSettingsGlobal() settings.Concurrent = 5 settings.MajorIterations = 15 diff --git a/optimize/local.go b/optimize/local.go index 0bc4f810..6759fe54 100644 --- a/optimize/local.go +++ b/optimize/local.go @@ -95,7 +95,7 @@ func Local(p Problem, initX []float64, settings *Settings, method Method) (*Resu } // Check if the starting location satisfies the convergence criteria. - status := checkConvergence(optLoc, settings, true) + status := checkLocationConvergence(optLoc, settings) // Run optimization if status == NotTerminated && err == nil { @@ -123,8 +123,6 @@ func minimize(p *Problem, method Method, settings *Settings, stats *Stats, optLo copyLocation(loc, optLoc) x := make([]float64, len(loc.X)) - statuser, _ := method.(Statuser) - var op Operation op, err = method.Init(loc) if err != nil { @@ -143,18 +141,34 @@ func minimize(p *Problem, method Method, settings *Settings, stats *Stats, optLo case PostIteration: panic("optimize: Method returned PostIteration") case MajorIteration: - copyLocation(optLoc, loc) - stats.MajorIterations++ - status = checkConvergence(optLoc, settings, true) + status = performMajorIteration(optLoc, loc, stats, startTime, settings) + case MethodDone: + statuser, ok := method.(Statuser) + if !ok { + panic("optimize: method returned MethodDone is not a Statuser") + } + status, err = statuser.Status() + if status == NotTerminated { + panic("optimize: method returned MethodDone but a NotTerminated status") + } default: // Any of the Evaluation operations. - status, err = evaluate(p, loc, op, x) + evaluate(p, loc, op, x) updateEvaluationStats(stats, op) + status, err = checkEvaluationLimits(p, stats, settings) } - - status, err = finishIteration(status, err, stats, settings, statuser, startTime, loc, op) if status != NotTerminated || err != nil { return } + if settings.Recorder != nil { + stats.Runtime = time.Since(startTime) + err = settings.Recorder.Record(loc, op, stats) + if err != nil { + if status == NotTerminated { + status = Failure + } + return status, err + } + } op, err = method.Iterate(loc) if err != nil { diff --git a/optimize/minimize.go b/optimize/minimize.go index 940083e9..e1c93976 100644 --- a/optimize/minimize.go +++ b/optimize/minimize.go @@ -13,6 +13,18 @@ import ( "gonum.org/v1/gonum/mat" ) +const ( + nonpositiveDimension string = "optimize: non-positive input dimension" + negativeTasks string = "optimize: negative input number of tasks" +) + +func min(a, b int) int { + if a < b { + return a + } + return b +} + // newLocation allocates a new locatian structure of the appropriate size. It // allocates memory based on the dimension and the values in Needs. The initial // function value is set to math.Inf(1). @@ -74,18 +86,12 @@ func checkOptimization(p Problem, dim int, method Needser, recorder Recorder) er } // evaluate evaluates the routines specified by the Operation at loc.X, and stores -// the answer into loc. loc.X is copied into x before -// evaluating in order to prevent the routines from modifying it. -func evaluate(p *Problem, loc *Location, op Operation, x []float64) (Status, error) { +// the answer into loc. loc.X is copied into x before evaluating in order to +// prevent the routines from modifying it. +func evaluate(p *Problem, loc *Location, op Operation, x []float64) { if !op.isEvaluation() { panic(fmt.Sprintf("optimize: invalid evaluation %v", op)) } - if p.Status != nil { - status, err := p.Status() - if err != nil || status != NotTerminated { - return status, err - } - } copy(x, loc.X) if op&FuncEvaluation != 0 { loc.F = p.Func(x) @@ -96,29 +102,6 @@ func evaluate(p *Problem, loc *Location, op Operation, x []float64) (Status, err if op&HessEvaluation != 0 { p.Hess(loc.Hessian, x) } - return NotTerminated, nil -} - -// checkConvergence returns NotTerminated if the Location does not satisfy the -// convergence criteria given by settings. Otherwise a corresponding status is -// returned. -// Unlike checkLimits, checkConvergence is called only at MajorIterations. -// -// If local is true, gradient convergence is also checked. -func checkConvergence(loc *Location, settings *Settings, local bool) Status { - if local && loc.Gradient != nil { - norm := floats.Norm(loc.Gradient, math.Inf(1)) - if norm < settings.GradientThreshold { - return GradientThreshold - } - } - if loc.F < settings.FunctionThreshold { - return FunctionThreshold - } - if settings.FunctionConverge != nil { - return settings.FunctionConverge.FunctionConverged(loc.F) - } - return NotTerminated } // updateEvaluationStats updates the statistics based on the operation. @@ -134,70 +117,75 @@ func updateEvaluationStats(stats *Stats, op Operation) { } } -// checkLimits returns NotTerminated status if the various limits given by -// settings have not been reached. Otherwise it returns a corresponding status. -// Unlike checkConvergence, checkLimits is called by Local and Global at _every_ -// iteration. -func checkLimits(loc *Location, stats *Stats, settings *Settings) Status { - // Check the objective function value for negative infinity because it - // could break the linesearches and -inf is the best we can do anyway. +// checkLocationConvergence checks if the current optimal location satisfies +// any of the convergence criteria based on the function location. +// +// checkLocationConvergence returns NotTerminated if the Location does not satisfy +// the convergence criteria given by settings. Otherwise a corresponding status is +// returned. +// Unlike checkLimits, checkConvergence is called only at MajorIterations. +func checkLocationConvergence(loc *Location, settings *Settings) Status { if math.IsInf(loc.F, -1) { return FunctionNegativeInfinity } - - if settings.MajorIterations > 0 && stats.MajorIterations >= settings.MajorIterations { - return IterationLimit + if loc.Gradient != nil { + norm := floats.Norm(loc.Gradient, math.Inf(1)) + if norm < settings.GradientThreshold { + return GradientThreshold + } } - - if settings.FuncEvaluations > 0 && stats.FuncEvaluations >= settings.FuncEvaluations { - return FunctionEvaluationLimit + if loc.F < settings.FunctionThreshold { + return FunctionThreshold } - - if settings.GradEvaluations > 0 && stats.GradEvaluations >= settings.GradEvaluations { - return GradientEvaluationLimit + if settings.FunctionConverge != nil { + return settings.FunctionConverge.FunctionConverged(loc.F) } - - if settings.HessEvaluations > 0 && stats.HessEvaluations >= settings.HessEvaluations { - return HessianEvaluationLimit - } - - // TODO(vladimir-ch): It would be nice to update Runtime here. - if settings.Runtime > 0 && stats.Runtime >= settings.Runtime { - return RuntimeLimit - } - return NotTerminated } -// finishIteration performs cleanup tasks at the end of an optimization iteration. -// It checks the status, sends information to recorders, and updates the runtime. -func finishIteration(status Status, err error, stats *Stats, settings *Settings, statuser Statuser, startTime time.Time, loc *Location, op Operation) (Status, error) { - if status != NotTerminated || err != nil { - return status, err - } - - if settings.Recorder != nil { - stats.Runtime = time.Since(startTime) - err = settings.Recorder.Record(loc, op, stats) - if err != nil { - if status == NotTerminated { - status = Failure - } - return status, err - } - } - - stats.Runtime = time.Since(startTime) - status = checkLimits(loc, stats, settings) - if status != NotTerminated { - return status, nil - } - - if statuser != nil { - status, err = statuser.Status() +// checkEvaluationLimits checks the optimization limits after an evaluation +// Operation. It checks the number of evaluations (of various kinds) and checks +// the status of the Problem, if applicable. +func checkEvaluationLimits(p *Problem, stats *Stats, settings *Settings) (Status, error) { + if p.Status != nil { + status, err := p.Status() if err != nil || status != NotTerminated { return status, err } } - return status, nil + if settings.FuncEvaluations > 0 && stats.FuncEvaluations >= settings.FuncEvaluations { + return FunctionEvaluationLimit, nil + } + if settings.GradEvaluations > 0 && stats.GradEvaluations >= settings.GradEvaluations { + return GradientEvaluationLimit, nil + } + if settings.HessEvaluations > 0 && stats.HessEvaluations >= settings.HessEvaluations { + return HessianEvaluationLimit, nil + } + return NotTerminated, nil +} + +// checkIterationLimits checks the limits on iterations affected by MajorIteration. +func checkIterationLimits(loc *Location, stats *Stats, settings *Settings) Status { + if settings.MajorIterations > 0 && stats.MajorIterations >= settings.MajorIterations { + return IterationLimit + } + if settings.Runtime > 0 && stats.Runtime >= settings.Runtime { + return RuntimeLimit + } + return NotTerminated +} + +// performMajorIteration does all of the steps needed to perform a MajorIteration. +// It increments the iteration count, updates the optimal location, and checks +// the necessary convergence criteria. +func performMajorIteration(optLoc, loc *Location, stats *Stats, startTime time.Time, settings *Settings) Status { + copyLocation(optLoc, loc) + stats.MajorIterations++ + stats.Runtime = time.Since(startTime) + status := checkLocationConvergence(optLoc, settings) + if status != NotTerminated { + return status + } + return checkIterationLimits(optLoc, stats, settings) } diff --git a/optimize/types.go b/optimize/types.go index 118e7eb6..10bc7f28 100644 --- a/optimize/types.go +++ b/optimize/types.go @@ -38,6 +38,10 @@ const ( // MajorIteration indicates that the next candidate location for // an optimum has been found and convergence should be checked. MajorIteration + // MethodDone declares that the method is done running. A method must + // be a Statuser in order to use this iteration, and after returning + // MethodDone, the Status must return other than NotTerminated. + MethodDone // FuncEvaluation specifies that the objective function // should be evaluated. FuncEvaluation @@ -47,6 +51,8 @@ const ( // HessEvaluation specifies that the Hessian // of the objective function should be evaluated. HessEvaluation + // signalDone is used internally to signal completion. + signalDone // Mask for the evaluating operations. evalMask = FuncEvaluation | GradEvaluation | HessEvaluation @@ -76,6 +82,8 @@ var operationNames = map[Operation]string{ InitIteration: "InitIteration", MajorIteration: "MajorIteration", PostIteration: "PostIteration", + MethodDone: "MethodDone", + signalDone: "signalDone", } // Location represents a location in the optimization procedure. @@ -201,7 +209,7 @@ type Settings struct { // Runtime is the maximum runtime allowed. RuntimeLimit status is returned // if the duration of the run is longer than this value. Runtime is only - // checked at iterations of the Method. + // checked at MajorIterations of the Method. // If it equals zero, this setting has no effect. // The default value is 0. Runtime time.Duration