| // Copyright ©2018 The Gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package optimize |
| |
| import ( |
| "testing" |
| |
| "golang.org/x/exp/rand" |
| |
| "gonum.org/v1/gonum/floats" |
| "gonum.org/v1/gonum/mat" |
| "gonum.org/v1/gonum/optimize/functions" |
| ) |
| |
| func TestListSearch(t *testing.T) { |
| t.Parallel() |
| rnd := rand.New(rand.NewSource(1)) |
| for cas, test := range []struct { |
| r, c int |
| shortEvals int |
| fun func([]float64) float64 |
| }{ |
| { |
| r: 100, |
| c: 10, |
| fun: functions.ExtendedRosenbrock{}.Func, |
| }, |
| } { |
| // Generate a random list of items. |
| r, c := test.r, test.c |
| locs := mat.NewDense(r, c, nil) |
| for i := 0; i < r; i++ { |
| for j := 0; j < c; j++ { |
| locs.Set(i, j, rnd.NormFloat64()) |
| } |
| } |
| |
| // Evaluate all of the items in the list and find the minimum value. |
| fs := make([]float64, r) |
| for i := 0; i < r; i++ { |
| fs[i] = test.fun(locs.RawRowView(i)) |
| } |
| minIdx := floats.MinIdx(fs) |
| |
| // Check that the global minimum is found under normal conditions. |
| p := Problem{Func: test.fun} |
| method := &ListSearch{ |
| Locs: locs, |
| } |
| settings := &Settings{ |
| Converger: NeverTerminate{}, |
| } |
| initX := make([]float64, c) |
| result, err := Minimize(p, initX, settings, method) |
| if err != nil { |
| t.Errorf("cas %v: error optimizing: %s", cas, err) |
| } |
| if result.Status != MethodConverge { |
| t.Errorf("cas %v: status should be MethodConverge", cas) |
| } |
| if !floats.Equal(result.X, locs.RawRowView(minIdx)) { |
| t.Errorf("cas %v: did not find minimum of whole list", cas) |
| } |
| |
| // Check that the optimization works concurrently. |
| concurrent := 6 |
| settings.Concurrent = concurrent |
| result, err = Minimize(p, initX, settings, method) |
| if err != nil { |
| t.Errorf("cas %v: error optimizing: %s", cas, err) |
| } |
| if result.Status != MethodConverge { |
| t.Errorf("cas %v: status should be MethodConverge", cas) |
| } |
| if !floats.Equal(result.X, locs.RawRowView(minIdx)) { |
| t.Errorf("cas %v: did not find minimum of whole list concurrent", cas) |
| } |
| |
| // Check that the optimization works concurrently with more than the number of samples. |
| settings.Concurrent = test.r + concurrent |
| result, err = Minimize(p, initX, settings, method) |
| if err != nil { |
| t.Errorf("cas %v: error optimizing: %s", cas, err) |
| } |
| if result.Status != MethodConverge { |
| t.Errorf("cas %v: status should be MethodConverge", cas) |
| } |
| if !floats.Equal(result.X, locs.RawRowView(minIdx)) { |
| t.Errorf("cas %v: did not find minimum of whole list concurrent", cas) |
| } |
| |
| // Check that cleanup happens properly by setting the minimum location |
| // to the last sample. |
| swapSamples(locs, fs, minIdx, test.r-1) |
| minIdx = test.r - 1 |
| settings.Concurrent = concurrent |
| result, err = Minimize(p, initX, settings, method) |
| if err != nil { |
| t.Errorf("cas %v: error optimizing: %s", cas, err) |
| } |
| if result.Status != MethodConverge { |
| t.Errorf("cas %v: status should be MethodConverge", cas) |
| } |
| if !floats.Equal(result.X, locs.RawRowView(minIdx)) { |
| t.Errorf("cas %v: did not find minimum of whole list last sample", cas) |
| } |
| |
| // Test that the correct optimum is found when the optimization ends early. |
| // Note that the above test swapped the list minimum to the last sample, |
| // so it's guaranteed that the minimum of the shortened list is not the |
| // same as the minimum of the whole list. |
| evals := test.r / 3 |
| minIdxFirst := floats.MinIdx(fs[:evals]) |
| settings.Concurrent = 0 |
| settings.FuncEvaluations = evals |
| result, err = Minimize(p, initX, settings, method) |
| if err != nil { |
| t.Errorf("cas %v: error optimizing: %s", cas, err) |
| } |
| if result.Status != FunctionEvaluationLimit { |
| t.Errorf("cas %v: status was not FunctionEvaluationLimit", cas) |
| } |
| if !floats.Equal(result.X, locs.RawRowView(minIdxFirst)) { |
| t.Errorf("cas %v: did not find minimum of shortened list serial", cas) |
| } |
| |
| // Test the same but concurrently. We can't guarantee a specific number |
| // of function evaluations concurrently, so make sure that the list optimum |
| // is not between [evals:evals+concurrent] |
| for floats.MinIdx(fs[:evals]) != floats.MinIdx(fs[:evals+concurrent]) { |
| // Swap the minimum index with a random element. |
| minIdxFirst := floats.MinIdx(fs[:evals+concurrent]) |
| new := rnd.Intn(evals) |
| swapSamples(locs, fs, minIdxFirst, new) |
| } |
| |
| minIdxFirst = floats.MinIdx(fs[:evals]) |
| settings.Concurrent = concurrent |
| result, err = Minimize(p, initX, settings, method) |
| if err != nil { |
| t.Errorf("cas %v: error optimizing: %s", cas, err) |
| } |
| if result.Status != FunctionEvaluationLimit { |
| t.Errorf("cas %v: status was not FunctionEvaluationLimit", cas) |
| } |
| if !floats.Equal(result.X, locs.RawRowView(minIdxFirst)) { |
| t.Errorf("cas %v: did not find minimum of shortened list concurrent", cas) |
| } |
| } |
| } |
| |
| func swapSamples(m *mat.Dense, f []float64, i, j int) { |
| f[i], f[j] = f[j], f[i] |
| row := mat.Row(nil, i, m) |
| m.SetRow(i, m.RawRowView(j)) |
| m.SetRow(j, row) |
| } |