| // Copyright ©2014 The Gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package gonum |
| |
| import ( |
| "testing" |
| |
| "golang.org/x/exp/rand" |
| |
| "gonum.org/v1/gonum/blas" |
| ) |
| |
| func TestDgemmParallel(t *testing.T) { |
| for i, test := range []struct { |
| m int |
| n int |
| k int |
| alpha float64 |
| tA blas.Transpose |
| tB blas.Transpose |
| }{ |
| { |
| m: 3, |
| n: 4, |
| k: 2, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| { |
| m: blockSize*2 + 5, |
| n: 3, |
| k: 2, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| { |
| m: 3, |
| n: blockSize * 2, |
| k: 2, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| { |
| m: 2, |
| n: 3, |
| k: blockSize*3 - 2, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| { |
| m: blockSize * minParBlock, |
| n: 3, |
| k: 2, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| { |
| m: 3, |
| n: blockSize * minParBlock, |
| k: 2, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| { |
| m: 2, |
| n: 3, |
| k: blockSize * minParBlock, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| { |
| m: blockSize*minParBlock + 1, |
| n: blockSize * minParBlock, |
| k: 3, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| { |
| m: 3, |
| n: blockSize*minParBlock + 2, |
| k: blockSize * 3, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| { |
| m: blockSize * minParBlock, |
| n: 3, |
| k: blockSize * minParBlock, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| { |
| m: blockSize * minParBlock, |
| n: blockSize * minParBlock, |
| k: blockSize * 3, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| { |
| m: blockSize + blockSize/2, |
| n: blockSize + blockSize/2, |
| k: blockSize + blockSize/2, |
| alpha: 2.5, |
| tA: blas.NoTrans, |
| tB: blas.NoTrans, |
| }, |
| } { |
| testMatchParallelSerial(t, i, blas.NoTrans, blas.NoTrans, test.m, test.n, test.k, test.alpha) |
| testMatchParallelSerial(t, i, blas.Trans, blas.NoTrans, test.m, test.n, test.k, test.alpha) |
| testMatchParallelSerial(t, i, blas.NoTrans, blas.Trans, test.m, test.n, test.k, test.alpha) |
| testMatchParallelSerial(t, i, blas.Trans, blas.Trans, test.m, test.n, test.k, test.alpha) |
| } |
| } |
| |
| func testMatchParallelSerial(t *testing.T, i int, tA, tB blas.Transpose, m, n, k int, alpha float64) { |
| var ( |
| rowA, colA int |
| rowB, colB int |
| ) |
| if tA == blas.NoTrans { |
| rowA = m |
| colA = k |
| } else { |
| rowA = k |
| colA = m |
| } |
| if tB == blas.NoTrans { |
| rowB = k |
| colB = n |
| } else { |
| rowB = n |
| colB = k |
| } |
| a := randmat(rowA, colA, colA) |
| b := randmat(rowB, colB, colB) |
| c := randmat(m, n, n) |
| |
| aClone := a.clone() |
| bClone := b.clone() |
| cClone := c.clone() |
| |
| lda := colA |
| ldb := colB |
| ldc := n |
| dgemmSerial(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, cClone.data, ldc, alpha) |
| dgemmParallel(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, c.data, ldc, alpha) |
| if !a.equal(aClone) { |
| t.Errorf("Case %v: a changed during call to dgemmParallel", i) |
| } |
| if !b.equal(bClone) { |
| t.Errorf("Case %v: b changed during call to dgemmParallel", i) |
| } |
| if !c.equalWithinAbs(cClone, 1e-12) { |
| t.Errorf("Case %v: answer not equal parallel and serial", i) |
| } |
| } |
| |
| func randmat(r, c, stride int) general64 { |
| data := make([]float64, r*stride+c) |
| for i := range data { |
| data[i] = rand.Float64() |
| } |
| return general64{ |
| data: data, |
| rows: r, |
| cols: c, |
| stride: stride, |
| } |
| } |