blob: ec2f4e051fdff7ed9caa971ab2d66ddc5811d002 [file] [log] [blame]
// Copyright ©2017 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat
import (
"fmt"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/floats/scalar"
)
func TestGSVD(t *testing.T) {
t.Parallel()
const tol = 1e-10
for _, test := range []struct {
m, p, n int
}{
{5, 3, 5},
{5, 3, 3},
{3, 3, 5},
{5, 5, 5},
{5, 5, 3},
{3, 5, 5},
{150, 150, 150},
{200, 150, 150},
{150, 150, 200},
{150, 200, 150},
{200, 200, 150},
{150, 200, 200},
} {
m := test.m
p := test.p
n := test.n
t.Run(fmt.Sprintf("%v", test), func(t *testing.T) {
t.Parallel()
rnd := rand.New(rand.NewSource(1))
for trial := 0; trial < 10; trial++ {
a := NewDense(m, n, nil)
for i := range a.mat.Data {
a.mat.Data[i] = rnd.NormFloat64()
}
aCopy := DenseCopyOf(a)
b := NewDense(p, n, nil)
for i := range b.mat.Data {
b.mat.Data[i] = rnd.NormFloat64()
}
bCopy := DenseCopyOf(b)
// Test Full decomposition.
var gsvd GSVD
ok := gsvd.Factorize(a, b, GSVDU|GSVDV|GSVDQ)
if !ok {
t.Errorf("GSVD factorization failed")
}
if !Equal(a, aCopy) {
t.Errorf("A changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ")
}
if !Equal(b, bCopy) {
t.Errorf("B changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ")
}
c, s, sigma1, sigma2, zeroR, u, v, q := extractGSVD(&gsvd)
var ansU, ansV, d1R, d2R Dense
ansU.Product(u.T(), a, q)
ansV.Product(v.T(), b, q)
d1R.Mul(sigma1, zeroR)
d2R.Mul(sigma2, zeroR)
if !EqualApprox(&ansU, &d1R, tol) {
t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nUᵀ * A * Q:\n% 0.2f\nΣ₁ * [ 0 R ]:\n% 0.2f",
Formatted(&ansU), Formatted(&d1R))
}
if !EqualApprox(&ansV, &d2R, tol) {
t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nVᵀ * B *Q:\n% 0.2f\nΣ₂ * [ 0 R ]:\n% 0.2f",
Formatted(&d2R), Formatted(&ansV))
}
// Check C^2 + S^2 = I.
for i := range c {
d := c[i]*c[i] + s[i]*s[i]
if !scalar.EqualWithinAbsOrRel(d, 1, 1e-14, 1e-14) {
t.Errorf("c_%d^2 + s_%d^2 != 1: got: %v", i, i, d)
}
}
// Test None decomposition.
ok = gsvd.Factorize(a, b, GSVDNone)
if !ok {
t.Errorf("GSVD factorization failed")
}
if !Equal(a, aCopy) {
t.Errorf("A changed during call to GSVD with GSVDNone")
}
if !Equal(b, bCopy) {
t.Errorf("B changed during call to GSVD with GSVDNone")
}
cNone := gsvd.ValuesA(nil)
if !floats.EqualApprox(c, cNone, tol) {
t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition")
}
sNone := gsvd.ValuesB(nil)
if !floats.EqualApprox(s, sNone, tol) {
t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition")
}
}
})
}
}
func extractGSVD(gsvd *GSVD) (c, s []float64, s1, s2, zR, u, v, q *Dense) {
s1 = &Dense{}
s2 = &Dense{}
zR = &Dense{}
u = &Dense{}
v = &Dense{}
q = &Dense{}
gsvd.SigmaATo(s1)
gsvd.SigmaBTo(s2)
gsvd.ZeroRTo(zR)
gsvd.UTo(u)
gsvd.VTo(v)
gsvd.QTo(q)
c = gsvd.ValuesA(nil)
s = gsvd.ValuesB(nil)
return c, s, s1, s2, zR, u, v, q
}