blob: b7129f953806a6ca500aab3b0e0f3f75a2d6b3f1 [file] [log] [blame]
// Copyright ©2013 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat
import (
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats"
)
func TestSVD(t *testing.T) {
t.Parallel()
rnd := rand.New(rand.NewSource(1))
// Hand coded tests
for _, test := range []struct {
a *Dense
u *Dense
v *Dense
s []float64
}{
{
a: NewDense(4, 2, []float64{2, 4, 1, 3, 0, 0, 0, 0}),
u: NewDense(4, 2, []float64{
-0.8174155604703632, -0.5760484367663209,
-0.5760484367663209, 0.8174155604703633,
0, 0,
0, 0,
}),
v: NewDense(2, 2, []float64{
-0.4045535848337571, -0.9145142956773044,
-0.9145142956773044, 0.4045535848337571,
}),
s: []float64{5.464985704219041, 0.365966190626258},
},
{
// Issue #5.
a: NewDense(3, 11, []float64{
1, 1, 0, 1, 0, 0, 0, 0, 0, 11, 1,
1, 0, 0, 0, 0, 0, 1, 0, 0, 12, 2,
1, 1, 0, 0, 0, 0, 0, 0, 1, 13, 3,
}),
u: NewDense(3, 3, []float64{
-0.5224167862273765, 0.7864430360363114, 0.3295270133658976,
-0.5739526766688285, -0.03852203026050301, -0.8179818935216693,
-0.6306021141833781, -0.6164603833618163, 0.4715056408282468,
}),
v: NewDense(11, 3, []float64{
-0.08123293141915189, 0.08528085505260324, -0.013165501690885152,
-0.05423546426886932, 0.1102707844980355, 0.622210623111631,
0, 0, 0,
-0.0245733326078166, 0.510179651760153, 0.25596360803140994,
0, 0, 0,
0, 0, 0,
-0.026997467150282436, -0.024989929445430496, -0.6353761248025164,
0, 0, 0,
-0.029662131661052707, -0.3999088672621176, 0.3662470150802212,
-0.9798839760830571, 0.11328174160898856, -0.047702613241813366,
-0.16755466189153964, -0.7395268089170608, 0.08395240366704032,
}),
s: []float64{21.259500881097434, 1.5415021616856566, 1.2873979074613628},
},
} {
var svd SVD
ok := svd.Factorize(test.a, SVDThin)
if !ok {
t.Errorf("SVD failed")
}
s, u, v := extractSVD(&svd)
if !floats.EqualApprox(s, test.s, 1e-10) {
t.Errorf("Singular value mismatch. Got %v, want %v.", s, test.s)
}
if !EqualApprox(u, test.u, 1e-10) {
t.Errorf("U mismatch.\nGot:\n%v\nWant:\n%v", Formatted(u), Formatted(test.u))
}
if !EqualApprox(v, test.v, 1e-10) {
t.Errorf("V mismatch.\nGot:\n%v\nWant:\n%v", Formatted(v), Formatted(test.v))
}
m, n := test.a.Dims()
sigma := NewDense(min(m, n), min(m, n), nil)
for i := 0; i < min(m, n); i++ {
sigma.Set(i, i, s[i])
}
var ans Dense
ans.Product(u, sigma, v.T())
if !EqualApprox(test.a, &ans, 1e-10) {
t.Errorf("A reconstruction mismatch.\nGot:\n%v\nWant:\n%v\n", Formatted(&ans), Formatted(test.a))
}
for _, kind := range []SVDKind{
SVDThinU, SVDFullU, SVDThinV, SVDFullV,
} {
var svd SVD
svd.Factorize(test.a, kind)
if kind&SVDThinU == 0 && kind&SVDFullU == 0 {
panicked, message := panics(func() {
var dst Dense
svd.UTo(&dst)
})
if !panicked {
t.Error("expected panic with no U matrix requested")
continue
}
want := "svd: u not computed during factorization"
if message != want {
t.Errorf("unexpected message: got:%q want:%q", message, want)
}
}
if kind&SVDThinV == 0 && kind&SVDFullV == 0 {
panicked, message := panics(func() {
var dst Dense
svd.VTo(&dst)
})
if !panicked {
t.Error("expected panic with no V matrix requested")
continue
}
want := "svd: v not computed during factorization"
if message != want {
t.Errorf("unexpected message: got:%q want:%q", message, want)
}
}
}
}
for _, test := range []struct {
m, n int
}{
{5, 5},
{5, 3},
{3, 5},
{150, 150},
{200, 150},
{150, 200},
} {
m := test.m
n := test.n
for trial := 0; trial < 10; trial++ {
a := NewDense(m, n, nil)
for i := range a.mat.Data {
a.mat.Data[i] = rnd.NormFloat64()
}
aCopy := DenseCopyOf(a)
// Test Full decomposition.
var svd SVD
ok := svd.Factorize(a, SVDFull)
if !ok {
t.Errorf("SVD factorization failed")
}
if !Equal(a, aCopy) {
t.Errorf("A changed during call to SVD with full")
}
s, u, v := extractSVD(&svd)
sigma := NewDense(m, n, nil)
for i := 0; i < min(m, n); i++ {
sigma.Set(i, i, s[i])
}
var ansFull Dense
ansFull.Product(u, sigma, v.T())
if !EqualApprox(&ansFull, a, 1e-8) {
t.Errorf("Answer mismatch when SVDFull")
}
// Test Thin decomposition.
ok = svd.Factorize(a, SVDThin)
if !ok {
t.Errorf("SVD factorization failed")
}
if !Equal(a, aCopy) {
t.Errorf("A changed during call to SVD with Thin")
}
sThin, u, v := extractSVD(&svd)
if !floats.EqualApprox(s, sThin, 1e-8) {
t.Errorf("Singular value mismatch between Full and Thin decomposition")
}
sigma = NewDense(min(m, n), min(m, n), nil)
for i := 0; i < min(m, n); i++ {
sigma.Set(i, i, sThin[i])
}
ansFull.Reset()
ansFull.Product(u, sigma, v.T())
if !EqualApprox(&ansFull, a, 1e-8) {
t.Errorf("Answer mismatch when SVDFull")
}
// Test None decomposition.
ok = svd.Factorize(a, SVDNone)
if !ok {
t.Errorf("SVD factorization failed")
}
if !Equal(a, aCopy) {
t.Errorf("A changed during call to SVD with none")
}
sNone := make([]float64, min(m, n))
svd.Values(sNone)
if !floats.EqualApprox(s, sNone, 1e-8) {
t.Errorf("Singular value mismatch between Full and None decomposition")
}
}
}
}
func extractSVD(svd *SVD) (s []float64, u, v *Dense) {
u = &Dense{}
svd.UTo(u)
v = &Dense{}
svd.VTo(v)
return svd.Values(nil), u, v
}
func TestSVDSolveTo(t *testing.T) {
t.Parallel()
rnd := rand.New(rand.NewSource(1))
// Hand-coded cases.
for i, test := range []struct {
a []float64
m, n int
b []float64
bc int
rcond float64
want []float64
wm, wn int
}{
{
a: []float64{6}, m: 1, n: 1,
b: []float64{3}, bc: 1,
want: []float64{0.5}, wm: 1, wn: 1,
},
{
a: []float64{
1, 0, 0,
0, 1, 0,
0, 0, 1,
}, m: 3, n: 3,
b: []float64{
3,
2,
1,
}, bc: 1,
want: []float64{
3,
2,
1,
}, wm: 3, wn: 1,
},
{
a: []float64{
0.8147, 0.9134, 0.5528,
0.9058, 0.6324, 0.8723,
0.1270, 0.0975, 0.7612,
}, m: 3, n: 3,
b: []float64{
0.278,
0.547,
0.958,
}, bc: 1,
want: []float64{
-0.932687281002860,
0.303963920182067,
1.375216503507109,
}, wm: 3, wn: 1,
},
{
a: []float64{
0.8147, 0.9134, 0.5528,
0.9058, 0.6324, 0.8723,
}, m: 2, n: 3,
b: []float64{
0.278,
0.547,
}, bc: 1,
want: []float64{
0.25919787248965376,
-0.25560256266441034,
0.5432324059702451,
}, wm: 3, wn: 1,
},
{
a: []float64{
0.8147, 0.9134, 0.9,
0.9058, 0.6324, 0.9,
0.1270, 0.0975, 0.1,
1.6, 2.8, -3.5,
}, m: 4, n: 3,
b: []float64{
0.278,
0.547,
-0.958,
1.452,
}, bc: 1,
want: []float64{
0.820970340787782,
-0.218604626527306,
-0.212938815234215,
}, wm: 3, wn: 1,
},
{
a: []float64{
0.8147, 0.9134, 0.231, -1.65,
0.9058, 0.6324, 0.9, 0.72,
0.1270, 0.0975, 0.1, 1.723,
1.6, 2.8, -3.5, 0.987,
7.231, 9.154, 1.823, 0.9,
}, m: 5, n: 4,
b: []float64{
0.278, 8.635,
0.547, 9.125,
-0.958, -0.762,
1.452, 1.444,
1.999, -7.234,
}, bc: 2,
want: []float64{
1.863006789511373, 44.467887791812750,
-1.127270935407224, -34.073794226035126,
-0.527926457947330, -8.032133759788573,
-0.248621916204897, -2.366366415805275,
}, wm: 4, wn: 2,
},
{
// Test rank-deficient case compared with numpy.
// >>> import numpy as np
// >>> b = np.array([[-2.3181340317357653],
// ... [-0.7146777651358073],
// ... [1.8361340927945298],
// ... [-0.35699930593018775],
// ... [-1.6359508076249094]])
// >>> A = np.array([[-1.7854591879711257, -0.42687285925779594, -0.12730256811265162],
// ... [-0.5728984211439724, -0.10093393134001777, -0.1181901192353067],
// ... [1.2484316018707418, 0.5646683943038734, -0.48229492403243485],
// ... [0.10174927665169475, -0.5805410929482445, 1.3054473231942054],
// ... [-1.134174808195733, -0.4732430202414438, 0.3528489486370508]])
// >>> np.linalg.lstsq(A, b, rcond=None)
// (array([[ 1.21208422],
// [ 0.41541503],
// [-0.18320349]]), array([], dtype=float64), 2, array([2.68451480e+00, 1.52593185e+00, 6.82840229e-17]))
a: []float64{
-1.7854591879711257, -0.42687285925779594, -0.12730256811265162,
-0.5728984211439724, -0.10093393134001777, -0.1181901192353067,
1.2484316018707418, 0.5646683943038734, -0.48229492403243485,
0.10174927665169475, -0.5805410929482445, 1.3054473231942054,
-1.134174808195733, -0.4732430202414438, 0.3528489486370508,
}, m: 5, n: 3,
b: []float64{
-2.3181340317357653,
-0.7146777651358073,
1.8361340927945298,
-0.35699930593018775,
-1.6359508076249094,
}, bc: 1,
rcond: 1e-15,
want: []float64{
1.2120842180372118,
0.4154150318658529,
-0.1832034870198265,
}, wm: 3, wn: 1,
},
{
a: []float64{
0, 0,
0, 0,
}, m: 2, n: 2,
b: []float64{
3,
2,
}, bc: 1,
},
{
a: []float64{
0, 0,
0, 0,
0, 0,
}, m: 3, n: 2,
b: []float64{
3,
2,
1,
}, bc: 1,
},
{
a: []float64{
0, 0, 0,
0, 0, 0,
}, m: 2, n: 3,
b: []float64{
3,
2,
}, bc: 1,
},
} {
a := NewDense(test.m, test.n, test.a)
b := NewDense(test.m, test.bc, test.b)
var want *Dense
if test.want != nil {
want = NewDense(test.wm, test.wn, test.want)
}
var svd SVD
ok := svd.Factorize(a, SVDFull)
if !ok {
t.Errorf("unexpected factorization failure for test %d", i)
continue
}
var x Dense
rank := svd.Rank(test.rcond)
if rank == 0 {
continue
}
svd.SolveTo(&x, b, rank)
if !EqualApprox(&x, want, 1e-12) {
t.Errorf("Solve answer mismatch. Want %v, got %v", want, x)
}
}
// Random Cases.
for i, test := range []struct {
m, n, bc int
rcond float64
}{
{m: 5, n: 5, bc: 1},
{m: 5, n: 10, bc: 1},
{m: 10, n: 5, bc: 1},
{m: 5, n: 5, bc: 7},
{m: 5, n: 10, bc: 7},
{m: 10, n: 5, bc: 7},
{m: 5, n: 5, bc: 12},
{m: 5, n: 10, bc: 12},
{m: 10, n: 5, bc: 12},
} {
m := test.m
n := test.n
bc := test.bc
a := NewDense(m, n, nil)
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
a.Set(i, j, rnd.Float64())
}
}
br := m
b := NewDense(br, bc, nil)
for i := 0; i < br; i++ {
for j := 0; j < bc; j++ {
b.Set(i, j, rnd.Float64())
}
}
var svd SVD
ok := svd.Factorize(a, SVDFull)
if !ok {
t.Errorf("unexpected factorization failure for test %d", i)
continue
}
var x Dense
rank := svd.Rank(test.rcond)
if rank == 0 {
continue
}
svd.SolveTo(&x, b, rank)
// Test that the normal equations hold.
// Aᵀ * A * x = Aᵀ * b
var tmp, lhs, rhs Dense
tmp.Mul(a.T(), a)
lhs.Mul(&tmp, &x)
rhs.Mul(a.T(), b)
if !EqualApprox(&lhs, &rhs, 1e-10) {
t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs)
}
}
}
func TestSVDSolveVecTo(t *testing.T) {
t.Parallel()
rnd := rand.New(rand.NewSource(1))
// Hand-coded cases.
for i, test := range []struct {
a []float64
m, n int
b []float64
rcond float64
want []float64
}{
{
a: []float64{6}, m: 1, n: 1,
b: []float64{3},
want: []float64{0.5},
},
{
a: []float64{
1, 0, 0,
0, 1, 0,
0, 0, 1,
}, m: 3, n: 3,
b: []float64{3, 2, 1},
want: []float64{3, 2, 1},
},
{
a: []float64{
0.8147, 0.9134, 0.5528,
0.9058, 0.6324, 0.8723,
0.1270, 0.0975, 0.7612,
}, m: 3, n: 3,
b: []float64{0.278, 0.547, 0.958},
want: []float64{-0.932687281002860, 0.303963920182067, 1.375216503507109},
},
{
a: []float64{
0.8147, 0.9134, 0.5528,
0.9058, 0.6324, 0.8723,
}, m: 2, n: 3,
b: []float64{0.278, 0.547},
want: []float64{0.25919787248965376, -0.25560256266441034, 0.5432324059702451},
},
{
a: []float64{
0.8147, 0.9134, 0.9,
0.9058, 0.6324, 0.9,
0.1270, 0.0975, 0.1,
1.6, 2.8, -3.5,
}, m: 4, n: 3,
b: []float64{0.278, 0.547, -0.958, 1.452},
want: []float64{0.820970340787782, -0.218604626527306, -0.212938815234215},
},
{
// Test rank-deficient case compared with numpy.
// >>> import numpy as np
// >>> b = np.array([[-2.3181340317357653],
// ... [-0.7146777651358073],
// ... [1.8361340927945298],
// ... [-0.35699930593018775],
// ... [-1.6359508076249094]])
// >>> A = np.array([[-1.7854591879711257, -0.42687285925779594, -0.12730256811265162],
// ... [-0.5728984211439724, -0.10093393134001777, -0.1181901192353067],
// ... [1.2484316018707418, 0.5646683943038734, -0.48229492403243485],
// ... [0.10174927665169475, -0.5805410929482445, 1.3054473231942054],
// ... [-1.134174808195733, -0.4732430202414438, 0.3528489486370508]])
// >>> np.linalg.lstsq(A, b, rcond=None)
// (array([[ 1.21208422],
// [ 0.41541503],
// [-0.18320349]]), array([], dtype=float64), 2, array([2.68451480e+00, 1.52593185e+00, 6.82840229e-17]))
a: []float64{
-1.7854591879711257, -0.42687285925779594, -0.12730256811265162,
-0.5728984211439724, -0.10093393134001777, -0.1181901192353067,
1.2484316018707418, 0.5646683943038734, -0.48229492403243485,
0.10174927665169475, -0.5805410929482445, 1.3054473231942054,
-1.134174808195733, -0.4732430202414438, 0.3528489486370508,
}, m: 5, n: 3,
b: []float64{-2.3181340317357653, -0.7146777651358073, 1.8361340927945298, -0.35699930593018775, -1.6359508076249094},
rcond: 1e-15,
want: []float64{1.2120842180372118, 0.4154150318658529, -0.1832034870198265},
},
{
a: []float64{
0, 0,
0, 0,
}, m: 2, n: 2,
b: []float64{3, 2},
},
{
a: []float64{
0, 0,
0, 0,
0, 0,
}, m: 3, n: 2,
b: []float64{3, 2, 1},
},
{
a: []float64{
0, 0, 0,
0, 0, 0,
}, m: 2, n: 3,
b: []float64{3, 2},
},
} {
a := NewDense(test.m, test.n, test.a)
b := NewVecDense(len(test.b), test.b)
var want *VecDense
if test.want != nil {
want = NewVecDense(len(test.want), test.want)
}
var svd SVD
ok := svd.Factorize(a, SVDFull)
if !ok {
t.Errorf("unexpected factorization failure for test %d", i)
continue
}
var x VecDense
rank := svd.Rank(test.rcond)
if rank == 0 {
continue
}
svd.SolveVecTo(&x, b, rank)
if !EqualApprox(&x, want, 1e-12) {
t.Errorf("Solve answer mismatch. Want %v, got %v", want, x)
}
}
// Random Cases.
for i, test := range []struct {
m, n int
rcond float64
}{
{m: 5, n: 5},
{m: 5, n: 10},
{m: 10, n: 5},
{m: 5, n: 5},
{m: 5, n: 10},
{m: 10, n: 5},
{m: 5, n: 5},
{m: 5, n: 10},
{m: 10, n: 5},
} {
m := test.m
n := test.n
a := NewDense(m, n, nil)
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
a.Set(i, j, rnd.Float64())
}
}
br := m
b := NewVecDense(br, nil)
for i := 0; i < br; i++ {
b.SetVec(i, rnd.Float64())
}
var svd SVD
ok := svd.Factorize(a, SVDFull)
if !ok {
t.Errorf("unexpected factorization failure for test %d", i)
continue
}
var x VecDense
rank := svd.Rank(test.rcond)
if rank == 0 {
continue
}
svd.SolveVecTo(&x, b, rank)
// Test that the normal equations hold.
// Aᵀ * A * x = Aᵀ * b
var tmp, lhs, rhs Dense
tmp.Mul(a.T(), a)
lhs.Mul(&tmp, &x)
rhs.Mul(a.T(), b)
if !EqualApprox(&lhs, &rhs, 1e-10) {
t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs)
}
}
}