lapack/gonum: avoid NaN in Dlatrs
diff --git a/lapack/gonum/dlatrs.go b/lapack/gonum/dlatrs.go
index 37ac2fe..57af343 100644
--- a/lapack/gonum/dlatrs.go
+++ b/lapack/gonum/dlatrs.go
@@ -9,6 +9,7 @@
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
+ "gonum.org/v1/gonum/lapack"
)
// Dlatrs solves a triangular system of equations scaled to prevent overflow. It
@@ -43,9 +44,11 @@
panic(badLdA)
}
+ scale = 1
+
// Quick return if possible.
if n == 0 {
- return 0
+ return scale
}
switch {
@@ -62,7 +65,6 @@
smlnum := dlamchS / dlamchP
bignum := 1 / smlnum
- scale = 1
bi := blas64.Implementation()
@@ -86,8 +88,54 @@
if tmax <= bignum {
tscal = 1
} else {
- tscal = 1 / (smlnum * tmax)
- bi.Dscal(n, tscal, cnorm, 1)
+ // Avoid NaN generation if entries in cnorm exceed the overflow
+ // threshold https://github.com/Reference-LAPACK/lapack/issues/714
+ if tmax <= dlamchO {
+ // Case 1: All entries in cnorm are valid floating-point numbers.
+ tscal = 1 / (smlnum * tmax)
+ bi.Dscal(n, tscal, cnorm, 1)
+ } else {
+ // Case 2: At least one column norm of A cannot be represented as
+ // floating-point number. Find the offdiagonal entry A[i,j] with the
+ // largest absolute value. If this entry is not +/- Infinity, use
+ // this value as tscal.
+ tmax = 0
+ if upper {
+ for j := 1; j < n; j++ {
+ tmax = math.Max(impl.Dlange(lapack.MaxAbs, j, 1, a[j:], lda, nil), tmax)
+ }
+ } else {
+ for j := 0; j < n-1; j++ {
+ tmax = math.Max(impl.Dlange(lapack.MaxAbs, n-j-1, 1, a[(j+1)*lda+j:], lda, nil), tmax)
+ }
+ }
+ if tmax <= dlamchO {
+ tscal = 1 / (smlnum * tmax)
+ for j := 0; j < n; j++ {
+ if cnorm[j] <= dlamchO {
+ cnorm[j] *= tscal
+ } else {
+ // Recompute the 1-norm without introducing Infinity in
+ // the summation.
+ cnorm[j] = 0
+ if upper {
+ for i := 0; i < j; i++ {
+ cnorm[j] += tscal * math.Abs(a[i*lda+j])
+ }
+ } else {
+ for i := j + 1; i < n; i++ {
+ cnorm[j] += tscal * math.Abs(a[i*lda+j])
+ }
+ }
+ }
+ }
+ } else {
+ // At least one entry of A is not a valid floating-point entry.
+ // Rely on Dtrsv to propagate Inf and NaN.
+ bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1)
+ return scale
+ }
+ }
}
// Compute a bound on the computed solution vector to see if bi.Dtrsv can be used.
diff --git a/lapack/gonum/lapack.go b/lapack/gonum/lapack.go
index 3b914aa..98167ff 100644
--- a/lapack/gonum/lapack.go
+++ b/lapack/gonum/lapack.go
@@ -4,7 +4,11 @@
package gonum
-import "gonum.org/v1/gonum/lapack"
+import (
+ "math"
+
+ "gonum.org/v1/gonum/lapack"
+)
// Implementation is the native Go implementation of LAPACK routines. It
// is built on top of calls to the return of blas64.Implementation(), so while
@@ -49,6 +53,10 @@
// For IEEE this is 2^{-1022}.
dlamchS = 0x1p-1022
+ // dlamchO is the overflow threshold, the largest number that is not
+ // an infinity.
+ dlamchO = math.MaxFloat64
+
// (rtmin,rtmax) is a range of well-scaled numbers whose square
// or sum of squares is also safe.
// drtmin is sqrt(dlamchS/dlamchP)
diff --git a/lapack/testlapack/dlatrs.go b/lapack/testlapack/dlatrs.go
index fbac6ec..99484f6 100644
--- a/lapack/testlapack/dlatrs.go
+++ b/lapack/testlapack/dlatrs.go
@@ -30,6 +30,9 @@
if n < 6 {
imats = append(imats, 19)
}
+ if n == 3 {
+ imats = append(imats, -1)
+ }
for _, imat := range imats {
testDlatrs(t, impl, imat, uplo, trans, n, lda, rnd)
}
@@ -46,11 +49,29 @@
b := nanSlice(n)
work := make([]float64, 3*n)
- // Generate triangular test matrix and right hand side.
- diag := dlattr(imat, uplo, trans, n, a, lda, b, work, rnd)
- if imat <= 10 {
- // b has not been generated.
- dlarnv(b, 3, rnd)
+ var diag blas.Diag
+ switch imat {
+ default:
+ // Generate triangular test matrix and right hand side.
+ diag = dlattr(imat, uplo, trans, n, a, lda, b, work, rnd)
+ if imat <= 10 {
+ // b has not been generated.
+ dlarnv(b, 3, rnd)
+ }
+ case -1:
+ // Test case from https://github.com/Reference-LAPACK/lapack/issues/714
+ diag = blas.NonUnit
+ v := math.MaxFloat64
+ if uplo == blas.Upper {
+ a[0], a[1], a[2] = v, v, v
+ a[lda+1], a[lda+2] = v, v
+ a[2*lda+2] = v
+ } else {
+ a[0] = v
+ a[lda], a[lda+1] = v, v
+ a[2*lda], a[2*lda+1], a[2*lda+2] = v, v, v
+ }
+ b[0], b[1], b[2] = v, 0, v
}
cnorm := nanSlice(n)