mirror of
https://github.com/gonum/gonum.git
synced 2025-09-27 03:26:04 +08:00
109 lines
2.9 KiB
Go
109 lines
2.9 KiB
Go
// Copyright ©2014 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package optimize
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"os"
|
|
"time"
|
|
|
|
"gonum.org/v1/gonum/floats"
|
|
)
|
|
|
|
var printerHeadings = [...]string{
|
|
"Iter",
|
|
"Runtime",
|
|
"FuncEvals",
|
|
"Func",
|
|
"GradEvals",
|
|
"|Gradient|∞",
|
|
"HessEvals",
|
|
}
|
|
|
|
const (
|
|
printerBaseTmpl = "%9v %16v %9v %22v" // Base template for headings and values that are always printed.
|
|
printerGradTmpl = " %9v %22v" // Appended to base template when loc.Gradient != nil.
|
|
printerHessTmpl = " %9v" // Appended to base template when loc.Hessian != nil.
|
|
)
|
|
|
|
var _ Recorder = (*Printer)(nil)
|
|
|
|
// Printer writes column-format output to the specified writer as the optimization
|
|
// progresses. By default, it writes to os.Stdout.
|
|
type Printer struct {
|
|
Writer io.Writer
|
|
HeadingInterval int
|
|
ValueInterval time.Duration
|
|
|
|
lastHeading int
|
|
lastValue time.Time
|
|
}
|
|
|
|
func NewPrinter() *Printer {
|
|
return &Printer{
|
|
Writer: os.Stdout,
|
|
HeadingInterval: 30,
|
|
ValueInterval: 500 * time.Millisecond,
|
|
}
|
|
}
|
|
|
|
func (p *Printer) Init() error {
|
|
p.lastHeading = p.HeadingInterval // So the headings are printed the first time.
|
|
p.lastValue = time.Now().Add(-p.ValueInterval) // So the values are printed the first time.
|
|
return nil
|
|
}
|
|
|
|
func (p *Printer) Record(loc *Location, op Operation, stats *Stats) error {
|
|
if op != MajorIteration && op != InitIteration && op != PostIteration {
|
|
return nil
|
|
}
|
|
|
|
// Print values always on PostIteration or when ValueInterval has elapsed.
|
|
printValues := time.Since(p.lastValue) > p.ValueInterval || op == PostIteration
|
|
if !printValues {
|
|
// Return early if not printing anything.
|
|
return nil
|
|
}
|
|
|
|
// Print heading when HeadingInterval lines have been printed, but never on PostIteration.
|
|
printHeading := p.lastHeading >= p.HeadingInterval && op != PostIteration
|
|
if printHeading {
|
|
p.lastHeading = 1
|
|
} else {
|
|
p.lastHeading++
|
|
}
|
|
|
|
if printHeading {
|
|
headings := "\n" + fmt.Sprintf(printerBaseTmpl, printerHeadings[0], printerHeadings[1], printerHeadings[2], printerHeadings[3])
|
|
if loc.Gradient != nil {
|
|
headings += fmt.Sprintf(printerGradTmpl, printerHeadings[4], printerHeadings[5])
|
|
}
|
|
if loc.Hessian != nil {
|
|
headings += fmt.Sprintf(printerHessTmpl, printerHeadings[6])
|
|
}
|
|
_, err := fmt.Fprintln(p.Writer, headings)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
values := fmt.Sprintf(printerBaseTmpl, stats.MajorIterations, stats.Runtime, stats.FuncEvaluations, loc.F)
|
|
if loc.Gradient != nil {
|
|
values += fmt.Sprintf(printerGradTmpl, stats.GradEvaluations, floats.Norm(loc.Gradient, math.Inf(1)))
|
|
}
|
|
if loc.Hessian != nil {
|
|
values += fmt.Sprintf(printerHessTmpl, stats.HessEvaluations)
|
|
}
|
|
_, err := fmt.Fprintln(p.Writer, values)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
p.lastValue = time.Now()
|
|
return nil
|
|
}
|