diff --git a/optimize/backtracking.go b/optimize/backtracking.go index 2ab44e44..7b4426f9 100644 --- a/optimize/backtracking.go +++ b/optimize/backtracking.go @@ -28,6 +28,7 @@ var _ Linesearcher = (*Backtracking)(nil) type Backtracking struct { DecreaseFactor float64 // Constant factor in the sufficient decrease (Armijo) condition. ContractionFactor float64 // Step size multiplier at each iteration (step *= ContractionFactor). + MinimumStepSize float64 // Smallest allowed step size; line search fails if step shrinks below this value. stepSize float64 initF float64 @@ -50,12 +51,18 @@ func (b *Backtracking) Init(f, g float64, step float64) Operation { if b.DecreaseFactor == 0 { b.DecreaseFactor = defaultBacktrackingDecrease } + if b.MinimumStepSize == 0 { + b.MinimumStepSize = minimumBacktrackingStepSize + } if b.ContractionFactor <= 0 || b.ContractionFactor >= 1 { panic("backtracking: ContractionFactor must be between 0 and 1") } if b.DecreaseFactor <= 0 || b.DecreaseFactor >= 1 { panic("backtracking: DecreaseFactor must be between 0 and 1") } + if b.MinimumStepSize < 0 { + panic("backtracking: MinimumStepSize must be positive") + } b.stepSize = step b.initF = f @@ -75,7 +82,7 @@ func (b *Backtracking) Iterate(f, _ float64) (Operation, float64, error) { return b.lastOp, b.stepSize, nil } b.stepSize *= b.ContractionFactor - if b.stepSize < minimumBacktrackingStepSize { + if b.stepSize < b.MinimumStepSize { b.lastOp = NoOperation return b.lastOp, b.stepSize, ErrLinesearcherFailure } diff --git a/optimize/unconstrained_test.go b/optimize/unconstrained_test.go index c7165271..3c31395b 100644 --- a/optimize/unconstrained_test.go +++ b/optimize/unconstrained_test.go @@ -82,19 +82,19 @@ var gradFreeTests = []unconstrainedTest{ }, x: []float64{-5, 4, 16, 3}, }, - { + { name: "Sphere5D", p: Problem{ Func: functions.Sphere{}.Func, }, - x: []float64{0.00001, 1.00001, 2.00001, 3.00001, 4.00001}, + x: []float64{0.00001, 1.00001, 2.00001, 3.00001, 4.00001}, }, { name: "Sphere10D", p: Problem{ Func: functions.Sphere{}.Func, }, - x: []float64{0.00001, 1.00001, 2.00001, 3.00001, 4.00001, 5.00001, 6.00001, 7.00001, 8.00001, 9.00001}, + x: []float64{0.00001, 1.00001, 2.00001, 3.00001, 4.00001, 5.00001, 6.00001, 7.00001, 8.00001, 9.00001}, }, } @@ -1060,6 +1060,16 @@ func TestGradientDescentBacktracking(t *testing.T) { }) } +func TestGradientDescentBacktrackingWithMinimumStepSize(t *testing.T) { + t.Parallel() + testLocal(t, gradientDescentTests, &GradientDescent{ + Linesearcher: &Backtracking{ + DecreaseFactor: 0.1, + MinimumStepSize: 2e-8, + }, + }) +} + func TestGradientDescentBisection(t *testing.T) { t.Parallel() testLocal(t, gradientDescentTests, &GradientDescent{