| // Copyright ©2015 The gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package testlapack |
| |
| import ( |
| "fmt" |
| "math/rand" |
| "testing" |
| |
| "gonum.org/v1/gonum/blas" |
| "gonum.org/v1/gonum/blas/blas64" |
| "gonum.org/v1/gonum/floats" |
| "gonum.org/v1/gonum/lapack" |
| ) |
| |
| type Dgesvder interface { |
| Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool) |
| } |
| |
| func DgesvdTest(t *testing.T, impl Dgesvder) { |
| rnd := rand.New(rand.NewSource(1)) |
| // TODO(btracey): Add tests for all of the cases when the SVD implementation |
| // is finished. |
| // TODO(btracey): Add tests for m > mnthr and n > mnthr when other SVD |
| // conditions are implemented. Right now mnthr is 5,000,000 which is too |
| // large to create a square matrix of that size. |
| for _, test := range []struct { |
| m, n, lda, ldu, ldvt int |
| }{ |
| {5, 5, 0, 0, 0}, |
| {5, 6, 0, 0, 0}, |
| {6, 5, 0, 0, 0}, |
| {5, 9, 0, 0, 0}, |
| {9, 5, 0, 0, 0}, |
| |
| {5, 5, 10, 11, 12}, |
| {5, 6, 10, 11, 12}, |
| {6, 5, 10, 11, 12}, |
| {5, 5, 10, 11, 12}, |
| {5, 9, 10, 11, 12}, |
| {9, 5, 10, 11, 12}, |
| |
| {300, 300, 0, 0, 0}, |
| {300, 400, 0, 0, 0}, |
| {400, 300, 0, 0, 0}, |
| {300, 600, 0, 0, 0}, |
| {600, 300, 0, 0, 0}, |
| |
| {300, 300, 400, 450, 460}, |
| {300, 400, 500, 550, 560}, |
| {400, 300, 550, 550, 560}, |
| {300, 600, 700, 750, 760}, |
| {600, 300, 700, 750, 760}, |
| } { |
| jobU := lapack.SVDAll |
| jobVT := lapack.SVDAll |
| |
| m := test.m |
| n := test.n |
| lda := test.lda |
| if lda == 0 { |
| lda = n |
| } |
| ldu := test.ldu |
| if ldu == 0 { |
| ldu = m |
| } |
| ldvt := test.ldvt |
| if ldvt == 0 { |
| ldvt = n |
| } |
| |
| a := make([]float64, m*lda) |
| for i := range a { |
| a[i] = rnd.NormFloat64() |
| } |
| |
| u := make([]float64, m*ldu) |
| for i := range u { |
| u[i] = rnd.NormFloat64() |
| } |
| |
| vt := make([]float64, n*ldvt) |
| for i := range vt { |
| vt[i] = rnd.NormFloat64() |
| } |
| |
| uAllOrig := make([]float64, len(u)) |
| copy(uAllOrig, u) |
| vtAllOrig := make([]float64, len(vt)) |
| copy(vtAllOrig, vt) |
| aCopy := make([]float64, len(a)) |
| copy(aCopy, a) |
| |
| s := make([]float64, min(m, n)) |
| |
| work := make([]float64, 1) |
| impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, -1) |
| |
| if !floats.Equal(a, aCopy) { |
| t.Errorf("a changed during call to get work length") |
| } |
| |
| work = make([]float64, int(work[0])) |
| impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work)) |
| |
| errStr := fmt.Sprintf("m = %v, n = %v, lda = %v, ldu = %v, ldv = %v", m, n, lda, ldu, ldvt) |
| svdCheck(t, false, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda) |
| svdCheckPartial(t, impl, lapack.SVDAll, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false) |
| |
| // Test InPlace |
| jobU = lapack.SVDInPlace |
| jobVT = lapack.SVDInPlace |
| copy(a, aCopy) |
| copy(u, uAllOrig) |
| copy(vt, vtAllOrig) |
| |
| impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work)) |
| svdCheck(t, true, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda) |
| svdCheckPartial(t, impl, lapack.SVDInPlace, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false) |
| } |
| } |
| |
| // svdCheckPartial checks that the singular values and vectors are computed when |
| // not all of them are computed. |
| func svdCheckPartial(t *testing.T, impl Dgesvder, job lapack.SVDJob, errStr string, uAllOrig, vtAllOrig, aCopy []float64, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, shortWork bool) { |
| rnd := rand.New(rand.NewSource(1)) |
| jobU := job |
| jobVT := job |
| // Compare the singular values when computed with {SVDNone, SVDNone.} |
| sCopy := make([]float64, len(s)) |
| copy(sCopy, s) |
| copy(a, aCopy) |
| for i := range s { |
| s[i] = rnd.Float64() |
| } |
| tmp1 := make([]float64, 1) |
| tmp2 := make([]float64, 1) |
| jobU = lapack.SVDNone |
| jobVT = lapack.SVDNone |
| |
| impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, -1) |
| work = make([]float64, int(work[0])) |
| lwork := len(work) |
| if shortWork { |
| lwork-- |
| } |
| ok := impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, lwork) |
| if !ok { |
| t.Errorf("Dgesvd did not complete successfully") |
| } |
| if !floats.EqualApprox(s, sCopy, 1e-10) { |
| t.Errorf("Singular value mismatch when singular vectors not computed: %s", errStr) |
| } |
| // Check that the singular vectors are correctly computed when the other |
| // is none. |
| uAll := make([]float64, len(u)) |
| copy(uAll, u) |
| vtAll := make([]float64, len(vt)) |
| copy(vtAll, vt) |
| |
| // Copy the original vectors so the data outside the matrix bounds is the same. |
| copy(u, uAllOrig) |
| copy(vt, vtAllOrig) |
| |
| jobU = job |
| jobVT = lapack.SVDNone |
| copy(a, aCopy) |
| for i := range s { |
| s[i] = rnd.Float64() |
| } |
| impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, -1) |
| work = make([]float64, int(work[0])) |
| lwork = len(work) |
| if shortWork { |
| lwork-- |
| } |
| impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, len(work)) |
| if !floats.EqualApprox(uAll, u, 1e-10) { |
| t.Errorf("U mismatch when VT is not computed: %s", errStr) |
| } |
| if !floats.EqualApprox(s, sCopy, 1e-10) { |
| t.Errorf("Singular value mismatch when U computed VT not") |
| } |
| jobU = lapack.SVDNone |
| jobVT = job |
| copy(a, aCopy) |
| for i := range s { |
| s[i] = rnd.Float64() |
| } |
| impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, -1) |
| work = make([]float64, int(work[0])) |
| lwork = len(work) |
| if shortWork { |
| lwork-- |
| } |
| impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, len(work)) |
| if !floats.EqualApprox(vtAll, vt, 1e-10) { |
| t.Errorf("VT mismatch when U is not computed: %s", errStr) |
| } |
| if !floats.EqualApprox(s, sCopy, 1e-10) { |
| t.Errorf("Singular value mismatch when VT computed U not") |
| } |
| } |
| |
| // svdCheck checks that the singular value decomposition correctly multiplies back |
| // to the original matrix. |
| func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float64, ldu int, vt []float64, ldvt int, aCopy []float64, lda int) { |
| sigma := blas64.General{ |
| Rows: m, |
| Cols: n, |
| Stride: n, |
| Data: make([]float64, m*n), |
| } |
| for i := 0; i < min(m, n); i++ { |
| sigma.Data[i*sigma.Stride+i] = s[i] |
| } |
| |
| uMat := blas64.General{ |
| Rows: m, |
| Cols: m, |
| Stride: ldu, |
| Data: u, |
| } |
| vTMat := blas64.General{ |
| Rows: n, |
| Cols: n, |
| Stride: ldvt, |
| Data: vt, |
| } |
| if thin { |
| sigma.Rows = min(m, n) |
| sigma.Cols = min(m, n) |
| uMat.Cols = min(m, n) |
| vTMat.Rows = min(m, n) |
| } |
| |
| tmp := blas64.General{ |
| Rows: m, |
| Cols: n, |
| Stride: n, |
| Data: make([]float64, m*n), |
| } |
| ans := blas64.General{ |
| Rows: m, |
| Cols: n, |
| Stride: lda, |
| Data: make([]float64, m*lda), |
| } |
| copy(ans.Data, a) |
| |
| blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uMat, sigma, 0, tmp) |
| blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, vTMat, 0, ans) |
| |
| if !floats.EqualApprox(ans.Data, aCopy, 1e-8) { |
| t.Errorf("Decomposition mismatch. Trim = %v, %s", thin, errStr) |
| } |
| |
| if !thin { |
| // Check that U and V are orthogonal. |
| for i := 0; i < uMat.Rows; i++ { |
| for j := i + 1; j < uMat.Rows; j++ { |
| dot := blas64.Dot(uMat.Cols, |
| blas64.Vector{Inc: 1, Data: uMat.Data[i*uMat.Stride:]}, |
| blas64.Vector{Inc: 1, Data: uMat.Data[j*uMat.Stride:]}, |
| ) |
| if dot > 1e-8 { |
| t.Errorf("U not orthogonal %s", errStr) |
| } |
| } |
| } |
| for i := 0; i < vTMat.Rows; i++ { |
| for j := i + 1; j < vTMat.Rows; j++ { |
| dot := blas64.Dot(vTMat.Cols, |
| blas64.Vector{Inc: 1, Data: vTMat.Data[i*vTMat.Stride:]}, |
| blas64.Vector{Inc: 1, Data: vTMat.Data[j*vTMat.Stride:]}, |
| ) |
| if dot > 1e-8 { |
| t.Errorf("V not orthogonal %s", errStr) |
| } |
| } |
| } |
| } |
| } |