| // Copyright ©2017 The Gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package fd |
| |
| import ( |
| "math" |
| "sync" |
| |
| "gonum.org/v1/gonum/mat" |
| ) |
| |
| // Hessian approximates the Hessian matrix of the multivariate function f at |
| // the location x. That is |
| // H_{i,j} = ∂^2 f(x)/∂x_i ∂x_j |
| // The resulting H will be stored in dst. Finite difference formula and other |
| // options are specified by settings. If settings is nil, the Hessian will be |
| // estimated using the Forward formula and a default step size. |
| // |
| // If the dst matrix is empty it will be resized to the correct dimensions, |
| // otherwise the dimensions of dst must match the length of x or Hessian will panic. |
| // Hessian will panic if the derivative order of the formula is not 1. |
| func Hessian(dst *mat.SymDense, f func(x []float64) float64, x []float64, settings *Settings) { |
| n := len(x) |
| if dst.IsEmpty() { |
| *dst = *(dst.GrowSym(n).(*mat.SymDense)) |
| } else if dst.Symmetric() != n { |
| panic("hessian: dst size mismatch") |
| } |
| dst.Zero() |
| |
| // Default settings. |
| formula := Forward |
| step := math.Sqrt(formula.Step) // Use the sqrt because taking derivatives of derivatives. |
| var originValue float64 |
| var originKnown, concurrent bool |
| |
| // Use user settings if provided. |
| if settings != nil { |
| if !settings.Formula.isZero() { |
| formula = settings.Formula |
| step = math.Sqrt(formula.Step) |
| checkFormula(formula) |
| if formula.Derivative != 1 { |
| panic(badDerivOrder) |
| } |
| } |
| if settings.Step != 0 { |
| if settings.Step < 0 { |
| panic(negativeStep) |
| } |
| step = settings.Step |
| } |
| originKnown = settings.OriginKnown |
| originValue = settings.OriginValue |
| concurrent = settings.Concurrent |
| } |
| |
| evals := n * (n + 1) / 2 * len(formula.Stencil) * len(formula.Stencil) |
| for _, pt := range formula.Stencil { |
| if pt.Loc == 0 { |
| evals -= n * (n + 1) / 2 |
| break |
| } |
| } |
| |
| nWorkers := computeWorkers(concurrent, evals) |
| if nWorkers == 1 { |
| hessianSerial(dst, f, x, formula.Stencil, step, originKnown, originValue) |
| return |
| } |
| hessianConcurrent(dst, nWorkers, evals, f, x, formula.Stencil, step, originKnown, originValue) |
| } |
| |
| func hessianSerial(dst *mat.SymDense, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) { |
| n := len(x) |
| xCopy := make([]float64, n) |
| fo := func() float64 { |
| // Copy x in case it is modified during the call. |
| copy(xCopy, x) |
| return f(x) |
| } |
| is2 := 1 / (step * step) |
| origin := getOrigin(originKnown, originValue, fo, stencil) |
| for i := 0; i < n; i++ { |
| for j := i; j < n; j++ { |
| var hess float64 |
| for _, pti := range stencil { |
| for _, ptj := range stencil { |
| var v float64 |
| if pti.Loc == 0 && ptj.Loc == 0 { |
| v = origin |
| } else { |
| // Copying the data anew has two benefits. First, it |
| // avoids floating point issues where adding and then |
| // subtracting the step don't return to the exact same |
| // location. Secondly, it protects against the function |
| // modifying the input data. |
| copy(xCopy, x) |
| xCopy[i] += pti.Loc * step |
| xCopy[j] += ptj.Loc * step |
| v = f(xCopy) |
| } |
| hess += v * pti.Coeff * ptj.Coeff * is2 |
| } |
| } |
| dst.SetSym(i, j, hess) |
| } |
| } |
| } |
| |
| func hessianConcurrent(dst *mat.SymDense, nWorkers, evals int, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) { |
| n := dst.Symmetric() |
| type run struct { |
| i, j int |
| iIdx, jIdx int |
| result float64 |
| } |
| |
| send := make(chan run, evals) |
| ans := make(chan run, evals) |
| |
| var originWG sync.WaitGroup |
| hasOrigin := usesOrigin(stencil) |
| if hasOrigin { |
| originWG.Add(1) |
| // Launch worker to compute the origin. |
| go func() { |
| defer originWG.Done() |
| xCopy := make([]float64, len(x)) |
| copy(xCopy, x) |
| originValue = f(xCopy) |
| }() |
| } |
| |
| var workerWG sync.WaitGroup |
| // Launch workers. |
| for i := 0; i < nWorkers; i++ { |
| workerWG.Add(1) |
| go func(send <-chan run, ans chan<- run) { |
| defer workerWG.Done() |
| xCopy := make([]float64, len(x)) |
| for r := range send { |
| if stencil[r.iIdx].Loc == 0 && stencil[r.jIdx].Loc == 0 { |
| originWG.Wait() |
| r.result = originValue |
| } else { |
| // See hessianSerial for comment on the copy. |
| copy(xCopy, x) |
| xCopy[r.i] += stencil[r.iIdx].Loc * step |
| xCopy[r.j] += stencil[r.jIdx].Loc * step |
| r.result = f(xCopy) |
| } |
| ans <- r |
| } |
| }(send, ans) |
| } |
| |
| // Launch the distributor, which sends all of runs. |
| go func(send chan<- run) { |
| for i := 0; i < n; i++ { |
| for j := i; j < n; j++ { |
| for iIdx := range stencil { |
| for jIdx := range stencil { |
| send <- run{ |
| i: i, j: j, iIdx: iIdx, jIdx: jIdx, |
| } |
| } |
| } |
| } |
| } |
| close(send) |
| // Wait for all the workers to quit, then close the ans channel. |
| workerWG.Wait() |
| close(ans) |
| }(send) |
| |
| is2 := 1 / (step * step) |
| // Read in the results. |
| for r := range ans { |
| v := r.result * stencil[r.iIdx].Coeff * stencil[r.jIdx].Coeff * is2 |
| v += dst.At(r.i, r.j) |
| dst.SetSym(r.i, r.j, v) |
| } |
| } |