blob: 04ff14f97b100984ea4ba9215df31df2d51231a0 [file] [log] [blame]
// Copyright ©2017 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package optimize
import (
"errors"
"math"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/optimize/functions"
)
type functionThresholdConverger struct {
Threshold float64
}
func (functionThresholdConverger) Init(dim int) {}
func (f functionThresholdConverger) Converged(loc *Location) Status {
if loc.F < f.Threshold {
return FunctionThreshold
}
return NotTerminated
}
type cmaTestCase struct {
dim int
problem Problem
method *CmaEsChol
initX []float64
settings *Settings
good func(result *Result, err error, concurrent int) error
}
func cmaTestCases() []cmaTestCase {
localMinMean := []float64{2.2, -2.2}
s := mat.NewSymDense(2, []float64{0.01, 0, 0, 0.01})
var localMinChol mat.Cholesky
localMinChol.Factorize(s)
return []cmaTestCase{
{
// Test that can find a small value.
dim: 10,
problem: Problem{
Func: functions.ExtendedRosenbrock{}.Func,
},
method: &CmaEsChol{
StopLogDet: math.NaN(),
},
settings: &Settings{
Converger: functionThresholdConverger{0.01},
},
good: func(result *Result, err error, concurrent int) error {
if result.Status != FunctionThreshold {
return errors.New("result not function threshold")
}
if result.F > 0.01 {
return errors.New("result not sufficiently small")
}
return nil
},
},
{
// Test that can stop when the covariance gets small.
// For this case, also test that it is really at a minimum.
dim: 2,
problem: Problem{
Func: functions.ExtendedRosenbrock{}.Func,
},
method: &CmaEsChol{},
settings: &Settings{
Converger: NeverTerminate{},
},
good: func(result *Result, err error, concurrent int) error {
if result.Status != MethodConverge {
return errors.New("result not method converge")
}
if result.F > 1e-12 {
return errors.New("minimum not found")
}
return nil
},
},
{
// Test that population works properly and it stops after a certain
// number of iterations.
dim: 3,
problem: Problem{
Func: functions.ExtendedRosenbrock{}.Func,
},
method: &CmaEsChol{
Population: 100,
ForgetBest: true, // Otherwise may get an update at the end.
},
settings: &Settings{
MajorIterations: 10,
Converger: NeverTerminate{},
},
good: func(result *Result, err error, concurrent int) error {
if result.Status != IterationLimit {
return errors.New("result not iteration limit")
}
threshLower := 10
threshUpper := 10
if concurrent != 0 {
// Could have one more from final update.
threshUpper++
}
if result.MajorIterations < threshLower || result.MajorIterations > threshUpper {
return errors.New("wrong number of iterations")
}
return nil
},
},
{
// Test that work stops with some number of function evaluations.
dim: 5,
problem: Problem{
Func: functions.ExtendedRosenbrock{}.Func,
},
method: &CmaEsChol{
Population: 100,
},
settings: &Settings{
FuncEvaluations: 250, // Somewhere in the middle of an iteration.
Converger: NeverTerminate{},
},
good: func(result *Result, err error, concurrent int) error {
if result.Status != FunctionEvaluationLimit {
return errors.New("result not function evaluations")
}
threshLower := 250
threshUpper := 251
if concurrent != 0 {
threshUpper = threshLower + concurrent
}
if result.FuncEvaluations < threshLower {
return errors.New("too few function evaluations")
}
if result.FuncEvaluations > threshUpper {
return errors.New("too many function evaluations")
}
return nil
},
},
{
// Test that the global minimum is found with the right initialization.
dim: 2,
problem: Problem{
Func: functions.Rastrigin{}.Func,
},
method: &CmaEsChol{
Population: 100, // Increase the population size to reduce noise.
},
settings: &Settings{
Converger: NeverTerminate{},
},
good: func(result *Result, err error, concurrent int) error {
if result.Status != MethodConverge {
return errors.New("result not method converge")
}
if !floats.EqualApprox(result.X, []float64{0, 0}, 1e-6) {
return errors.New("global minimum not found")
}
return nil
},
},
{
// Test that a local minimum is found (with a different initialization).
dim: 2,
problem: Problem{
Func: functions.Rastrigin{}.Func,
},
initX: localMinMean,
method: &CmaEsChol{
Population: 100, // Increase the population size to reduce noise.
InitCholesky: &localMinChol,
ForgetBest: true, // So that if it accidentally finds a better place we still converge to the minimum.
},
settings: &Settings{
Converger: NeverTerminate{},
},
good: func(result *Result, err error, concurrent int) error {
if result.Status != MethodConverge {
return errors.New("result not method converge")
}
if !floats.EqualApprox(result.X, []float64{2, -2}, 3e-2) {
return errors.New("local minimum not found")
}
return nil
},
},
}
}
func TestCmaEsChol(t *testing.T) {
t.Parallel()
for i, test := range cmaTestCases() {
src := rand.New(rand.NewSource(1))
method := test.method
method.Src = src
initX := test.initX
if initX == nil {
initX = make([]float64, test.dim)
}
// Run and check that the expected termination occurs.
result, err := Minimize(test.problem, initX, test.settings, method)
if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
t.Errorf("cas %d: %v", i, testErr)
}
// Run a second time to make sure there are no residual effects
result, err = Minimize(test.problem, initX, test.settings, method)
if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
t.Errorf("cas %d second: %v", i, testErr)
}
// Test the problem in parallel.
test.settings.Concurrent = 5
result, err = Minimize(test.problem, initX, test.settings, method)
if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
t.Errorf("cas %d concurrent: %v", i, testErr)
}
test.settings.Concurrent = 0
}
}