| // Copyright ©2015 The Gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package mat |
| |
| import ( |
| "testing" |
| |
| "golang.org/x/exp/rand" |
| ) |
| |
| func TestSolve(t *testing.T) { |
| t.Parallel() |
| rnd := rand.New(rand.NewSource(1)) |
| // Hand-coded cases. |
| for _, test := range []struct { |
| a [][]float64 |
| b [][]float64 |
| ans [][]float64 |
| shouldErr bool |
| }{ |
| { |
| a: [][]float64{{6}}, |
| b: [][]float64{{3}}, |
| ans: [][]float64{{0.5}}, |
| shouldErr: false, |
| }, |
| { |
| a: [][]float64{ |
| {1, 0, 0}, |
| {0, 1, 0}, |
| {0, 0, 1}, |
| }, |
| b: [][]float64{ |
| {3}, |
| {2}, |
| {1}, |
| }, |
| ans: [][]float64{ |
| {3}, |
| {2}, |
| {1}, |
| }, |
| shouldErr: false, |
| }, |
| { |
| a: [][]float64{ |
| {0.8147, 0.9134, 0.5528}, |
| {0.9058, 0.6324, 0.8723}, |
| {0.1270, 0.0975, 0.7612}, |
| }, |
| b: [][]float64{ |
| {0.278}, |
| {0.547}, |
| {0.958}, |
| }, |
| ans: [][]float64{ |
| {-0.932687281002860}, |
| {0.303963920182067}, |
| {1.375216503507109}, |
| }, |
| shouldErr: false, |
| }, |
| { |
| a: [][]float64{ |
| {0.8147, 0.9134, 0.5528}, |
| {0.9058, 0.6324, 0.8723}, |
| }, |
| b: [][]float64{ |
| {0.278}, |
| {0.547}, |
| }, |
| ans: [][]float64{ |
| {0.25919787248965376}, |
| {-0.25560256266441034}, |
| {0.5432324059702451}, |
| }, |
| shouldErr: false, |
| }, |
| { |
| a: [][]float64{ |
| {0.8147, 0.9134, 0.9}, |
| {0.9058, 0.6324, 0.9}, |
| {0.1270, 0.0975, 0.1}, |
| {1.6, 2.8, -3.5}, |
| }, |
| b: [][]float64{ |
| {0.278}, |
| {0.547}, |
| {-0.958}, |
| {1.452}, |
| }, |
| ans: [][]float64{ |
| {0.820970340787782}, |
| {-0.218604626527306}, |
| {-0.212938815234215}, |
| }, |
| shouldErr: false, |
| }, |
| { |
| a: [][]float64{ |
| {0.8147, 0.9134, 0.231, -1.65}, |
| {0.9058, 0.6324, 0.9, 0.72}, |
| {0.1270, 0.0975, 0.1, 1.723}, |
| {1.6, 2.8, -3.5, 0.987}, |
| {7.231, 9.154, 1.823, 0.9}, |
| }, |
| b: [][]float64{ |
| {0.278, 8.635}, |
| {0.547, 9.125}, |
| {-0.958, -0.762}, |
| {1.452, 1.444}, |
| {1.999, -7.234}, |
| }, |
| ans: [][]float64{ |
| {1.863006789511373, 44.467887791812750}, |
| {-1.127270935407224, -34.073794226035126}, |
| {-0.527926457947330, -8.032133759788573}, |
| {-0.248621916204897, -2.366366415805275}, |
| }, |
| shouldErr: false, |
| }, |
| { |
| a: [][]float64{ |
| {0, 0}, |
| {0, 0}, |
| }, |
| b: [][]float64{ |
| {3}, |
| {2}, |
| }, |
| ans: nil, |
| shouldErr: true, |
| }, |
| { |
| a: [][]float64{ |
| {0, 0}, |
| {0, 0}, |
| {0, 0}, |
| }, |
| b: [][]float64{ |
| {3}, |
| {2}, |
| {1}, |
| }, |
| ans: nil, |
| shouldErr: true, |
| }, |
| { |
| a: [][]float64{ |
| {0, 0, 0}, |
| {0, 0, 0}, |
| }, |
| b: [][]float64{ |
| {3}, |
| {2}, |
| }, |
| ans: nil, |
| shouldErr: true, |
| }, |
| } { |
| a := NewDense(flatten(test.a)) |
| b := NewDense(flatten(test.b)) |
| |
| var ans *Dense |
| if test.ans != nil { |
| ans = NewDense(flatten(test.ans)) |
| } |
| |
| var x Dense |
| err := x.Solve(a, b) |
| if err != nil { |
| if !test.shouldErr { |
| t.Errorf("Unexpected solve error: %s", err) |
| } |
| continue |
| } |
| if err == nil && test.shouldErr { |
| t.Errorf("Did not error during solve.") |
| continue |
| } |
| if !EqualApprox(&x, ans, 1e-12) { |
| t.Errorf("Solve answer mismatch. Want %v, got %v", ans, x) |
| } |
| } |
| |
| // Random Cases. |
| for _, test := range []struct { |
| m, n, bc int |
| }{ |
| {5, 5, 1}, |
| {5, 10, 1}, |
| {10, 5, 1}, |
| {5, 5, 7}, |
| {5, 10, 7}, |
| {10, 5, 7}, |
| {5, 5, 12}, |
| {5, 10, 12}, |
| {10, 5, 12}, |
| } { |
| m := test.m |
| n := test.n |
| bc := test.bc |
| a := NewDense(m, n, nil) |
| for i := 0; i < m; i++ { |
| for j := 0; j < n; j++ { |
| a.Set(i, j, rnd.Float64()) |
| } |
| } |
| br := m |
| b := NewDense(br, bc, nil) |
| for i := 0; i < br; i++ { |
| for j := 0; j < bc; j++ { |
| b.Set(i, j, rnd.Float64()) |
| } |
| } |
| var x Dense |
| err := x.Solve(a, b) |
| if err != nil { |
| t.Errorf("unexpected error from dense solve: %v", err) |
| } |
| |
| // Test that the normal equations hold. |
| // Aᵀ * A * x = Aᵀ * b |
| var tmp, lhs, rhs Dense |
| tmp.Mul(a.T(), a) |
| lhs.Mul(&tmp, &x) |
| rhs.Mul(a.T(), b) |
| if !EqualApprox(&lhs, &rhs, 1e-10) { |
| t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) |
| } |
| } |
| |
| // Use testTwoInput. |
| method := func(receiver, a, b Matrix) { |
| type Solver interface { |
| Solve(a, b Matrix) error |
| } |
| rd := receiver.(Solver) |
| _ = rd.Solve(a, b) |
| } |
| denseComparison := func(receiver, a, b *Dense) { |
| _ = receiver.Solve(a, b) |
| } |
| testTwoInput(t, "Solve", &Dense{}, method, denseComparison, legalTypesAll, legalSizeSolve, 1e-7) |
| } |
| |
| func TestSolveVec(t *testing.T) { |
| t.Parallel() |
| rnd := rand.New(rand.NewSource(1)) |
| for _, test := range []struct { |
| m, n int |
| }{ |
| {5, 5}, |
| {5, 10}, |
| {10, 5}, |
| {5, 5}, |
| {5, 10}, |
| {10, 5}, |
| {5, 5}, |
| {5, 10}, |
| {10, 5}, |
| } { |
| m := test.m |
| n := test.n |
| a := NewDense(m, n, nil) |
| for i := 0; i < m; i++ { |
| for j := 0; j < n; j++ { |
| a.Set(i, j, rnd.Float64()) |
| } |
| } |
| br := m |
| b := NewVecDense(br, nil) |
| for i := 0; i < br; i++ { |
| b.SetVec(i, rnd.Float64()) |
| } |
| var x VecDense |
| err := x.SolveVec(a, b) |
| if err != nil { |
| t.Errorf("unexpected error from dense vector solve: %v", err) |
| } |
| |
| // Test that the normal equations hold. |
| // Aᵀ * A * x = Aᵀ * b |
| var tmp, lhs, rhs Dense |
| tmp.Mul(a.T(), a) |
| lhs.Mul(&tmp, &x) |
| rhs.Mul(a.T(), b) |
| if !EqualApprox(&lhs, &rhs, 1e-10) { |
| t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) |
| } |
| } |
| |
| // Use testTwoInput |
| method := func(receiver, a, b Matrix) { |
| type SolveVecer interface { |
| SolveVec(a Matrix, b Vector) error |
| } |
| rd := receiver.(SolveVecer) |
| _ = rd.SolveVec(a, b.(Vector)) |
| } |
| denseComparison := func(receiver, a, b *Dense) { |
| _ = receiver.Solve(a, b) |
| } |
| testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12) |
| } |