mirror of
https://github.com/gonum/gonum.git
synced 2025-09-27 03:26:04 +08:00
232 lines
6.2 KiB
Go
232 lines
6.2 KiB
Go
// Copyright ©2017 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package optimize
|
|
|
|
import (
|
|
"errors"
|
|
"math"
|
|
"math/rand/v2"
|
|
"testing"
|
|
|
|
"gonum.org/v1/gonum/floats"
|
|
"gonum.org/v1/gonum/mat"
|
|
"gonum.org/v1/gonum/optimize/functions"
|
|
)
|
|
|
|
type functionThresholdConverger struct {
|
|
Threshold float64
|
|
}
|
|
|
|
func (functionThresholdConverger) Init(dim int) {}
|
|
|
|
func (f functionThresholdConverger) Converged(loc *Location) Status {
|
|
if loc.F < f.Threshold {
|
|
return FunctionThreshold
|
|
}
|
|
return NotTerminated
|
|
}
|
|
|
|
type cmaTestCase struct {
|
|
dim int
|
|
problem Problem
|
|
method *CmaEsChol
|
|
initX []float64
|
|
settings *Settings
|
|
good func(result *Result, err error, concurrent int) error
|
|
}
|
|
|
|
func cmaTestCases() []cmaTestCase {
|
|
localMinMean := []float64{2.2, -2.2}
|
|
s := mat.NewSymDense(2, []float64{0.01, 0, 0, 0.01})
|
|
var localMinChol mat.Cholesky
|
|
localMinChol.Factorize(s)
|
|
return []cmaTestCase{
|
|
{
|
|
// Test that can find a small value.
|
|
dim: 10,
|
|
problem: Problem{
|
|
Func: functions.ExtendedRosenbrock{}.Func,
|
|
},
|
|
method: &CmaEsChol{
|
|
StopLogDet: math.NaN(),
|
|
},
|
|
settings: &Settings{
|
|
Converger: functionThresholdConverger{0.01},
|
|
},
|
|
good: func(result *Result, err error, concurrent int) error {
|
|
if result.Status != FunctionThreshold {
|
|
return errors.New("result not function threshold")
|
|
}
|
|
if result.F > 0.01 {
|
|
return errors.New("result not sufficiently small")
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
// Test that can stop when the covariance gets small.
|
|
// For this case, also test that it is really at a minimum.
|
|
dim: 2,
|
|
problem: Problem{
|
|
Func: functions.ExtendedRosenbrock{}.Func,
|
|
},
|
|
method: &CmaEsChol{},
|
|
settings: &Settings{
|
|
Converger: NeverTerminate{},
|
|
},
|
|
good: func(result *Result, err error, concurrent int) error {
|
|
if result.Status != MethodConverge {
|
|
return errors.New("result not method converge")
|
|
}
|
|
if result.F > 1e-12 {
|
|
return errors.New("minimum not found")
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
// Test that population works properly and it stops after a certain
|
|
// number of iterations.
|
|
dim: 3,
|
|
problem: Problem{
|
|
Func: functions.ExtendedRosenbrock{}.Func,
|
|
},
|
|
method: &CmaEsChol{
|
|
Population: 100,
|
|
ForgetBest: true, // Otherwise may get an update at the end.
|
|
},
|
|
settings: &Settings{
|
|
MajorIterations: 10,
|
|
Converger: NeverTerminate{},
|
|
},
|
|
good: func(result *Result, err error, concurrent int) error {
|
|
if result.Status != IterationLimit {
|
|
return errors.New("result not iteration limit")
|
|
}
|
|
threshLower := 10
|
|
threshUpper := 10
|
|
if concurrent != 0 {
|
|
// Could have one more from final update.
|
|
threshUpper++
|
|
}
|
|
if result.MajorIterations < threshLower || result.MajorIterations > threshUpper {
|
|
return errors.New("wrong number of iterations")
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
// Test that work stops with some number of function evaluations.
|
|
dim: 5,
|
|
problem: Problem{
|
|
Func: functions.ExtendedRosenbrock{}.Func,
|
|
},
|
|
method: &CmaEsChol{
|
|
Population: 100,
|
|
},
|
|
settings: &Settings{
|
|
FuncEvaluations: 250, // Somewhere in the middle of an iteration.
|
|
Converger: NeverTerminate{},
|
|
},
|
|
good: func(result *Result, err error, concurrent int) error {
|
|
if result.Status != FunctionEvaluationLimit {
|
|
return errors.New("result not function evaluations")
|
|
}
|
|
threshLower := 250
|
|
threshUpper := 251
|
|
if concurrent != 0 {
|
|
threshUpper = threshLower + concurrent
|
|
}
|
|
if result.FuncEvaluations < threshLower {
|
|
return errors.New("too few function evaluations")
|
|
}
|
|
if result.FuncEvaluations > threshUpper {
|
|
return errors.New("too many function evaluations")
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
// Test that the global minimum is found with the right initialization.
|
|
dim: 2,
|
|
problem: Problem{
|
|
Func: functions.Rastrigin{}.Func,
|
|
},
|
|
method: &CmaEsChol{
|
|
Population: 200, // Increase the population size to reduce noise.
|
|
},
|
|
settings: &Settings{
|
|
Converger: NeverTerminate{},
|
|
},
|
|
good: func(result *Result, err error, concurrent int) error {
|
|
if result.Status != MethodConverge {
|
|
return errors.New("result not method converge")
|
|
}
|
|
if !floats.EqualApprox(result.X, []float64{0, 0}, 1e-6) {
|
|
return errors.New("global minimum not found")
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
// Test that a local minimum is found (with a different initialization).
|
|
dim: 2,
|
|
problem: Problem{
|
|
Func: functions.Rastrigin{}.Func,
|
|
},
|
|
initX: localMinMean,
|
|
method: &CmaEsChol{
|
|
Population: 100, // Increase the population size to reduce noise.
|
|
InitCholesky: &localMinChol,
|
|
ForgetBest: true, // So that if it accidentally finds a better place we still converge to the minimum.
|
|
},
|
|
settings: &Settings{
|
|
Converger: NeverTerminate{},
|
|
},
|
|
good: func(result *Result, err error, concurrent int) error {
|
|
if result.Status != MethodConverge {
|
|
return errors.New("result not method converge")
|
|
}
|
|
if !floats.EqualApprox(result.X, []float64{2, -2}, 3e-2) {
|
|
return errors.New("local minimum not found")
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestCmaEsChol(t *testing.T) {
|
|
t.Parallel()
|
|
for i, test := range cmaTestCases() {
|
|
src := rand.New(rand.NewPCG(1, 1))
|
|
method := test.method
|
|
method.Src = src
|
|
initX := test.initX
|
|
if initX == nil {
|
|
initX = make([]float64, test.dim)
|
|
}
|
|
// Run and check that the expected termination occurs.
|
|
result, err := Minimize(test.problem, initX, test.settings, method)
|
|
if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
|
|
t.Errorf("cas %d: %v", i, testErr)
|
|
}
|
|
|
|
// Run a second time to make sure there are no residual effects
|
|
result, err = Minimize(test.problem, initX, test.settings, method)
|
|
if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
|
|
t.Errorf("cas %d second: %v", i, testErr)
|
|
}
|
|
|
|
// Test the problem in parallel.
|
|
test.settings.Concurrent = 5
|
|
result, err = Minimize(test.problem, initX, test.settings, method)
|
|
if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
|
|
t.Errorf("cas %d concurrent: %v", i, testErr)
|
|
}
|
|
test.settings.Concurrent = 0
|
|
}
|
|
}
|