blob: cc03bb389b63e9a9056c58972d93c74a282fccfe [file] [log] [blame]
// Copyright ©2014 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gonum
import (
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/blas"
)
func TestDgemmParallel(t *testing.T) {
for i, test := range []struct {
m int
n int
k int
alpha float64
tA blas.Transpose
tB blas.Transpose
}{
{
m: 3,
n: 4,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize*2 + 5,
n: 3,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 3,
n: blockSize * 2,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 2,
n: 3,
k: blockSize*3 - 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize * minParBlock,
n: 3,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 3,
n: blockSize * minParBlock,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 2,
n: 3,
k: blockSize * minParBlock,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize*minParBlock + 1,
n: blockSize * minParBlock,
k: 3,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 3,
n: blockSize*minParBlock + 2,
k: blockSize * 3,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize * minParBlock,
n: 3,
k: blockSize * minParBlock,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize * minParBlock,
n: blockSize * minParBlock,
k: blockSize * 3,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize + blockSize/2,
n: blockSize + blockSize/2,
k: blockSize + blockSize/2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
} {
testMatchParallelSerial(t, i, blas.NoTrans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
testMatchParallelSerial(t, i, blas.Trans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
testMatchParallelSerial(t, i, blas.NoTrans, blas.Trans, test.m, test.n, test.k, test.alpha)
testMatchParallelSerial(t, i, blas.Trans, blas.Trans, test.m, test.n, test.k, test.alpha)
}
}
func testMatchParallelSerial(t *testing.T, i int, tA, tB blas.Transpose, m, n, k int, alpha float64) {
var (
rowA, colA int
rowB, colB int
)
if tA == blas.NoTrans {
rowA = m
colA = k
} else {
rowA = k
colA = m
}
if tB == blas.NoTrans {
rowB = k
colB = n
} else {
rowB = n
colB = k
}
a := randmat(rowA, colA, colA)
b := randmat(rowB, colB, colB)
c := randmat(m, n, n)
aClone := a.clone()
bClone := b.clone()
cClone := c.clone()
lda := colA
ldb := colB
ldc := n
dgemmSerial(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, cClone.data, ldc, alpha)
dgemmParallel(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, c.data, ldc, alpha)
if !a.equal(aClone) {
t.Errorf("Case %v: a changed during call to dgemmParallel", i)
}
if !b.equal(bClone) {
t.Errorf("Case %v: b changed during call to dgemmParallel", i)
}
if !c.equalWithinAbs(cClone, 1e-12) {
t.Errorf("Case %v: answer not equal parallel and serial", i)
}
}
func randmat(r, c, stride int) general64 {
data := make([]float64, r*stride+c)
for i := range data {
data[i] = rand.Float64()
}
return general64{
data: data,
rows: r,
cols: c,
stride: stride,
}
}