blob: c48de9379cbcec5d72b05449ad4f729da5791e77 [file] [log] [blame]
// Copyright ©2018 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat
import (
"fmt"
"reflect"
"testing"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
)
func TestNewTriBand(t *testing.T) {
t.Parallel()
for cas, test := range []struct {
data []float64
n, k int
kind TriKind
mat *TriBandDense
dense *Dense
}{
{
data: []float64{1, 2, 3},
n: 3, k: 0,
kind: Upper,
mat: &TriBandDense{
mat: blas64.TriangularBand{
Diag: blas.NonUnit,
Uplo: blas.Upper,
N: 3, K: 0,
Data: []float64{1, 2, 3},
Stride: 1,
},
},
dense: NewDense(3, 3, []float64{
1, 0, 0,
0, 2, 0,
0, 0, 3,
}),
},
{
data: []float64{
1, 2,
3, 4,
5, 6,
7, 8,
9, 10,
11, -1,
},
n: 6, k: 1,
kind: Upper,
mat: &TriBandDense{
mat: blas64.TriangularBand{
Diag: blas.NonUnit,
Uplo: blas.Upper,
N: 6, K: 1,
Data: []float64{
1, 2,
3, 4,
5, 6,
7, 8,
9, 10,
11, -1,
},
Stride: 2,
},
},
dense: NewDense(6, 6, []float64{
1, 2, 0, 0, 0, 0,
0, 3, 4, 0, 0, 0,
0, 0, 5, 6, 0, 0,
0, 0, 0, 7, 8, 0,
0, 0, 0, 0, 9, 10,
0, 0, 0, 0, 0, 11,
}),
},
{
data: []float64{
1, 2, 3,
4, 5, 6,
7, 8, 9,
10, 11, 12,
13, 14, -1,
15, -1, -1,
},
n: 6, k: 2,
kind: Upper,
mat: &TriBandDense{
mat: blas64.TriangularBand{
Diag: blas.NonUnit,
Uplo: blas.Upper,
N: 6, K: 2,
Data: []float64{
1, 2, 3,
4, 5, 6,
7, 8, 9,
10, 11, 12,
13, 14, -1,
15, -1, -1,
},
Stride: 3,
},
},
dense: NewDense(6, 6, []float64{
1, 2, 3, 0, 0, 0,
0, 4, 5, 6, 0, 0,
0, 0, 7, 8, 9, 0,
0, 0, 0, 10, 11, 12,
0, 0, 0, 0, 13, 14,
0, 0, 0, 0, 0, 15,
}),
},
{
data: []float64{
-1, 1,
2, 3,
4, 5,
6, 7,
8, 9,
10, 11,
},
n: 6, k: 1,
kind: Lower,
mat: &TriBandDense{
mat: blas64.TriangularBand{
Diag: blas.NonUnit,
Uplo: blas.Lower,
N: 6, K: 1,
Data: []float64{
-1, 1,
2, 3,
4, 5,
6, 7,
8, 9,
10, 11,
},
Stride: 2,
},
},
dense: NewDense(6, 6, []float64{
1, 0, 0, 0, 0, 0,
2, 3, 0, 0, 0, 0,
0, 4, 5, 0, 0, 0,
0, 0, 6, 7, 0, 0,
0, 0, 0, 8, 9, 0,
0, 0, 0, 0, 10, 11,
}),
},
{
data: []float64{
-1, -1, 1,
-1, 2, 3,
4, 5, 6,
7, 8, 9,
10, 11, 12,
13, 14, 15,
},
n: 6, k: 2,
kind: Lower,
mat: &TriBandDense{
mat: blas64.TriangularBand{
Diag: blas.NonUnit,
Uplo: blas.Lower,
N: 6, K: 2,
Data: []float64{
-1, -1, 1,
-1, 2, 3,
4, 5, 6,
7, 8, 9,
10, 11, 12,
13, 14, 15,
},
Stride: 3,
},
},
dense: NewDense(6, 6, []float64{
1, 0, 0, 0, 0, 0,
2, 3, 0, 0, 0, 0,
4, 5, 6, 0, 0, 0,
0, 7, 8, 9, 0, 0,
0, 0, 10, 11, 12, 0,
0, 0, 0, 13, 14, 15,
}),
},
} {
triBand := NewTriBandDense(test.n, test.k, test.kind, test.data)
r, c := triBand.Dims()
n, k, kind := triBand.TriBand()
if n != test.n {
t.Errorf("unexpected triband size for test %d: got: %d want: %d", cas, n, test.n)
}
if k != test.k {
t.Errorf("unexpected triband bandwidth for test %d: got: %d want: %d", cas, k, test.k)
}
if kind != test.kind {
t.Errorf("unexpected triband bandwidth for test %v: got: %v want: %v", cas, kind, test.kind)
}
if r != n {
t.Errorf("unexpected number of rows for test %d: got: %d want: %d", cas, r, n)
}
if c != n {
t.Errorf("unexpected number of cols for test %d: got: %d want: %d", cas, c, n)
}
if !reflect.DeepEqual(triBand, test.mat) {
t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", cas, triBand, test.mat)
}
if !Equal(triBand, test.mat) {
t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", cas, triBand, test.mat)
}
if !Equal(triBand, test.dense) {
t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", cas, Formatted(triBand), Formatted(test.dense))
}
}
}
func TestTriBandAtSetUpper(t *testing.T) {
t.Parallel()
for _, kind := range []TriKind{Upper, Lower} {
var band *TriBandDense
var data []float64
if kind {
// 1 2 3 0 0 0
// 0 4 5 6 0 0
// 0 0 7 8 9 0
// 0 0 0 10 11 12
// 0 0 0 0 13 14
// 0 0 0 0 0 15
data = []float64{
1, 2, 3,
4, 5, 6,
7, 8, 9,
10, 11, 12,
13, 14, -1,
15, -1, -1,
}
band = NewTriBandDense(6, 2, kind, data)
} else {
// 1 0 0 0 0 0
// 2 3 0 0 0 0
// 4 5 6 0 0 0
// 0 7 8 9 0 0
// 0 0 10 11 12 0
// 0 0 0 13 14 15
data = []float64{
-1, -1, 1,
-1, 2, 3,
4, 5, 6,
7, 8, 9,
10, 11, 12,
13, 14, 15,
}
band = NewTriBandDense(6, 2, kind, data)
}
rows, cols := band.Dims()
// Check At out of bounds.
for _, row := range []int{-1, rows, rows + 1} {
panicked, message := panics(func() { band.At(row, 0) })
if !panicked || message != ErrRowAccess.Error() {
t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
}
}
for _, col := range []int{-1, cols, cols + 1} {
panicked, message := panics(func() { band.At(0, col) })
if !panicked || message != ErrColAccess.Error() {
t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
}
}
// Check Set out of bounds
// First, check outside the matrix bounds.
for _, row := range []int{-1, rows, rows + 1} {
panicked, message := panics(func() { band.SetTriBand(row, 0, 1.2) })
if !panicked || message != ErrRowAccess.Error() {
t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
}
}
for _, col := range []int{-1, cols, cols + 1} {
panicked, message := panics(func() { band.SetTriBand(0, col, 1.2) })
if !panicked || message != ErrColAccess.Error() {
t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
}
}
// Next, check outside the Triangular bounds.
for _, s := range []struct{ r, c int }{
{3, 2},
} {
if kind == Lower {
s.r, s.c = s.c, s.r
}
panicked, message := panics(func() { band.SetTriBand(s.r, s.c, 1.2) })
if !panicked || message != ErrTriangleSet.Error() {
t.Errorf("expected panic for invalid triangular access N=%d, r=%d c=%d", cols, s.r, s.c)
}
}
// Finally, check inside the triangle, but outside the band.
for _, s := range []struct{ r, c int }{
{1, 5},
} {
if kind == Lower {
s.r, s.c = s.c, s.r
}
panicked, message := panics(func() { band.SetTriBand(s.r, s.c, 1.2) })
if !panicked || message != ErrBandSet.Error() {
t.Errorf("expected panic for invalid triangular access N=%d, r=%d c=%d", cols, s.r, s.c)
}
}
// Test that At and Set work correctly.
offset := 100.0
dataCopy := make([]float64, len(data))
copy(dataCopy, data)
for i := 0; i < rows; i++ {
for j := 0; j < rows; j++ {
v := band.At(i, j)
if v != 0 {
band.SetTriBand(i, j, v+offset)
}
}
}
for i, v := range dataCopy {
if v == -1 {
if data[i] != -1 {
t.Errorf("Set changed unexpected entry. Want %v, got %v", -1, data[i])
}
} else {
if v != data[i]-offset {
t.Errorf("Set incorrectly changed for %v. got %v, want %v", v, data[i], v+offset)
}
}
}
}
}
func TestTriBandDenseZero(t *testing.T) {
t.Parallel()
// Elements that equal 1 should be set to zero, elements that equal -1
// should remain unchanged.
for _, test := range []*TriBandDense{
{
mat: blas64.TriangularBand{
Uplo: blas.Upper,
N: 6,
K: 2,
Stride: 5,
Data: []float64{
1, 1, 1, -1, -1,
1, 1, 1, -1, -1,
1, 1, 1, -1, -1,
1, 1, 1, -1, -1,
1, 1, -1, -1, -1,
1, -1, -1, -1, -1,
},
},
},
{
mat: blas64.TriangularBand{
Uplo: blas.Lower,
N: 6,
K: 2,
Stride: 5,
Data: []float64{
-1, -1, 1, -1, -1,
-1, 1, 1, -1, -1,
1, 1, 1, -1, -1,
1, 1, 1, -1, -1,
1, 1, 1, -1, -1,
1, 1, 1, -1, -1,
},
},
},
} {
dataCopy := make([]float64, len(test.mat.Data))
copy(dataCopy, test.mat.Data)
test.Zero()
for i, v := range test.mat.Data {
if dataCopy[i] != -1 && v != 0 {
t.Errorf("Matrix not zeroed in bounds")
}
if dataCopy[i] == -1 && v != -1 {
t.Errorf("Matrix zeroed out of bounds")
}
}
}
}
func TestTriBandDiagView(t *testing.T) {
t.Parallel()
for cas, test := range []*TriBandDense{
NewTriBandDense(1, 0, Upper, []float64{1}),
NewTriBandDense(4, 0, Upper, []float64{1, 2, 3, 4}),
NewTriBandDense(6, 2, Upper, []float64{
1, 2, 3,
4, 5, 6,
7, 8, 9,
10, 11, 12,
13, 14, -1,
15, -1, -1,
}),
NewTriBandDense(1, 0, Lower, []float64{1}),
NewTriBandDense(4, 0, Lower, []float64{1, 2, 3, 4}),
NewTriBandDense(6, 2, Lower, []float64{
-1, -1, 1,
-1, 2, 3,
4, 5, 6,
7, 8, 9,
10, 11, 12,
13, 14, 15,
}),
} {
testDiagView(t, cas, test)
}
}
func TestTriBandDenseSolveTo(t *testing.T) {
t.Parallel()
const tol = 1e-15
for tc, test := range []struct {
a *TriBandDense
b *Dense
}{
{
a: NewTriBandDense(5, 2, Upper, []float64{
-0.34, -0.49, -0.51,
-0.25, -0.5, 1.03,
-1.1, 0.3, -0.82,
1.69, 0.69, -2.22,
-0.62, 1.22, -0.85,
}),
b: NewDense(5, 2, []float64{
0.44, 1.34,
0.07, -1.45,
-0.32, -0.88,
-0.09, -0.15,
-1.17, -0.19,
}),
},
{
a: NewTriBandDense(5, 2, Lower, []float64{
0, 0, -0.34,
0, -0.49, -0.25,
-0.51, -0.5, -1.1,
1.03, 0.3, 1.69,
-0.82, 0.69, -0.62,
}),
b: NewDense(5, 2, []float64{
0.44, 1.34,
0.07, -1.45,
-0.32, -0.88,
-0.09, -0.15,
-1.17, -0.19,
}),
},
} {
a := test.a
for _, trans := range []bool{false, true} {
for _, dstSameAsB := range []bool{false, true} {
name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB)
n, nrhs := test.b.Dims()
var dst Dense
var err error
if dstSameAsB {
dst = *NewDense(n, nrhs, nil)
dst.Copy(test.b)
err = a.SolveTo(&dst, trans, &dst)
} else {
tmp := NewDense(n, nrhs, nil)
tmp.Copy(test.b)
err = a.SolveTo(&dst, trans, asBasicMatrix(tmp))
}
if err != nil {
t.Fatalf("%v: unexpected error from SolveTo", name)
}
var resid Dense
if trans {
resid.Mul(a.T(), &dst)
} else {
resid.Mul(a, &dst)
}
resid.Sub(&resid, test.b)
diff := Norm(&resid, 1)
if diff > tol {
t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol)
}
}
}
}
}
func TestTriBandDenseSolveVecTo(t *testing.T) {
t.Parallel()
const tol = 1e-15
for tc, test := range []struct {
a *TriBandDense
b *VecDense
}{
{
a: NewTriBandDense(5, 2, Upper, []float64{
-0.34, -0.49, -0.51,
-0.25, -0.5, 1.03,
-1.1, 0.3, -0.82,
1.69, 0.69, -2.22,
-0.62, 1.22, -0.85,
}),
b: NewVecDense(5, []float64{
0.44,
0.07,
-0.32,
-0.09,
-1.17,
}),
},
{
a: NewTriBandDense(5, 2, Lower, []float64{
0, 0, -0.34,
0, -0.49, -0.25,
-0.51, -0.5, -1.1,
1.03, 0.3, 1.69,
-0.82, 0.69, -0.62,
}),
b: NewVecDense(5, []float64{
0.44,
0.07,
-0.32,
-0.09,
-1.17,
}),
},
} {
a := test.a
for _, trans := range []bool{false, true} {
for _, dstSameAsB := range []bool{false, true} {
name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB)
n, _ := test.b.Dims()
var dst VecDense
var err error
if dstSameAsB {
dst = *NewVecDense(n, nil)
dst.CopyVec(test.b)
err = a.SolveVecTo(&dst, trans, &dst)
} else {
tmp := NewVecDense(n, nil)
tmp.CopyVec(test.b)
err = a.SolveVecTo(&dst, trans, asBasicVector(tmp))
}
if err != nil {
t.Fatalf("%v: unexpected error from SolveVecTo", name)
}
var resid VecDense
if trans {
resid.MulVec(a.T(), &dst)
} else {
resid.MulVec(a, &dst)
}
resid.SubVec(&resid, test.b)
diff := Norm(&resid, 1)
if diff > tol {
t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol)
}
}
}
}
}