diff --git a/blas/gonum/dgemm.go b/blas/gonum/dgemm.go index 167dd27c..9ebf6b2a 100644 --- a/blas/gonum/dgemm.go +++ b/blas/gonum/dgemm.go @@ -158,31 +158,24 @@ func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []f return } - nWorkers := runtime.GOMAXPROCS(0) - if parBlocks < nWorkers { - nWorkers = parBlocks - } - // There is a tradeoff between the workers having to wait for work - // and a large buffer making operations slow. - buf := buffMul * nWorkers - if buf > parBlocks { - buf = parBlocks - } + // workerLimit acts a number of maximum concurrent workers, + // with the limit set to the number of procs available. + workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0)) - sendChan := make(chan subMul, buf) - - // Launch workers. A worker receives an {i, j} submatrix of c, and computes - // A_ik B_ki (or the transposed version) storing the result in c_ij. When the - // channel is finally closed, it signals to the waitgroup that it has finished - // computing. + // wg is used to wait for all var wg sync.WaitGroup - for i := 0; i < nWorkers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for sub := range sendChan { - i := sub.i - j := sub.j + wg.Add(parBlocks) + defer wg.Wait() + + for i := 0; i < m; i += blockSize { + for j := 0; j < n; j += blockSize { + workerLimit <- struct{}{} + go func(i, j int) { + defer func() { + wg.Done() + <-workerLimit + }() + leni := blockSize if i+leni > m { leni = m - i @@ -213,21 +206,9 @@ func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []f } dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha) } - } - }() - } - - // Send out all of the {i, j} subblocks for computation. - for i := 0; i < m; i += blockSize { - for j := 0; j < n; j += blockSize { - sendChan <- subMul{ - i: i, - j: j, - } + }(i, j) } } - close(sendChan) - wg.Wait() } // dgemmSerial is serial matrix multiply diff --git a/blas/gonum/gonum.go b/blas/gonum/gonum.go index 8ab8d43e..602a8f3e 100644 --- a/blas/gonum/gonum.go +++ b/blas/gonum/gonum.go @@ -19,14 +19,8 @@ type Implementation struct{} const ( blockSize = 64 // b x b matrix minParBlock = 4 // minimum number of blocks needed to go parallel - buffMul = 4 // how big is the buffer relative to the number of workers ) -// subMul is a common type shared by [SD]gemm. -type subMul struct { - i, j int // index of block -} - func max(a, b int) int { if a > b { return a diff --git a/blas/gonum/sgemm.go b/blas/gonum/sgemm.go index 079b94ce..7514c6c3 100644 --- a/blas/gonum/sgemm.go +++ b/blas/gonum/sgemm.go @@ -162,31 +162,24 @@ func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []f return } - nWorkers := runtime.GOMAXPROCS(0) - if parBlocks < nWorkers { - nWorkers = parBlocks - } - // There is a tradeoff between the workers having to wait for work - // and a large buffer making operations slow. - buf := buffMul * nWorkers - if buf > parBlocks { - buf = parBlocks - } + // workerLimit acts a number of maximum concurrent workers, + // with the limit set to the number of procs available. + workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0)) - sendChan := make(chan subMul, buf) - - // Launch workers. A worker receives an {i, j} submatrix of c, and computes - // A_ik B_ki (or the transposed version) storing the result in c_ij. When the - // channel is finally closed, it signals to the waitgroup that it has finished - // computing. + // wg is used to wait for all var wg sync.WaitGroup - for i := 0; i < nWorkers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for sub := range sendChan { - i := sub.i - j := sub.j + wg.Add(parBlocks) + defer wg.Wait() + + for i := 0; i < m; i += blockSize { + for j := 0; j < n; j += blockSize { + workerLimit <- struct{}{} + go func(i, j int) { + defer func() { + wg.Done() + <-workerLimit + }() + leni := blockSize if i+leni > m { leni = m - i @@ -217,21 +210,9 @@ func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []f } sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha) } - } - }() - } - - // Send out all of the {i, j} subblocks for computation. - for i := 0; i < m; i += blockSize { - for j := 0; j < n; j += blockSize { - sendChan <- subMul{ - i: i, - j: j, - } + }(i, j) } } - close(sendChan) - wg.Wait() } // sgemmSerial is serial matrix multiply