mat: add TriBandDense.SolveVec and SolveVecTo
diff --git a/mat/triband.go b/mat/triband.go
index 41975a0..6afcb97 100644
--- a/mat/triband.go
+++ b/mat/triband.go
@@ -5,8 +5,11 @@
package mat
import (
+ "math"
+
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
+ "gonum.org/v1/gonum/lapack/lapack64"
)
var (
@@ -458,6 +461,64 @@
return tr
}
+// SolveTo solves a triangular system T * X = B or Tᵀ * X = B where T is an
+// n×n triangular band matrix represented by the receiver and B is a given
+// n×nrhs matrix. If T is non-singular, the result will be stored into dst and
+// nil will be returned. If T is singular, the contents of dst will be undefined
+// and a Condition error will be returned.
+func (t *TriBandDense) SolveTo(dst *Dense, trans bool, b Matrix) error {
+ n, nrhs := b.Dims()
+ if n != t.mat.N {
+ panic(ErrShape)
+ }
+ if b, ok := b.(RawMatrixer); ok && dst != b {
+ dst.checkOverlap(b.RawMatrix())
+ }
+ dst.reuseAsNonZeroed(n, nrhs)
+ if dst != b {
+ dst.Copy(b)
+ }
+ var ok bool
+ if trans {
+ ok = lapack64.Tbtrs(blas.Trans, t.mat, dst.mat)
+ } else {
+ ok = lapack64.Tbtrs(blas.NoTrans, t.mat, dst.mat)
+ }
+ if !ok {
+ return Condition(math.Inf(1))
+ }
+ return nil
+}
+
+// SolveVecTo solves a triangular system T * x = b or Tᵀ * x = b where T is an
+// n×n triangular band matrix represented by the receiver and b is a given
+// n-vector. If T is non-singular, the result will be stored into dst and nil
+// will be returned. If T is singular, the contents of dst will be undefined and
+// a Condition error will be returned.
+func (t *TriBandDense) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
+ n, nrhs := b.Dims()
+ if n != t.mat.N || nrhs != 1 {
+ panic(ErrShape)
+ }
+ if b, ok := b.(RawVectorer); ok && dst != b {
+ dst.checkOverlap(b.RawVector())
+ }
+ dst.reuseAsNonZeroed(n)
+ if dst != b {
+ dst.CopyVec(b)
+ }
+ var ok bool
+ if trans {
+ ok = lapack64.Tbtrs(blas.Trans, t.mat, dst.asGeneral())
+ } else {
+ ok = lapack64.Tbtrs(blas.NoTrans, t.mat, dst.asGeneral())
+ }
+ if !ok {
+ return Condition(math.Inf(1))
+ }
+ return nil
+}
+
func copySymBandIntoTriBand(dst *TriBandDense, s SymBanded) {
n, k, upper := dst.TriBand()
ns, ks := s.SymBand()
diff --git a/mat/triband_test.go b/mat/triband_test.go
index 6757a69..c48de93 100644
--- a/mat/triband_test.go
+++ b/mat/triband_test.go
@@ -5,6 +5,7 @@
package mat
import (
+ "fmt"
"reflect"
"testing"
@@ -414,3 +415,163 @@
testDiagView(t, cas, test)
}
}
+
+func TestTriBandDenseSolveTo(t *testing.T) {
+ t.Parallel()
+
+ const tol = 1e-15
+
+ for tc, test := range []struct {
+ a *TriBandDense
+ b *Dense
+ }{
+ {
+ a: NewTriBandDense(5, 2, Upper, []float64{
+ -0.34, -0.49, -0.51,
+ -0.25, -0.5, 1.03,
+ -1.1, 0.3, -0.82,
+ 1.69, 0.69, -2.22,
+ -0.62, 1.22, -0.85,
+ }),
+ b: NewDense(5, 2, []float64{
+ 0.44, 1.34,
+ 0.07, -1.45,
+ -0.32, -0.88,
+ -0.09, -0.15,
+ -1.17, -0.19,
+ }),
+ },
+ {
+ a: NewTriBandDense(5, 2, Lower, []float64{
+ 0, 0, -0.34,
+ 0, -0.49, -0.25,
+ -0.51, -0.5, -1.1,
+ 1.03, 0.3, 1.69,
+ -0.82, 0.69, -0.62,
+ }),
+ b: NewDense(5, 2, []float64{
+ 0.44, 1.34,
+ 0.07, -1.45,
+ -0.32, -0.88,
+ -0.09, -0.15,
+ -1.17, -0.19,
+ }),
+ },
+ } {
+ a := test.a
+ for _, trans := range []bool{false, true} {
+ for _, dstSameAsB := range []bool{false, true} {
+ name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB)
+
+ n, nrhs := test.b.Dims()
+ var dst Dense
+ var err error
+ if dstSameAsB {
+ dst = *NewDense(n, nrhs, nil)
+ dst.Copy(test.b)
+ err = a.SolveTo(&dst, trans, &dst)
+ } else {
+ tmp := NewDense(n, nrhs, nil)
+ tmp.Copy(test.b)
+ err = a.SolveTo(&dst, trans, asBasicMatrix(tmp))
+ }
+
+ if err != nil {
+ t.Fatalf("%v: unexpected error from SolveTo", name)
+ }
+
+ var resid Dense
+ if trans {
+ resid.Mul(a.T(), &dst)
+ } else {
+ resid.Mul(a, &dst)
+ }
+ resid.Sub(&resid, test.b)
+ diff := Norm(&resid, 1)
+ if diff > tol {
+ t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol)
+ }
+ }
+ }
+ }
+}
+
+func TestTriBandDenseSolveVecTo(t *testing.T) {
+ t.Parallel()
+
+ const tol = 1e-15
+
+ for tc, test := range []struct {
+ a *TriBandDense
+ b *VecDense
+ }{
+ {
+ a: NewTriBandDense(5, 2, Upper, []float64{
+ -0.34, -0.49, -0.51,
+ -0.25, -0.5, 1.03,
+ -1.1, 0.3, -0.82,
+ 1.69, 0.69, -2.22,
+ -0.62, 1.22, -0.85,
+ }),
+ b: NewVecDense(5, []float64{
+ 0.44,
+ 0.07,
+ -0.32,
+ -0.09,
+ -1.17,
+ }),
+ },
+ {
+ a: NewTriBandDense(5, 2, Lower, []float64{
+ 0, 0, -0.34,
+ 0, -0.49, -0.25,
+ -0.51, -0.5, -1.1,
+ 1.03, 0.3, 1.69,
+ -0.82, 0.69, -0.62,
+ }),
+ b: NewVecDense(5, []float64{
+ 0.44,
+ 0.07,
+ -0.32,
+ -0.09,
+ -1.17,
+ }),
+ },
+ } {
+ a := test.a
+ for _, trans := range []bool{false, true} {
+ for _, dstSameAsB := range []bool{false, true} {
+ name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB)
+
+ n, _ := test.b.Dims()
+ var dst VecDense
+ var err error
+ if dstSameAsB {
+ dst = *NewVecDense(n, nil)
+ dst.CopyVec(test.b)
+ err = a.SolveVecTo(&dst, trans, &dst)
+ } else {
+ tmp := NewVecDense(n, nil)
+ tmp.CopyVec(test.b)
+ err = a.SolveVecTo(&dst, trans, asBasicVector(tmp))
+ }
+
+ if err != nil {
+ t.Fatalf("%v: unexpected error from SolveVecTo", name)
+ }
+
+ var resid VecDense
+ if trans {
+ resid.MulVec(a.T(), &dst)
+ } else {
+ resid.MulVec(a, &dst)
+ }
+ resid.SubVec(&resid, test.b)
+ diff := Norm(&resid, 1)
+ if diff > tol {
+ t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol)
+ }
+ }
+ }
+ }
+}