| // Copyright ©2014 The gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package optimize |
| |
| import ( |
| "fmt" |
| "io" |
| "math" |
| "os" |
| "time" |
| |
| "gonum.org/v1/gonum/floats" |
| ) |
| |
| var printerHeadings = [...]string{ |
| "Iter", |
| "Runtime", |
| "FuncEvals", |
| "Func", |
| "GradEvals", |
| "|Gradient|∞", |
| "HessEvals", |
| } |
| |
| const ( |
| printerBaseTmpl = "%9v %16v %9v %22v" // Base template for headings and values that are always printed. |
| printerGradTmpl = " %9v %22v" // Appended to base template when loc.Gradient != nil. |
| printerHessTmpl = " %9v" // Appended to base template when loc.Hessian != nil. |
| ) |
| |
| // Printer writes column-format output to the specified writer as the optimization |
| // progresses. By default, it writes to os.Stdout. |
| type Printer struct { |
| Writer io.Writer |
| HeadingInterval int |
| ValueInterval time.Duration |
| |
| lastHeading int |
| lastValue time.Time |
| } |
| |
| func NewPrinter() *Printer { |
| return &Printer{ |
| Writer: os.Stdout, |
| HeadingInterval: 30, |
| ValueInterval: 500 * time.Millisecond, |
| } |
| } |
| |
| func (p *Printer) Init() error { |
| p.lastHeading = p.HeadingInterval // So the headings are printed the first time. |
| p.lastValue = time.Now().Add(-p.ValueInterval) // So the values are printed the first time. |
| return nil |
| } |
| |
| func (p *Printer) Record(loc *Location, op Operation, stats *Stats) error { |
| if op != MajorIteration && op != InitIteration && op != PostIteration { |
| return nil |
| } |
| |
| // Print values always on PostIteration or when ValueInterval has elapsed. |
| printValues := time.Since(p.lastValue) > p.ValueInterval || op == PostIteration |
| if !printValues { |
| // Return early if not printing anything. |
| return nil |
| } |
| |
| // Print heading when HeadingInterval lines have been printed, but never on PostIteration. |
| printHeading := p.lastHeading >= p.HeadingInterval && op != PostIteration |
| if printHeading { |
| p.lastHeading = 1 |
| } else { |
| p.lastHeading++ |
| } |
| |
| if printHeading { |
| headings := "\n" + fmt.Sprintf(printerBaseTmpl, printerHeadings[0], printerHeadings[1], printerHeadings[2], printerHeadings[3]) |
| if loc.Gradient != nil { |
| headings += fmt.Sprintf(printerGradTmpl, printerHeadings[4], printerHeadings[5]) |
| } |
| if loc.Hessian != nil { |
| headings += fmt.Sprintf(printerHessTmpl, printerHeadings[6]) |
| } |
| _, err := fmt.Fprintln(p.Writer, headings) |
| if err != nil { |
| return err |
| } |
| } |
| |
| values := fmt.Sprintf(printerBaseTmpl, stats.MajorIterations, stats.Runtime, stats.FuncEvaluations, loc.F) |
| if loc.Gradient != nil { |
| values += fmt.Sprintf(printerGradTmpl, stats.GradEvaluations, floats.Norm(loc.Gradient, math.Inf(1))) |
| } |
| if loc.Hessian != nil { |
| values += fmt.Sprintf(printerHessTmpl, stats.HessEvaluations) |
| } |
| _, err := fmt.Fprintln(p.Writer, values) |
| if err != nil { |
| return err |
| } |
| |
| p.lastValue = time.Now() |
| return nil |
| } |