blob: 5e07e07db690f954c356544a6ae1c79d37081fcf [file] [log] [blame]
package testblas
import (
"testing"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/floats"
)
type Dsymmer interface {
Dsymm(s blas.Side, ul blas.Uplo, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
}
func DsymmTest(t *testing.T, blasser Dsymmer) {
for i, test := range []struct {
m int
n int
side blas.Side
ul blas.Uplo
a [][]float64
b [][]float64
c [][]float64
alpha float64
beta float64
ans [][]float64
}{
{
side: blas.Left,
ul: blas.Upper,
m: 3,
n: 4,
a: [][]float64{
{2, 3, 4},
{0, 6, 7},
{0, 0, 10},
},
b: [][]float64{
{2, 3, 4, 8},
{5, 6, 7, 15},
{8, 9, 10, 20},
},
c: [][]float64{
{8, 12, 2, 1},
{9, 12, 9, 9},
{12, 1, -1, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{126, 156, 144, 285},
{211, 252, 275, 535},
{282, 291, 327, 689},
},
},
{
side: blas.Left,
ul: blas.Upper,
m: 4,
n: 3,
a: [][]float64{
{2, 3, 4, 8},
{0, 6, 7, 9},
{0, 0, 10, 10},
{0, 0, 0, 11},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{158, 172, 160},
{247, 270, 293},
{322, 311, 347},
{329, 385, 427},
},
},
{
side: blas.Left,
ul: blas.Lower,
m: 3,
n: 4,
a: [][]float64{
{2, 0, 0},
{3, 6, 0},
{4, 7, 10},
},
b: [][]float64{
{2, 3, 4, 8},
{5, 6, 7, 15},
{8, 9, 10, 20},
},
c: [][]float64{
{8, 12, 2, 1},
{9, 12, 9, 9},
{12, 1, -1, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{126, 156, 144, 285},
{211, 252, 275, 535},
{282, 291, 327, 689},
},
},
{
side: blas.Left,
ul: blas.Lower,
m: 4,
n: 3,
a: [][]float64{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 10, 0},
{8, 9, 10, 11},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{158, 172, 160},
{247, 270, 293},
{322, 311, 347},
{329, 385, 427},
},
},
{
side: blas.Right,
ul: blas.Upper,
m: 3,
n: 4,
a: [][]float64{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 10, 0},
{3, 4, 5, 6},
},
b: [][]float64{
{2, 3, 4, 9},
{5, 6, 7, -3},
{8, 9, 10, -2},
},
c: [][]float64{
{8, 12, 2, 10},
{9, 12, 9, 10},
{12, 1, -1, 10},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{32, 72, 86, 138},
{47, 108, 167, -6},
{68, 111, 197, 6},
},
},
{
side: blas.Right,
ul: blas.Upper,
m: 4,
n: 3,
a: [][]float64{
{2, 0, 0},
{3, 6, 0},
{4, 7, 10},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{32, 72, 86},
{47, 108, 167},
{68, 111, 197},
{11, 39, 35},
},
},
{
side: blas.Right,
ul: blas.Lower,
m: 3,
n: 4,
a: [][]float64{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 10, 0},
{3, 4, 5, 6},
},
b: [][]float64{
{2, 3, 4, 2},
{5, 6, 7, 1},
{8, 9, 10, 1},
},
c: [][]float64{
{8, 12, 2, 1},
{9, 12, 9, 9},
{12, 1, -1, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{94, 156, 164, 103},
{145, 244, 301, 187},
{208, 307, 397, 247},
},
},
{
side: blas.Right,
ul: blas.Lower,
m: 4,
n: 3,
a: [][]float64{
{2, 0, 0},
{3, 6, 0},
{4, 7, 10},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{82, 140, 144},
{139, 236, 291},
{202, 299, 387},
{25, 65, 65},
},
},
} {
aFlat := flatten(test.a)
bFlat := flatten(test.b)
cFlat := flatten(test.c)
ansFlat := flatten(test.ans)
blasser.Dsymm(test.side, test.ul, test.m, test.n, test.alpha, aFlat, len(test.a[0]), bFlat, test.n, test.beta, cFlat, test.n)
if !floats.EqualApprox(cFlat, ansFlat, 1e-14) {
t.Errorf("Case %v: Want %v, got %v.", i, ansFlat, cFlat)
}
}
}