| // Copyright ©2015 The gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package gonum |
| |
| import ( |
| "math" |
| |
| "gonum.org/v1/gonum/blas" |
| "gonum.org/v1/gonum/blas/blas64" |
| ) |
| |
| // Dlatrs solves a triangular system of equations scaled to prevent overflow. It |
| // solves |
| // A * x = scale * b if trans == blas.NoTrans |
| // A^T * x = scale * b if trans == blas.Trans |
| // where the scale s is set for numeric stability. |
| // |
| // A is an n×n triangular matrix. On entry, the slice x contains the values of |
| // of b, and on exit it contains the solution vector x. |
| // |
| // If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal |
| // part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater |
| // than or equal to the infinity norm, and greater than or equal to the one-norm |
| // otherwise. If normin == false, then cnorm is treated as an output, and is set |
| // to contain the 1-norm of the off-diagonal part of the j^th column of A. |
| // |
| // Dlatrs is an internal routine. It is exported for testing purposes. |
| func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) { |
| if uplo != blas.Upper && uplo != blas.Lower { |
| panic(badUplo) |
| } |
| if trans != blas.Trans && trans != blas.NoTrans { |
| panic(badTrans) |
| } |
| if diag != blas.Unit && diag != blas.NonUnit { |
| panic(badDiag) |
| } |
| upper := uplo == blas.Upper |
| noTrans := trans == blas.NoTrans |
| nonUnit := diag == blas.NonUnit |
| |
| if n < 0 { |
| panic(nLT0) |
| } |
| checkMatrix(n, n, a, lda) |
| checkVector(n, x, 1) |
| checkVector(n, cnorm, 1) |
| |
| if n == 0 { |
| return 0 |
| } |
| smlnum := dlamchS / dlamchP |
| bignum := 1 / smlnum |
| scale = 1 |
| bi := blas64.Implementation() |
| if !normin { |
| if upper { |
| cnorm[0] = 0 |
| for j := 1; j < n; j++ { |
| cnorm[j] = bi.Dasum(j, a[j:], lda) |
| } |
| } else { |
| for j := 0; j < n-1; j++ { |
| cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda) |
| } |
| cnorm[n-1] = 0 |
| } |
| } |
| // Scale the column norms by tscal if the maximum element in cnorm is greater than bignum. |
| imax := bi.Idamax(n, cnorm, 1) |
| tmax := cnorm[imax] |
| var tscal float64 |
| if tmax <= bignum { |
| tscal = 1 |
| } else { |
| tscal = 1 / (smlnum * tmax) |
| bi.Dscal(n, tscal, cnorm, 1) |
| } |
| |
| // Compute a bound on the computed solution vector to see if bi.Dtrsv can be used. |
| j := bi.Idamax(n, x, 1) |
| xmax := math.Abs(x[j]) |
| xbnd := xmax |
| var grow float64 |
| var jfirst, jlast, jinc int |
| if noTrans { |
| if upper { |
| jfirst = n - 1 |
| jlast = -1 |
| jinc = -1 |
| } else { |
| jfirst = 0 |
| jlast = n |
| jinc = 1 |
| } |
| // Compute the growth in A * x = b. |
| if tscal != 1 { |
| grow = 0 |
| goto Solve |
| } |
| if nonUnit { |
| grow = 1 / math.Max(xbnd, smlnum) |
| xbnd = grow |
| for j := jfirst; j != jlast; j += jinc { |
| if grow <= smlnum { |
| goto Solve |
| } |
| tjj := math.Abs(a[j*lda+j]) |
| xbnd = math.Min(xbnd, math.Min(1, tjj)*grow) |
| if tjj+cnorm[j] >= smlnum { |
| grow *= tjj / (tjj + cnorm[j]) |
| } else { |
| grow = 0 |
| } |
| } |
| grow = xbnd |
| } else { |
| grow = math.Min(1, 1/math.Max(xbnd, smlnum)) |
| for j := jfirst; j != jlast; j += jinc { |
| if grow <= smlnum { |
| goto Solve |
| } |
| grow *= 1 / (1 + cnorm[j]) |
| } |
| } |
| } else { |
| if upper { |
| jfirst = 0 |
| jlast = n |
| jinc = 1 |
| } else { |
| jfirst = n - 1 |
| jlast = -1 |
| jinc = -1 |
| } |
| if tscal != 1 { |
| grow = 0 |
| goto Solve |
| } |
| if nonUnit { |
| grow = 1 / (math.Max(xbnd, smlnum)) |
| xbnd = grow |
| for j := jfirst; j != jlast; j += jinc { |
| if grow <= smlnum { |
| goto Solve |
| } |
| xj := 1 + cnorm[j] |
| grow = math.Min(grow, xbnd/xj) |
| tjj := math.Abs(a[j*lda+j]) |
| if xj > tjj { |
| xbnd *= tjj / xj |
| } |
| } |
| grow = math.Min(grow, xbnd) |
| } else { |
| grow = math.Min(1, 1/math.Max(xbnd, smlnum)) |
| for j := jfirst; j != jlast; j += jinc { |
| if grow <= smlnum { |
| goto Solve |
| } |
| xj := 1 + cnorm[j] |
| grow /= xj |
| } |
| } |
| } |
| |
| Solve: |
| if grow*tscal > smlnum { |
| // Use the Level 2 BLAS solve if the reciprocal of the bound on |
| // elements of X is not too small. |
| bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1) |
| if tscal != 1 { |
| bi.Dscal(n, 1/tscal, cnorm, 1) |
| } |
| return scale |
| } |
| |
| // Use a Level 1 BLAS solve, scaling intermediate results. |
| if xmax > bignum { |
| scale = bignum / xmax |
| bi.Dscal(n, scale, x, 1) |
| xmax = bignum |
| } |
| if noTrans { |
| for j := jfirst; j != jlast; j += jinc { |
| xj := math.Abs(x[j]) |
| var tjj, tjjs float64 |
| if nonUnit { |
| tjjs = a[j*lda+j] * tscal |
| } else { |
| tjjs = tscal |
| if tscal == 1 { |
| goto Skip1 |
| } |
| } |
| tjj = math.Abs(tjjs) |
| if tjj > smlnum { |
| if tjj < 1 { |
| if xj > tjj*bignum { |
| rec := 1 / xj |
| bi.Dscal(n, rec, x, 1) |
| scale *= rec |
| xmax *= rec |
| } |
| } |
| x[j] /= tjjs |
| xj = math.Abs(x[j]) |
| } else if tjj > 0 { |
| if xj > tjj*bignum { |
| rec := (tjj * bignum) / xj |
| if cnorm[j] > 1 { |
| rec /= cnorm[j] |
| } |
| bi.Dscal(n, rec, x, 1) |
| scale *= rec |
| xmax *= rec |
| } |
| x[j] /= tjjs |
| xj = math.Abs(x[j]) |
| } else { |
| for i := 0; i < n; i++ { |
| x[i] = 0 |
| } |
| x[j] = 1 |
| xj = 1 |
| scale = 0 |
| xmax = 0 |
| } |
| Skip1: |
| if xj > 1 { |
| rec := 1 / xj |
| if cnorm[j] > (bignum-xmax)*rec { |
| rec *= 0.5 |
| bi.Dscal(n, rec, x, 1) |
| scale *= rec |
| } |
| } else if xj*cnorm[j] > bignum-xmax { |
| bi.Dscal(n, 0.5, x, 1) |
| scale *= 0.5 |
| } |
| if upper { |
| if j > 0 { |
| bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1) |
| i := bi.Idamax(j, x, 1) |
| xmax = math.Abs(x[i]) |
| } |
| } else { |
| if j < n-1 { |
| bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1) |
| i := j + bi.Idamax(n-j-1, x[j+1:], 1) |
| xmax = math.Abs(x[i]) |
| } |
| } |
| } |
| } else { |
| for j := jfirst; j != jlast; j += jinc { |
| xj := math.Abs(x[j]) |
| uscal := tscal |
| rec := 1 / math.Max(xmax, 1) |
| var tjjs float64 |
| if cnorm[j] > (bignum-xj)*rec { |
| rec *= 0.5 |
| if nonUnit { |
| tjjs = a[j*lda+j] * tscal |
| } else { |
| tjjs = tscal |
| } |
| tjj := math.Abs(tjjs) |
| if tjj > 1 { |
| rec = math.Min(1, rec*tjj) |
| uscal /= tjjs |
| } |
| if rec < 1 { |
| bi.Dscal(n, rec, x, 1) |
| scale *= rec |
| xmax *= rec |
| } |
| } |
| var sumj float64 |
| if uscal == 1 { |
| if upper { |
| sumj = bi.Ddot(j, a[j:], lda, x, 1) |
| } else if j < n-1 { |
| sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1) |
| } |
| } else { |
| if upper { |
| for i := 0; i < j; i++ { |
| sumj += (a[i*lda+j] * uscal) * x[i] |
| } |
| } else if j < n { |
| for i := j + 1; i < n; i++ { |
| sumj += (a[i*lda+j] * uscal) * x[i] |
| } |
| } |
| } |
| if uscal == tscal { |
| x[j] -= sumj |
| xj := math.Abs(x[j]) |
| var tjjs float64 |
| if nonUnit { |
| tjjs = a[j*lda+j] * tscal |
| } else { |
| tjjs = tscal |
| if tscal == 1 { |
| goto Skip2 |
| } |
| } |
| tjj := math.Abs(tjjs) |
| if tjj > smlnum { |
| if tjj < 1 { |
| if xj > tjj*bignum { |
| rec = 1 / xj |
| bi.Dscal(n, rec, x, 1) |
| scale *= rec |
| xmax *= rec |
| } |
| } |
| x[j] /= tjjs |
| } else if tjj > 0 { |
| if xj > tjj*bignum { |
| rec = (tjj * bignum) / xj |
| bi.Dscal(n, rec, x, 1) |
| scale *= rec |
| xmax *= rec |
| } |
| x[j] /= tjjs |
| } else { |
| for i := 0; i < n; i++ { |
| x[i] = 0 |
| } |
| x[j] = 1 |
| scale = 0 |
| xmax = 0 |
| } |
| } else { |
| x[j] = x[j]/tjjs - sumj |
| } |
| Skip2: |
| xmax = math.Max(xmax, math.Abs(x[j])) |
| } |
| } |
| scale /= tscal |
| if tscal != 1 { |
| bi.Dscal(n, 1/tscal, cnorm, 1) |
| } |
| return scale |
| } |