| // Copyright ©2017 The Gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package mat |
| |
| import ( |
| "fmt" |
| "testing" |
| |
| "golang.org/x/exp/rand" |
| |
| "gonum.org/v1/gonum/floats" |
| "gonum.org/v1/gonum/floats/scalar" |
| ) |
| |
| func TestGSVD(t *testing.T) { |
| t.Parallel() |
| |
| const tol = 1e-10 |
| for _, test := range []struct { |
| m, p, n int |
| }{ |
| {5, 3, 5}, |
| {5, 3, 3}, |
| {3, 3, 5}, |
| {5, 5, 5}, |
| {5, 5, 3}, |
| {3, 5, 5}, |
| {150, 150, 150}, |
| {200, 150, 150}, |
| {150, 150, 200}, |
| {150, 200, 150}, |
| {200, 200, 150}, |
| {150, 200, 200}, |
| } { |
| m := test.m |
| p := test.p |
| n := test.n |
| t.Run(fmt.Sprintf("%v", test), func(t *testing.T) { |
| t.Parallel() |
| |
| rnd := rand.New(rand.NewSource(1)) |
| for trial := 0; trial < 10; trial++ { |
| a := NewDense(m, n, nil) |
| for i := range a.mat.Data { |
| a.mat.Data[i] = rnd.NormFloat64() |
| } |
| aCopy := DenseCopyOf(a) |
| |
| b := NewDense(p, n, nil) |
| for i := range b.mat.Data { |
| b.mat.Data[i] = rnd.NormFloat64() |
| } |
| bCopy := DenseCopyOf(b) |
| |
| // Test Full decomposition. |
| var gsvd GSVD |
| ok := gsvd.Factorize(a, b, GSVDU|GSVDV|GSVDQ) |
| if !ok { |
| t.Errorf("GSVD factorization failed") |
| } |
| if !Equal(a, aCopy) { |
| t.Errorf("A changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ") |
| } |
| if !Equal(b, bCopy) { |
| t.Errorf("B changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ") |
| } |
| c, s, sigma1, sigma2, zeroR, u, v, q := extractGSVD(&gsvd) |
| var ansU, ansV, d1R, d2R Dense |
| ansU.Product(u.T(), a, q) |
| ansV.Product(v.T(), b, q) |
| d1R.Mul(sigma1, zeroR) |
| d2R.Mul(sigma2, zeroR) |
| if !EqualApprox(&ansU, &d1R, tol) { |
| t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nUᵀ * A * Q:\n% 0.2f\nΣ₁ * [ 0 R ]:\n% 0.2f", |
| Formatted(&ansU), Formatted(&d1R)) |
| } |
| if !EqualApprox(&ansV, &d2R, tol) { |
| t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nVᵀ * B *Q:\n% 0.2f\nΣ₂ * [ 0 R ]:\n% 0.2f", |
| Formatted(&d2R), Formatted(&ansV)) |
| } |
| |
| // Check C^2 + S^2 = I. |
| for i := range c { |
| d := c[i]*c[i] + s[i]*s[i] |
| if !scalar.EqualWithinAbsOrRel(d, 1, 1e-14, 1e-14) { |
| t.Errorf("c_%d^2 + s_%d^2 != 1: got: %v", i, i, d) |
| } |
| } |
| |
| // Test None decomposition. |
| ok = gsvd.Factorize(a, b, GSVDNone) |
| if !ok { |
| t.Errorf("GSVD factorization failed") |
| } |
| if !Equal(a, aCopy) { |
| t.Errorf("A changed during call to GSVD with GSVDNone") |
| } |
| if !Equal(b, bCopy) { |
| t.Errorf("B changed during call to GSVD with GSVDNone") |
| } |
| cNone := gsvd.ValuesA(nil) |
| if !floats.EqualApprox(c, cNone, tol) { |
| t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition") |
| } |
| sNone := gsvd.ValuesB(nil) |
| if !floats.EqualApprox(s, sNone, tol) { |
| t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition") |
| } |
| } |
| }) |
| |
| } |
| } |
| |
| func extractGSVD(gsvd *GSVD) (c, s []float64, s1, s2, zR, u, v, q *Dense) { |
| s1 = &Dense{} |
| s2 = &Dense{} |
| zR = &Dense{} |
| u = &Dense{} |
| v = &Dense{} |
| q = &Dense{} |
| gsvd.SigmaATo(s1) |
| gsvd.SigmaBTo(s2) |
| gsvd.ZeroRTo(zR) |
| gsvd.UTo(u) |
| gsvd.VTo(v) |
| gsvd.QTo(q) |
| c = gsvd.ValuesA(nil) |
| s = gsvd.ValuesB(nil) |
| return c, s, s1, s2, zR, u, v, q |
| } |