blas/gonum: improve dgemmParallel performance

name                  old time/op  new time/op  delta
DgemmSmSmSm-32         820ns ± 1%   823ns ± 0%     ~     (p=0.127 n=5+5)
DgemmMedMedMed-32      137µs ± 1%   139µs ± 0%   +1.12%  (p=0.008 n=5+5)
DgemmMedLgMed-32       463µs ± 0%   450µs ± 0%   -2.88%  (p=0.008 n=5+5)
DgemmLgLgLg-32        25.0ms ± 1%  24.9ms ± 0%     ~     (p=1.000 n=5+5)
DgemmLgSmLg-32         685µs ± 1%   694µs ± 1%   +1.40%  (p=0.008 n=5+5)
DgemmLgLgSm-32         808µs ± 1%   761µs ± 0%   -5.77%  (p=0.008 n=5+5)
DgemmHgHgSm-32        71.7ms ± 0%  68.5ms ± 0%   -4.40%  (p=0.008 n=5+5)
DgemmMedMedMedTNT-32   345µs ±10%   228µs ± 1%  -33.97%  (p=0.008 n=5+5)
DgemmMedMedMedNTT-32   142µs ± 0%   149µs ± 1%   +5.05%  (p=0.008 n=5+5)
DgemmMedMedMedTT-32    584µs ±33%   417µs ± 4%  -28.48%  (p=0.008 n=5+5)
This commit is contained in:
Egon Elbre
2020-03-13 17:26:39 +02:00
committed by Dan Kortschak
parent 19ac2540b2
commit 028ce68c53
3 changed files with 34 additions and 78 deletions

View File

@@ -158,31 +158,24 @@ func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []f
return
}
nWorkers := runtime.GOMAXPROCS(0)
if parBlocks < nWorkers {
nWorkers = parBlocks
}
// There is a tradeoff between the workers having to wait for work
// and a large buffer making operations slow.
buf := buffMul * nWorkers
if buf > parBlocks {
buf = parBlocks
}
// workerLimit acts a number of maximum concurrent workers,
// with the limit set to the number of procs available.
workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
sendChan := make(chan subMul, buf)
// Launch workers. A worker receives an {i, j} submatrix of c, and computes
// A_ik B_ki (or the transposed version) storing the result in c_ij. When the
// channel is finally closed, it signals to the waitgroup that it has finished
// computing.
// wg is used to wait for all
var wg sync.WaitGroup
for i := 0; i < nWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for sub := range sendChan {
i := sub.i
j := sub.j
wg.Add(parBlocks)
defer wg.Wait()
for i := 0; i < m; i += blockSize {
for j := 0; j < n; j += blockSize {
workerLimit <- struct{}{}
go func(i, j int) {
defer func() {
wg.Done()
<-workerLimit
}()
leni := blockSize
if i+leni > m {
leni = m - i
@@ -213,21 +206,9 @@ func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []f
}
dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
}
}
}()
}
// Send out all of the {i, j} subblocks for computation.
for i := 0; i < m; i += blockSize {
for j := 0; j < n; j += blockSize {
sendChan <- subMul{
i: i,
j: j,
}
}(i, j)
}
}
close(sendChan)
wg.Wait()
}
// dgemmSerial is serial matrix multiply

View File

@@ -19,14 +19,8 @@ type Implementation struct{}
const (
blockSize = 64 // b x b matrix
minParBlock = 4 // minimum number of blocks needed to go parallel
buffMul = 4 // how big is the buffer relative to the number of workers
)
// subMul is a common type shared by [SD]gemm.
type subMul struct {
i, j int // index of block
}
func max(a, b int) int {
if a > b {
return a

View File

@@ -162,31 +162,24 @@ func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []f
return
}
nWorkers := runtime.GOMAXPROCS(0)
if parBlocks < nWorkers {
nWorkers = parBlocks
}
// There is a tradeoff between the workers having to wait for work
// and a large buffer making operations slow.
buf := buffMul * nWorkers
if buf > parBlocks {
buf = parBlocks
}
// workerLimit acts a number of maximum concurrent workers,
// with the limit set to the number of procs available.
workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
sendChan := make(chan subMul, buf)
// Launch workers. A worker receives an {i, j} submatrix of c, and computes
// A_ik B_ki (or the transposed version) storing the result in c_ij. When the
// channel is finally closed, it signals to the waitgroup that it has finished
// computing.
// wg is used to wait for all
var wg sync.WaitGroup
for i := 0; i < nWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for sub := range sendChan {
i := sub.i
j := sub.j
wg.Add(parBlocks)
defer wg.Wait()
for i := 0; i < m; i += blockSize {
for j := 0; j < n; j += blockSize {
workerLimit <- struct{}{}
go func(i, j int) {
defer func() {
wg.Done()
<-workerLimit
}()
leni := blockSize
if i+leni > m {
leni = m - i
@@ -217,21 +210,9 @@ func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []f
}
sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
}
}
}()
}
// Send out all of the {i, j} subblocks for computation.
for i := 0; i < m; i += blockSize {
for j := 0; j < n; j += blockSize {
sendChan <- subMul{
i: i,
j: j,
}
}(i, j)
}
}
close(sendChan)
wg.Wait()
}
// sgemmSerial is serial matrix multiply