blob: f3dfdc839b67fbe086c02a747d37cdba79533ff9 [file] [log] [blame]
 // Copyright ©2013 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package mat import ( "math" "testing" "golang.org/x/exp/rand" "gonum.org/v1/gonum/blas/blas64" ) func TestQR(t *testing.T) { t.Parallel() rnd := rand.New(rand.NewSource(1)) for _, test := range []struct { m, n int }{ {5, 5}, {10, 5}, } { m := test.m n := test.n a := NewDense(m, n, nil) for i := 0; i < m; i++ { for j := 0; j < n; j++ { a.Set(i, j, rnd.NormFloat64()) } } var want Dense want.CloneFrom(a) var qr QR qr.Factorize(a) var q, r Dense qr.QTo(&q) if !isOrthonormal(&q, 1e-10) { t.Errorf("Q is not orthonormal: m = %v, n = %v", m, n) } qr.RTo(&r) var got Dense got.Mul(&q, &r) if !EqualApprox(&got, &want, 1e-12) { t.Errorf("QR does not equal original matrix. \nWant: %v\nGot: %v", want, got) } } } func isOrthonormal(q *Dense, tol float64) bool { m, n := q.Dims() if m != n { return false } for i := 0; i < m; i++ { for j := i; j < m; j++ { dot := blas64.Dot(blas64.Vector{N: m, Inc: 1, Data: q.mat.Data[i*q.mat.Stride:]}, blas64.Vector{N: m, Inc: 1, Data: q.mat.Data[j*q.mat.Stride:]}) // Dot product should be 1 if i == j and 0 otherwise. if i == j && math.Abs(dot-1) > tol { return false } if i != j && math.Abs(dot) > tol { return false } } } return true } func TestQRSolveTo(t *testing.T) { t.Parallel() rnd := rand.New(rand.NewSource(1)) for _, trans := range []bool{false, true} { for _, test := range []struct { m, n, bc int }{ {5, 5, 1}, {10, 5, 1}, {5, 5, 3}, {10, 5, 3}, } { m := test.m n := test.n bc := test.bc a := NewDense(m, n, nil) for i := 0; i < m; i++ { for j := 0; j < n; j++ { a.Set(i, j, rnd.Float64()) } } br := m if trans { br = n } b := NewDense(br, bc, nil) for i := 0; i < br; i++ { for j := 0; j < bc; j++ { b.Set(i, j, rnd.Float64()) } } var x Dense var qr QR qr.Factorize(a) err := qr.SolveTo(&x, trans, b) if err != nil { t.Errorf("unexpected error from QR solve: %v", err) } // Test that the normal equations hold. // Aᵀ * A * x = Aᵀ * b if !trans // A * Aᵀ * x = A * b if trans var lhs Dense var rhs Dense if trans { var tmp Dense tmp.Mul(a, a.T()) lhs.Mul(&tmp, &x) rhs.Mul(a, b) } else { var tmp Dense tmp.Mul(a.T(), a) lhs.Mul(&tmp, &x) rhs.Mul(a.T(), b) } if !EqualApprox(&lhs, &rhs, 1e-10) { t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) } } } // TODO(btracey): Add in testOneInput when it exists. } func TestQRSolveVecTo(t *testing.T) { t.Parallel() rnd := rand.New(rand.NewSource(1)) for _, trans := range []bool{false, true} { for _, test := range []struct { m, n int }{ {5, 5}, {10, 5}, } { m := test.m n := test.n a := NewDense(m, n, nil) for i := 0; i < m; i++ { for j := 0; j < n; j++ { a.Set(i, j, rnd.Float64()) } } br := m if trans { br = n } b := NewVecDense(br, nil) for i := 0; i < br; i++ { b.SetVec(i, rnd.Float64()) } var x VecDense var qr QR qr.Factorize(a) err := qr.SolveVecTo(&x, trans, b) if err != nil { t.Errorf("unexpected error from QR solve: %v", err) } // Test that the normal equations hold. // Aᵀ * A * x = Aᵀ * b if !trans // A * Aᵀ * x = A * b if trans var lhs Dense var rhs Dense if trans { var tmp Dense tmp.Mul(a, a.T()) lhs.Mul(&tmp, &x) rhs.Mul(a, b) } else { var tmp Dense tmp.Mul(a.T(), a) lhs.Mul(&tmp, &x) rhs.Mul(a.T(), b) } if !EqualApprox(&lhs, &rhs, 1e-10) { t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) } } } // TODO(btracey): Add in testOneInput when it exists. } func TestQRSolveCondTo(t *testing.T) { t.Parallel() for _, test := range []*Dense{ NewDense(2, 2, []float64{1, 0, 0, 1e-20}), NewDense(3, 2, []float64{1, 0, 0, 1e-20, 0, 0}), } { m, _ := test.Dims() var qr QR qr.Factorize(test) b := NewDense(m, 2, nil) var x Dense if err := qr.SolveTo(&x, false, b); err == nil { t.Error("No error for near-singular matrix in matrix solve.") } bvec := NewVecDense(m, nil) var xvec VecDense if err := qr.SolveVecTo(&xvec, false, bvec); err == nil { t.Error("No error for near-singular matrix in matrix solve.") } } }