| // Copyright ©2017 The Gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package fd |
| |
| import "gonum.org/v1/gonum/floats" |
| |
| // Gradient estimates the gradient of the multivariate function f at the |
| // location x. If dst is not nil, the result will be stored in-place into dst |
| // and returned, otherwise a new slice will be allocated first. Finite |
| // difference formula and other options are specified by settings. If settings is |
| // nil, the gradient will be estimated using the Forward formula and a default |
| // step size. |
| // |
| // Gradient panics if the length of dst and x is not equal, or if the derivative |
| // order of the formula is not 1. |
| func Gradient(dst []float64, f func([]float64) float64, x []float64, settings *Settings) []float64 { |
| if dst == nil { |
| dst = make([]float64, len(x)) |
| } |
| if len(dst) != len(x) { |
| panic("fd: slice length mismatch") |
| } |
| |
| // Default settings. |
| formula := Forward |
| step := formula.Step |
| var originValue float64 |
| var originKnown, concurrent bool |
| |
| // Use user settings if provided. |
| if settings != nil { |
| if !settings.Formula.isZero() { |
| formula = settings.Formula |
| step = formula.Step |
| checkFormula(formula) |
| if formula.Derivative != 1 { |
| panic(badDerivOrder) |
| } |
| } |
| if settings.Step != 0 { |
| step = settings.Step |
| } |
| originKnown = settings.OriginKnown |
| originValue = settings.OriginValue |
| concurrent = settings.Concurrent |
| } |
| |
| evals := len(formula.Stencil) * len(x) |
| nWorkers := computeWorkers(concurrent, evals) |
| |
| hasOrigin := usesOrigin(formula.Stencil) |
| // Copy x in case it is modified during the call. |
| xcopy := make([]float64, len(x)) |
| if hasOrigin && !originKnown { |
| copy(xcopy, x) |
| originValue = f(xcopy) |
| } |
| |
| if nWorkers == 1 { |
| for i := range xcopy { |
| var deriv float64 |
| for _, pt := range formula.Stencil { |
| if pt.Loc == 0 { |
| deriv += pt.Coeff * originValue |
| continue |
| } |
| // Copying the data anew has two benefits. First, it |
| // avoids floating point issues where adding and then |
| // subtracting the step don't return to the exact same |
| // location. Secondly, it protects against the function |
| // modifying the input data. |
| copy(xcopy, x) |
| xcopy[i] += pt.Loc * step |
| deriv += pt.Coeff * f(xcopy) |
| } |
| dst[i] = deriv / step |
| } |
| return dst |
| } |
| |
| sendChan := make(chan fdrun, evals) |
| ansChan := make(chan fdrun, evals) |
| quit := make(chan struct{}) |
| defer close(quit) |
| |
| // Launch workers. Workers receive an index and a step, and compute the answer. |
| for i := 0; i < nWorkers; i++ { |
| go func(sendChan <-chan fdrun, ansChan chan<- fdrun, quit <-chan struct{}) { |
| xcopy := make([]float64, len(x)) |
| for { |
| select { |
| case <-quit: |
| return |
| case run := <-sendChan: |
| // See above comment on the copy. |
| copy(xcopy, x) |
| xcopy[run.idx] += run.pt.Loc * step |
| run.result = f(xcopy) |
| ansChan <- run |
| } |
| } |
| }(sendChan, ansChan, quit) |
| } |
| |
| // Launch the distributor. Distributor sends the cases to be computed. |
| go func(sendChan chan<- fdrun, ansChan chan<- fdrun) { |
| for i := range x { |
| for _, pt := range formula.Stencil { |
| if pt.Loc == 0 { |
| // Answer already known. Send the answer on the answer channel. |
| ansChan <- fdrun{ |
| idx: i, |
| pt: pt, |
| result: originValue, |
| } |
| continue |
| } |
| // Answer not known, send the answer to be computed. |
| sendChan <- fdrun{ |
| idx: i, |
| pt: pt, |
| } |
| } |
| } |
| }(sendChan, ansChan) |
| |
| for i := range dst { |
| dst[i] = 0 |
| } |
| // Read in all of the results. |
| for i := 0; i < evals; i++ { |
| run := <-ansChan |
| dst[run.idx] += run.pt.Coeff * run.result |
| } |
| floats.Scale(1/step, dst) |
| return dst |
| } |
| |
| type fdrun struct { |
| idx int |
| pt Point |
| result float64 |
| } |