| // Copyright ©2015 The gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package gonum |
| |
| import ( |
| "gonum.org/v1/gonum/blas" |
| "gonum.org/v1/gonum/blas/blas64" |
| ) |
| |
| // Dgetri computes the inverse of the matrix A using the LU factorization computed |
| // by Dgetrf. On entry, a contains the PLU decomposition of A as computed by |
| // Dgetrf and on exit contains the reciprocal of the original matrix. |
| // |
| // Dgetri will not perform the inversion if the matrix is singular, and returns |
| // a boolean indicating whether the inversion was successful. |
| // |
| // work is temporary storage, and lwork specifies the usable memory length. |
| // At minimum, lwork >= n and this function will panic otherwise. |
| // Dgetri is a blocked inversion, but the block size is limited |
| // by the temporary space available. If lwork == -1, instead of performing Dgetri, |
| // the optimal work length will be stored into work[0]. |
| func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) { |
| checkMatrix(n, n, a, lda) |
| if len(ipiv) < n { |
| panic(badIpiv) |
| } |
| nb := impl.Ilaenv(1, "DGETRI", " ", n, -1, -1, -1) |
| if lwork == -1 { |
| work[0] = float64(n * nb) |
| return true |
| } |
| if lwork < n { |
| panic(badWork) |
| } |
| if len(work) < lwork { |
| panic(badWork) |
| } |
| if n == 0 { |
| return true |
| } |
| ok = impl.Dtrtri(blas.Upper, blas.NonUnit, n, a, lda) |
| if !ok { |
| return false |
| } |
| nbmin := 2 |
| ldwork := nb |
| if nb > 1 && nb < n { |
| iws := max(ldwork*n, 1) |
| if lwork < iws { |
| nb = lwork / ldwork |
| nbmin = max(2, impl.Ilaenv(2, "DGETRI", " ", n, -1, -1, -1)) |
| } |
| } |
| bi := blas64.Implementation() |
| // TODO(btracey): Replace this with a more row-major oriented algorithm. |
| if nb < nbmin || nb >= n { |
| // Unblocked code. |
| for j := n - 1; j >= 0; j-- { |
| for i := j + 1; i < n; i++ { |
| work[i*ldwork] = a[i*lda+j] |
| a[i*lda+j] = 0 |
| } |
| if j < n { |
| bi.Dgemv(blas.NoTrans, n, n-j-1, -1, a[(j+1):], lda, work[(j+1)*ldwork:], ldwork, 1, a[j:], lda) |
| } |
| } |
| } else { |
| nn := ((n - 1) / nb) * nb |
| for j := nn; j >= 0; j -= nb { |
| jb := min(nb, n-j) |
| for jj := j; jj < j+jb-1; jj++ { |
| for i := jj + 1; i < n; i++ { |
| work[i*ldwork+(jj-j)] = a[i*lda+jj] |
| a[i*lda+jj] = 0 |
| } |
| } |
| if j+jb < n { |
| bi.Dgemm(blas.NoTrans, blas.NoTrans, n, jb, n-j-jb, -1, a[(j+jb):], lda, work[(j+jb)*ldwork:], ldwork, 1, a[j:], lda) |
| bi.Dtrsm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, jb, 1, work[j*ldwork:], ldwork, a[j:], lda) |
| } |
| } |
| } |
| for j := n - 2; j >= 0; j-- { |
| jp := ipiv[j] |
| if jp != j { |
| bi.Dswap(n, a[j:], lda, a[jp:], lda) |
| } |
| } |
| return true |
| } |