blob: 0c9bc281cfbcb7e0e475195d66d0ca0660bdeebe [file] [log] [blame]
 // Copyright ©2015 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package testlapack import ( "testing" "golang.org/x/exp/rand" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/blas/blas64" ) type Dgetrier interface { Dgetrfer Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) bool } func DgetriTest(t *testing.T, impl Dgetrier) { const tol = 1e-13 rnd := rand.New(rand.NewSource(1)) bi := blas64.Implementation() for _, test := range []struct { n, lda int }{ {5, 0}, {5, 8}, {45, 0}, {45, 50}, {63, 70}, {64, 70}, {65, 0}, {65, 70}, {66, 70}, {150, 0}, {150, 250}, } { n := test.n lda := test.lda if lda == 0 { lda = n } // Generate a random well conditioned matrix perm := rnd.Perm(n) a := make([]float64, n*lda) for i := 0; i < n; i++ { a[i*lda+perm[i]] = 1 } for i := range a { a[i] += 0.01 * rnd.Float64() } aCopy := make([]float64, len(a)) copy(aCopy, a) ipiv := make([]int, n) // Compute LU decomposition. impl.Dgetrf(n, n, a, lda, ipiv) // Test with various workspace sizes. for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} { ainv := make([]float64, len(a)) copy(ainv, a) var lwork int switch wl { case minimumWork: lwork = max(1, n) case mediumWork: work := make([]float64, 1) impl.Dgetri(n, ainv, lda, ipiv, work, -1) lwork = max(int(work[0])-2*n, n) case optimumWork: work := make([]float64, 1) impl.Dgetri(n, ainv, lda, ipiv, work, -1) lwork = int(work[0]) } work := make([]float64, lwork) // Compute inverse. ok := impl.Dgetri(n, ainv, lda, ipiv, work, lwork) if !ok { t.Errorf("Unexpected singular matrix.") } // Check that A(inv) * A = I. ans := make([]float64, len(ainv)) bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, aCopy, lda, ainv, lda, 0, ans, lda) // The tolerance is so high because computing matrix inverses is very unstable. dist := distFromIdentity(n, ans, lda) if dist > tol { t.Errorf("|Inv(A) * A - I|_inf = %v is too large. n = %v, lda = %v", dist, n, lda) } } } }