blob: 0774ad731e3d68d93ba994db8c52ab1f971c6b44 [file] [log] [blame]
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fd
import (
"math"
"math/rand"
"testing"
"gonum.org/v1/gonum/floats"
)
type Rosenbrock struct {
nDim int
}
func (r Rosenbrock) F(x []float64) (sum float64) {
deriv := make([]float64, len(x))
return r.FDf(x, deriv)
}
func (r Rosenbrock) FDf(x []float64, deriv []float64) (sum float64) {
for i := range deriv {
deriv[i] = 0
}
for i := 0; i < len(x)-1; i++ {
sum += math.Pow(1-x[i], 2) + 100*math.Pow(x[i+1]-math.Pow(x[i], 2), 2)
}
for i := 0; i < len(x)-1; i++ {
deriv[i] += -1 * 2 * (1 - x[i])
deriv[i] += 2 * 100 * (x[i+1] - math.Pow(x[i], 2)) * (-2 * x[i])
}
for i := 1; i < len(x); i++ {
deriv[i] += 2 * 100 * (x[i] - math.Pow(x[i-1], 2))
}
return sum
}
func TestGradient(t *testing.T) {
rand.Seed(1)
for i, test := range []struct {
nDim int
tol float64
formula Formula
}{
{
nDim: 2,
tol: 2e-4,
formula: Forward,
},
{
nDim: 2,
tol: 1e-6,
formula: Central,
},
{
nDim: 40,
tol: 2e-4,
formula: Forward,
},
{
nDim: 40,
tol: 1e-6,
formula: Central,
},
} {
x := make([]float64, test.nDim)
for i := range x {
x[i] = rand.Float64()
}
xcopy := make([]float64, len(x))
copy(xcopy, x)
r := Rosenbrock{len(x)}
trueGradient := make([]float64, len(x))
r.FDf(x, trueGradient)
// Try with gradient nil.
gradient := Gradient(nil, r.F, x, &Settings{
Formula: test.formula,
})
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch in serial with nil. Want: %v, Got: %v.", i, trueGradient, gradient)
}
if !floats.Equal(x, xcopy) {
t.Errorf("Case %v: x modified during call to gradient in serial with nil.", i)
}
// Try with provided gradient.
for i := range gradient {
gradient[i] = rand.Float64()
}
Gradient(gradient, r.F, x, &Settings{
Formula: test.formula,
})
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch in serial. Want: %v, Got: %v.", i, trueGradient, gradient)
}
if !floats.Equal(x, xcopy) {
t.Errorf("Case %v: x modified during call to gradient in serial with non-nil.", i)
}
// Try with known value.
for i := range gradient {
gradient[i] = rand.Float64()
}
Gradient(gradient, r.F, x, &Settings{
Formula: test.formula,
OriginKnown: true,
OriginValue: r.F(x),
})
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch with known origin in serial. Want: %v, Got: %v.", i, trueGradient, gradient)
}
// Try with concurrent evaluation.
for i := range gradient {
gradient[i] = rand.Float64()
}
Gradient(gradient, r.F, x, &Settings{
Formula: test.formula,
Concurrent: true,
})
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch with unknown origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient)
}
if !floats.Equal(x, xcopy) {
t.Errorf("Case %v: x modified during call to gradient in parallel", i)
}
// Try with concurrent evaluation with origin known.
for i := range gradient {
gradient[i] = rand.Float64()
}
Gradient(gradient, r.F, x, &Settings{
Formula: test.formula,
Concurrent: true,
OriginKnown: true,
OriginValue: r.F(x),
})
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch with known origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient)
}
// Try with nil settings.
for i := range gradient {
gradient[i] = rand.Float64()
}
Gradient(gradient, r.F, x, nil)
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch with default settings. Want: %v, Got: %v.", i, trueGradient, gradient)
}
// Try with zero-valued settings.
for i := range gradient {
gradient[i] = rand.Float64()
}
Gradient(gradient, r.F, x, &Settings{})
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch with zero settings. Want: %v, Got: %v.", i, trueGradient, gradient)
}
}
}
func Panics(fun func()) (b bool) {
defer func() {
err := recover()
if err != nil {
b = true
}
}()
fun()
return
}
func TestGradientPanics(t *testing.T) {
// Test that it panics
if !Panics(func() {
Gradient([]float64{0.0}, func(x []float64) float64 { return x[0] * x[0] }, []float64{0.0, 0.0}, nil)
}) {
t.Errorf("Gradient did not panic with length mismatch")
}
}