mirror of
https://github.com/gonum/gonum.git
synced 2025-10-04 14:52:57 +08:00
blas/gonum: improve dgemmParallel performance
name old time/op new time/op delta DgemmSmSmSm-32 820ns ± 1% 823ns ± 0% ~ (p=0.127 n=5+5) DgemmMedMedMed-32 137µs ± 1% 139µs ± 0% +1.12% (p=0.008 n=5+5) DgemmMedLgMed-32 463µs ± 0% 450µs ± 0% -2.88% (p=0.008 n=5+5) DgemmLgLgLg-32 25.0ms ± 1% 24.9ms ± 0% ~ (p=1.000 n=5+5) DgemmLgSmLg-32 685µs ± 1% 694µs ± 1% +1.40% (p=0.008 n=5+5) DgemmLgLgSm-32 808µs ± 1% 761µs ± 0% -5.77% (p=0.008 n=5+5) DgemmHgHgSm-32 71.7ms ± 0% 68.5ms ± 0% -4.40% (p=0.008 n=5+5) DgemmMedMedMedTNT-32 345µs ±10% 228µs ± 1% -33.97% (p=0.008 n=5+5) DgemmMedMedMedNTT-32 142µs ± 0% 149µs ± 1% +5.05% (p=0.008 n=5+5) DgemmMedMedMedTT-32 584µs ±33% 417µs ± 4% -28.48% (p=0.008 n=5+5)
This commit is contained in:

committed by
Dan Kortschak

parent
19ac2540b2
commit
028ce68c53
@@ -158,31 +158,24 @@ func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []f
|
||||
return
|
||||
}
|
||||
|
||||
nWorkers := runtime.GOMAXPROCS(0)
|
||||
if parBlocks < nWorkers {
|
||||
nWorkers = parBlocks
|
||||
}
|
||||
// There is a tradeoff between the workers having to wait for work
|
||||
// and a large buffer making operations slow.
|
||||
buf := buffMul * nWorkers
|
||||
if buf > parBlocks {
|
||||
buf = parBlocks
|
||||
}
|
||||
// workerLimit acts a number of maximum concurrent workers,
|
||||
// with the limit set to the number of procs available.
|
||||
workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
|
||||
|
||||
sendChan := make(chan subMul, buf)
|
||||
|
||||
// Launch workers. A worker receives an {i, j} submatrix of c, and computes
|
||||
// A_ik B_ki (or the transposed version) storing the result in c_ij. When the
|
||||
// channel is finally closed, it signals to the waitgroup that it has finished
|
||||
// computing.
|
||||
// wg is used to wait for all
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < nWorkers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for sub := range sendChan {
|
||||
i := sub.i
|
||||
j := sub.j
|
||||
wg.Add(parBlocks)
|
||||
defer wg.Wait()
|
||||
|
||||
for i := 0; i < m; i += blockSize {
|
||||
for j := 0; j < n; j += blockSize {
|
||||
workerLimit <- struct{}{}
|
||||
go func(i, j int) {
|
||||
defer func() {
|
||||
wg.Done()
|
||||
<-workerLimit
|
||||
}()
|
||||
|
||||
leni := blockSize
|
||||
if i+leni > m {
|
||||
leni = m - i
|
||||
@@ -213,21 +206,9 @@ func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []f
|
||||
}
|
||||
dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Send out all of the {i, j} subblocks for computation.
|
||||
for i := 0; i < m; i += blockSize {
|
||||
for j := 0; j < n; j += blockSize {
|
||||
sendChan <- subMul{
|
||||
i: i,
|
||||
j: j,
|
||||
}
|
||||
}(i, j)
|
||||
}
|
||||
}
|
||||
close(sendChan)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// dgemmSerial is serial matrix multiply
|
||||
|
@@ -19,14 +19,8 @@ type Implementation struct{}
|
||||
const (
|
||||
blockSize = 64 // b x b matrix
|
||||
minParBlock = 4 // minimum number of blocks needed to go parallel
|
||||
buffMul = 4 // how big is the buffer relative to the number of workers
|
||||
)
|
||||
|
||||
// subMul is a common type shared by [SD]gemm.
|
||||
type subMul struct {
|
||||
i, j int // index of block
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
|
@@ -162,31 +162,24 @@ func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []f
|
||||
return
|
||||
}
|
||||
|
||||
nWorkers := runtime.GOMAXPROCS(0)
|
||||
if parBlocks < nWorkers {
|
||||
nWorkers = parBlocks
|
||||
}
|
||||
// There is a tradeoff between the workers having to wait for work
|
||||
// and a large buffer making operations slow.
|
||||
buf := buffMul * nWorkers
|
||||
if buf > parBlocks {
|
||||
buf = parBlocks
|
||||
}
|
||||
// workerLimit acts a number of maximum concurrent workers,
|
||||
// with the limit set to the number of procs available.
|
||||
workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
|
||||
|
||||
sendChan := make(chan subMul, buf)
|
||||
|
||||
// Launch workers. A worker receives an {i, j} submatrix of c, and computes
|
||||
// A_ik B_ki (or the transposed version) storing the result in c_ij. When the
|
||||
// channel is finally closed, it signals to the waitgroup that it has finished
|
||||
// computing.
|
||||
// wg is used to wait for all
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < nWorkers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for sub := range sendChan {
|
||||
i := sub.i
|
||||
j := sub.j
|
||||
wg.Add(parBlocks)
|
||||
defer wg.Wait()
|
||||
|
||||
for i := 0; i < m; i += blockSize {
|
||||
for j := 0; j < n; j += blockSize {
|
||||
workerLimit <- struct{}{}
|
||||
go func(i, j int) {
|
||||
defer func() {
|
||||
wg.Done()
|
||||
<-workerLimit
|
||||
}()
|
||||
|
||||
leni := blockSize
|
||||
if i+leni > m {
|
||||
leni = m - i
|
||||
@@ -217,21 +210,9 @@ func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []f
|
||||
}
|
||||
sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Send out all of the {i, j} subblocks for computation.
|
||||
for i := 0; i < m; i += blockSize {
|
||||
for j := 0; j < n; j += blockSize {
|
||||
sendChan <- subMul{
|
||||
i: i,
|
||||
j: j,
|
||||
}
|
||||
}(i, j)
|
||||
}
|
||||
}
|
||||
close(sendChan)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// sgemmSerial is serial matrix multiply
|
||||
|
Reference in New Issue
Block a user