| // Copyright ©2015 The gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package gonum |
| |
| import ( |
| "gonum.org/v1/gonum/blas" |
| "gonum.org/v1/gonum/blas/blas64" |
| "gonum.org/v1/gonum/lapack" |
| ) |
| |
| // Dlarft forms the triangular factor T of a block reflector H, storing the answer |
| // in t. |
| // H = I - V * T * V^T if store == lapack.ColumnWise |
| // H = I - V^T * T * V if store == lapack.RowWise |
| // H is defined by a product of the elementary reflectors where |
| // H = H_0 * H_1 * ... * H_{k-1} if direct == lapack.Forward |
| // H = H_{k-1} * ... * H_1 * H_0 if direct == lapack.Backward |
| // |
| // t is a k×k triangular matrix. t is upper triangular if direct = lapack.Forward |
| // and lower triangular otherwise. This function will panic if t is not of |
| // sufficient size. |
| // |
| // store describes the storage of the elementary reflectors in v. Please see |
| // Dlarfb for a description of layout. |
| // |
| // tau contains the scalar factors of the elementary reflectors H_i. |
| // |
| // Dlarft is an internal routine. It is exported for testing purposes. |
| func (Implementation) Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, |
| v []float64, ldv int, tau []float64, t []float64, ldt int) { |
| if n == 0 { |
| return |
| } |
| if n < 0 || k < 0 { |
| panic(negDimension) |
| } |
| if direct != lapack.Forward && direct != lapack.Backward { |
| panic(badDirect) |
| } |
| if store != lapack.RowWise && store != lapack.ColumnWise { |
| panic(badStore) |
| } |
| if len(tau) < k { |
| panic(badTau) |
| } |
| checkMatrix(k, k, t, ldt) |
| bi := blas64.Implementation() |
| // TODO(btracey): There are a number of minor obvious loop optimizations here. |
| // TODO(btracey): It may be possible to rearrange some of the code so that |
| // index of 1 is more common in the Dgemv. |
| if direct == lapack.Forward { |
| prevlastv := n - 1 |
| for i := 0; i < k; i++ { |
| prevlastv = max(i, prevlastv) |
| if tau[i] == 0 { |
| for j := 0; j <= i; j++ { |
| t[j*ldt+i] = 0 |
| } |
| continue |
| } |
| var lastv int |
| if store == lapack.ColumnWise { |
| // skip trailing zeros |
| for lastv = n - 1; lastv >= i+1; lastv-- { |
| if v[lastv*ldv+i] != 0 { |
| break |
| } |
| } |
| for j := 0; j < i; j++ { |
| t[j*ldt+i] = -tau[i] * v[i*ldv+j] |
| } |
| j := min(lastv, prevlastv) |
| bi.Dgemv(blas.Trans, j-i, i, |
| -tau[i], v[(i+1)*ldv:], ldv, v[(i+1)*ldv+i:], ldv, |
| 1, t[i:], ldt) |
| } else { |
| for lastv = n - 1; lastv >= i+1; lastv-- { |
| if v[i*ldv+lastv] != 0 { |
| break |
| } |
| } |
| for j := 0; j < i; j++ { |
| t[j*ldt+i] = -tau[i] * v[j*ldv+i] |
| } |
| j := min(lastv, prevlastv) |
| bi.Dgemv(blas.NoTrans, i, j-i, |
| -tau[i], v[i+1:], ldv, v[i*ldv+i+1:], 1, |
| 1, t[i:], ldt) |
| } |
| bi.Dtrmv(blas.Upper, blas.NoTrans, blas.NonUnit, i, t, ldt, t[i:], ldt) |
| t[i*ldt+i] = tau[i] |
| if i > 1 { |
| prevlastv = max(prevlastv, lastv) |
| } else { |
| prevlastv = lastv |
| } |
| } |
| return |
| } |
| prevlastv := 0 |
| for i := k - 1; i >= 0; i-- { |
| if tau[i] == 0 { |
| for j := i; j < k; j++ { |
| t[j*ldt+i] = 0 |
| } |
| continue |
| } |
| var lastv int |
| if i < k-1 { |
| if store == lapack.ColumnWise { |
| for lastv = 0; lastv < i; lastv++ { |
| if v[lastv*ldv+i] != 0 { |
| break |
| } |
| } |
| for j := i + 1; j < k; j++ { |
| t[j*ldt+i] = -tau[i] * v[(n-k+i)*ldv+j] |
| } |
| j := max(lastv, prevlastv) |
| bi.Dgemv(blas.Trans, n-k+i-j, k-i-1, |
| -tau[i], v[j*ldv+i+1:], ldv, v[j*ldv+i:], ldv, |
| 1, t[(i+1)*ldt+i:], ldt) |
| } else { |
| for lastv = 0; lastv < i; lastv++ { |
| if v[i*ldv+lastv] != 0 { |
| break |
| } |
| } |
| for j := i + 1; j < k; j++ { |
| t[j*ldt+i] = -tau[i] * v[j*ldv+n-k+i] |
| } |
| j := max(lastv, prevlastv) |
| bi.Dgemv(blas.NoTrans, k-i-1, n-k+i-j, |
| -tau[i], v[(i+1)*ldv+j:], ldv, v[i*ldv+j:], 1, |
| 1, t[(i+1)*ldt+i:], ldt) |
| } |
| bi.Dtrmv(blas.Lower, blas.NoTrans, blas.NonUnit, k-i-1, |
| t[(i+1)*ldt+i+1:], ldt, |
| t[(i+1)*ldt+i:], ldt) |
| if i > 0 { |
| prevlastv = min(prevlastv, lastv) |
| } else { |
| prevlastv = lastv |
| } |
| } |
| t[i*ldt+i] = tau[i] |
| } |
| } |