| // Code generated by "go generate gonum.org/v1/gonum/blas/gonum”; DO NOT EDIT. |
| |
| // Copyright ©2014 The Gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package gonum |
| |
| import ( |
| "runtime" |
| "sync" |
| |
| "gonum.org/v1/gonum/blas" |
| "gonum.org/v1/gonum/internal/asm/f32" |
| ) |
| |
| // Sgemm performs one of the matrix-matrix operations |
| // C = alpha * A * B + beta * C |
| // C = alpha * Aᵀ * B + beta * C |
| // C = alpha * A * Bᵀ + beta * C |
| // C = alpha * Aᵀ * Bᵀ + beta * C |
| // where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is |
| // an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or |
| // B are transposed. |
| // |
| // Float32 implementations are autogenerated and not directly tested. |
| func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) { |
| switch tA { |
| default: |
| panic(badTranspose) |
| case blas.NoTrans, blas.Trans, blas.ConjTrans: |
| } |
| switch tB { |
| default: |
| panic(badTranspose) |
| case blas.NoTrans, blas.Trans, blas.ConjTrans: |
| } |
| if m < 0 { |
| panic(mLT0) |
| } |
| if n < 0 { |
| panic(nLT0) |
| } |
| if k < 0 { |
| panic(kLT0) |
| } |
| aTrans := tA == blas.Trans || tA == blas.ConjTrans |
| if aTrans { |
| if lda < max(1, m) { |
| panic(badLdA) |
| } |
| } else { |
| if lda < max(1, k) { |
| panic(badLdA) |
| } |
| } |
| bTrans := tB == blas.Trans || tB == blas.ConjTrans |
| if bTrans { |
| if ldb < max(1, k) { |
| panic(badLdB) |
| } |
| } else { |
| if ldb < max(1, n) { |
| panic(badLdB) |
| } |
| } |
| if ldc < max(1, n) { |
| panic(badLdC) |
| } |
| |
| // Quick return if possible. |
| if m == 0 || n == 0 { |
| return |
| } |
| |
| // For zero matrix size the following slice length checks are trivially satisfied. |
| if aTrans { |
| if len(a) < (k-1)*lda+m { |
| panic(shortA) |
| } |
| } else { |
| if len(a) < (m-1)*lda+k { |
| panic(shortA) |
| } |
| } |
| if bTrans { |
| if len(b) < (n-1)*ldb+k { |
| panic(shortB) |
| } |
| } else { |
| if len(b) < (k-1)*ldb+n { |
| panic(shortB) |
| } |
| } |
| if len(c) < (m-1)*ldc+n { |
| panic(shortC) |
| } |
| |
| // Quick return if possible. |
| if (alpha == 0 || k == 0) && beta == 1 { |
| return |
| } |
| |
| // scale c |
| if beta != 1 { |
| if beta == 0 { |
| for i := 0; i < m; i++ { |
| ctmp := c[i*ldc : i*ldc+n] |
| for j := range ctmp { |
| ctmp[j] = 0 |
| } |
| } |
| } else { |
| for i := 0; i < m; i++ { |
| ctmp := c[i*ldc : i*ldc+n] |
| for j := range ctmp { |
| ctmp[j] *= beta |
| } |
| } |
| } |
| } |
| |
| sgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) |
| } |
| |
| func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { |
| // dgemmParallel computes a parallel matrix multiplication by partitioning |
| // a and b into sub-blocks, and updating c with the multiplication of the sub-block |
| // In all cases, |
| // A = [ A_11 A_12 ... A_1j |
| // A_21 A_22 ... A_2j |
| // ... |
| // A_i1 A_i2 ... A_ij] |
| // |
| // and same for B. All of the submatrix sizes are blockSize×blockSize except |
| // at the edges. |
| // |
| // In all cases, there is one dimension for each matrix along which |
| // C must be updated sequentially. |
| // Cij = \sum_k Aik Bki, (A * B) |
| // Cij = \sum_k Aki Bkj, (Aᵀ * B) |
| // Cij = \sum_k Aik Bjk, (A * Bᵀ) |
| // Cij = \sum_k Aki Bjk, (Aᵀ * Bᵀ) |
| // |
| // This code computes one {i, j} block sequentially along the k dimension, |
| // and computes all of the {i, j} blocks concurrently. This |
| // partitioning allows Cij to be updated in-place without race-conditions. |
| // Instead of launching a goroutine for each possible concurrent computation, |
| // a number of worker goroutines are created and channels are used to pass |
| // available and completed cases. |
| // |
| // http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix |
| // multiplies, though this code does not copy matrices to attempt to eliminate |
| // cache misses. |
| |
| maxKLen := k |
| parBlocks := blocks(m, blockSize) * blocks(n, blockSize) |
| if parBlocks < minParBlock { |
| // The matrix multiplication is small in the dimensions where it can be |
| // computed concurrently. Just do it in serial. |
| sgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) |
| return |
| } |
| |
| // workerLimit acts a number of maximum concurrent workers, |
| // with the limit set to the number of procs available. |
| workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0)) |
| |
| // wg is used to wait for all |
| var wg sync.WaitGroup |
| wg.Add(parBlocks) |
| defer wg.Wait() |
| |
| for i := 0; i < m; i += blockSize { |
| for j := 0; j < n; j += blockSize { |
| workerLimit <- struct{}{} |
| go func(i, j int) { |
| defer func() { |
| wg.Done() |
| <-workerLimit |
| }() |
| |
| leni := blockSize |
| if i+leni > m { |
| leni = m - i |
| } |
| lenj := blockSize |
| if j+lenj > n { |
| lenj = n - j |
| } |
| |
| cSub := sliceView32(c, ldc, i, j, leni, lenj) |
| |
| // Compute A_ik B_kj for all k |
| for k := 0; k < maxKLen; k += blockSize { |
| lenk := blockSize |
| if k+lenk > maxKLen { |
| lenk = maxKLen - k |
| } |
| var aSub, bSub []float32 |
| if aTrans { |
| aSub = sliceView32(a, lda, k, i, lenk, leni) |
| } else { |
| aSub = sliceView32(a, lda, i, k, leni, lenk) |
| } |
| if bTrans { |
| bSub = sliceView32(b, ldb, j, k, lenj, lenk) |
| } else { |
| bSub = sliceView32(b, ldb, k, j, lenk, lenj) |
| } |
| sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha) |
| } |
| }(i, j) |
| } |
| } |
| } |
| |
| // sgemmSerial is serial matrix multiply |
| func sgemmSerial(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { |
| switch { |
| case !aTrans && !bTrans: |
| sgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) |
| return |
| case aTrans && !bTrans: |
| sgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) |
| return |
| case !aTrans && bTrans: |
| sgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) |
| return |
| case aTrans && bTrans: |
| sgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) |
| return |
| default: |
| panic("unreachable") |
| } |
| } |
| |
| // sgemmSerial where neither a nor b are transposed |
| func sgemmSerialNotNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { |
| // This style is used instead of the literal [i*stride +j]) is used because |
| // approximately 5 times faster as of go 1.3. |
| for i := 0; i < m; i++ { |
| ctmp := c[i*ldc : i*ldc+n] |
| for l, v := range a[i*lda : i*lda+k] { |
| tmp := alpha * v |
| if tmp != 0 { |
| f32.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp) |
| } |
| } |
| } |
| } |
| |
| // sgemmSerial where neither a is transposed and b is not |
| func sgemmSerialTransNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { |
| // This style is used instead of the literal [i*stride +j]) is used because |
| // approximately 5 times faster as of go 1.3. |
| for l := 0; l < k; l++ { |
| btmp := b[l*ldb : l*ldb+n] |
| for i, v := range a[l*lda : l*lda+m] { |
| tmp := alpha * v |
| if tmp != 0 { |
| ctmp := c[i*ldc : i*ldc+n] |
| f32.AxpyUnitary(tmp, btmp, ctmp) |
| } |
| } |
| } |
| } |
| |
| // sgemmSerial where neither a is not transposed and b is |
| func sgemmSerialNotTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { |
| // This style is used instead of the literal [i*stride +j]) is used because |
| // approximately 5 times faster as of go 1.3. |
| for i := 0; i < m; i++ { |
| atmp := a[i*lda : i*lda+k] |
| ctmp := c[i*ldc : i*ldc+n] |
| for j := 0; j < n; j++ { |
| ctmp[j] += alpha * f32.DotUnitary(atmp, b[j*ldb:j*ldb+k]) |
| } |
| } |
| } |
| |
| // sgemmSerial where both are transposed |
| func sgemmSerialTransTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { |
| // This style is used instead of the literal [i*stride +j]) is used because |
| // approximately 5 times faster as of go 1.3. |
| for l := 0; l < k; l++ { |
| for i, v := range a[l*lda : l*lda+m] { |
| tmp := alpha * v |
| if tmp != 0 { |
| ctmp := c[i*ldc : i*ldc+n] |
| f32.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0) |
| } |
| } |
| } |
| } |
| |
| func sliceView32(a []float32, lda, i, j, r, c int) []float32 { |
| return a[i*lda+j : (i+r-1)*lda+j+c] |
| } |