From 028ce68c53fb89bc62d0a6476b97a841f457bb5c Mon Sep 17 00:00:00 2001 From: Egon Elbre Date: Fri, 13 Mar 2020 17:26:39 +0200 Subject: [PATCH] blas/gonum: improve dgemmParallel performance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit name old time/op new time/op delta DgemmSmSmSm-32 820ns ± 1% 823ns ± 0% ~ (p=0.127 n=5+5) DgemmMedMedMed-32 137µs ± 1% 139µs ± 0% +1.12% (p=0.008 n=5+5) DgemmMedLgMed-32 463µs ± 0% 450µs ± 0% -2.88% (p=0.008 n=5+5) DgemmLgLgLg-32 25.0ms ± 1% 24.9ms ± 0% ~ (p=1.000 n=5+5) DgemmLgSmLg-32 685µs ± 1% 694µs ± 1% +1.40% (p=0.008 n=5+5) DgemmLgLgSm-32 808µs ± 1% 761µs ± 0% -5.77% (p=0.008 n=5+5) DgemmHgHgSm-32 71.7ms ± 0% 68.5ms ± 0% -4.40% (p=0.008 n=5+5) DgemmMedMedMedTNT-32 345µs ±10% 228µs ± 1% -33.97% (p=0.008 n=5+5) DgemmMedMedMedNTT-32 142µs ± 0% 149µs ± 1% +5.05% (p=0.008 n=5+5) DgemmMedMedMedTT-32 584µs ±33% 417µs ± 4% -28.48% (p=0.008 n=5+5) --- blas/gonum/dgemm.go | 53 +++++++++++++++------------------------------ blas/gonum/gonum.go | 6 ----- blas/gonum/sgemm.go | 53 +++++++++++++++------------------------------ 3 files changed, 34 insertions(+), 78 deletions(-) diff --git a/blas/gonum/dgemm.go b/blas/gonum/dgemm.go index 167dd27c..9ebf6b2a 100644 --- a/blas/gonum/dgemm.go +++ b/blas/gonum/dgemm.go @@ -158,31 +158,24 @@ func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []f return } - nWorkers := runtime.GOMAXPROCS(0) - if parBlocks < nWorkers { - nWorkers = parBlocks - } - // There is a tradeoff between the workers having to wait for work - // and a large buffer making operations slow. - buf := buffMul * nWorkers - if buf > parBlocks { - buf = parBlocks - } + // workerLimit acts a number of maximum concurrent workers, + // with the limit set to the number of procs available. + workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0)) - sendChan := make(chan subMul, buf) - - // Launch workers. A worker receives an {i, j} submatrix of c, and computes - // A_ik B_ki (or the transposed version) storing the result in c_ij. When the - // channel is finally closed, it signals to the waitgroup that it has finished - // computing. + // wg is used to wait for all var wg sync.WaitGroup - for i := 0; i < nWorkers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for sub := range sendChan { - i := sub.i - j := sub.j + wg.Add(parBlocks) + defer wg.Wait() + + for i := 0; i < m; i += blockSize { + for j := 0; j < n; j += blockSize { + workerLimit <- struct{}{} + go func(i, j int) { + defer func() { + wg.Done() + <-workerLimit + }() + leni := blockSize if i+leni > m { leni = m - i @@ -213,21 +206,9 @@ func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []f } dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha) } - } - }() - } - - // Send out all of the {i, j} subblocks for computation. - for i := 0; i < m; i += blockSize { - for j := 0; j < n; j += blockSize { - sendChan <- subMul{ - i: i, - j: j, - } + }(i, j) } } - close(sendChan) - wg.Wait() } // dgemmSerial is serial matrix multiply diff --git a/blas/gonum/gonum.go b/blas/gonum/gonum.go index 8ab8d43e..602a8f3e 100644 --- a/blas/gonum/gonum.go +++ b/blas/gonum/gonum.go @@ -19,14 +19,8 @@ type Implementation struct{} const ( blockSize = 64 // b x b matrix minParBlock = 4 // minimum number of blocks needed to go parallel - buffMul = 4 // how big is the buffer relative to the number of workers ) -// subMul is a common type shared by [SD]gemm. -type subMul struct { - i, j int // index of block -} - func max(a, b int) int { if a > b { return a diff --git a/blas/gonum/sgemm.go b/blas/gonum/sgemm.go index 079b94ce..7514c6c3 100644 --- a/blas/gonum/sgemm.go +++ b/blas/gonum/sgemm.go @@ -162,31 +162,24 @@ func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []f return } - nWorkers := runtime.GOMAXPROCS(0) - if parBlocks < nWorkers { - nWorkers = parBlocks - } - // There is a tradeoff between the workers having to wait for work - // and a large buffer making operations slow. - buf := buffMul * nWorkers - if buf > parBlocks { - buf = parBlocks - } + // workerLimit acts a number of maximum concurrent workers, + // with the limit set to the number of procs available. + workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0)) - sendChan := make(chan subMul, buf) - - // Launch workers. A worker receives an {i, j} submatrix of c, and computes - // A_ik B_ki (or the transposed version) storing the result in c_ij. When the - // channel is finally closed, it signals to the waitgroup that it has finished - // computing. + // wg is used to wait for all var wg sync.WaitGroup - for i := 0; i < nWorkers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for sub := range sendChan { - i := sub.i - j := sub.j + wg.Add(parBlocks) + defer wg.Wait() + + for i := 0; i < m; i += blockSize { + for j := 0; j < n; j += blockSize { + workerLimit <- struct{}{} + go func(i, j int) { + defer func() { + wg.Done() + <-workerLimit + }() + leni := blockSize if i+leni > m { leni = m - i @@ -217,21 +210,9 @@ func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []f } sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha) } - } - }() - } - - // Send out all of the {i, j} subblocks for computation. - for i := 0; i < m; i += blockSize { - for j := 0; j < n; j += blockSize { - sendChan <- subMul{ - i: i, - j: j, - } + }(i, j) } } - close(sendChan) - wg.Wait() } // sgemmSerial is serial matrix multiply