mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 15:16:59 +08:00
blas/gonum: improve dgemmParallel performance
name old time/op new time/op delta DgemmSmSmSm-32 820ns ± 1% 823ns ± 0% ~ (p=0.127 n=5+5) DgemmMedMedMed-32 137µs ± 1% 139µs ± 0% +1.12% (p=0.008 n=5+5) DgemmMedLgMed-32 463µs ± 0% 450µs ± 0% -2.88% (p=0.008 n=5+5) DgemmLgLgLg-32 25.0ms ± 1% 24.9ms ± 0% ~ (p=1.000 n=5+5) DgemmLgSmLg-32 685µs ± 1% 694µs ± 1% +1.40% (p=0.008 n=5+5) DgemmLgLgSm-32 808µs ± 1% 761µs ± 0% -5.77% (p=0.008 n=5+5) DgemmHgHgSm-32 71.7ms ± 0% 68.5ms ± 0% -4.40% (p=0.008 n=5+5) DgemmMedMedMedTNT-32 345µs ±10% 228µs ± 1% -33.97% (p=0.008 n=5+5) DgemmMedMedMedNTT-32 142µs ± 0% 149µs ± 1% +5.05% (p=0.008 n=5+5) DgemmMedMedMedTT-32 584µs ±33% 417µs ± 4% -28.48% (p=0.008 n=5+5)
This commit is contained in:

committed by
Dan Kortschak

parent
19ac2540b2
commit
028ce68c53
@@ -158,31 +158,24 @@ func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []f
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nWorkers := runtime.GOMAXPROCS(0)
|
// workerLimit acts a number of maximum concurrent workers,
|
||||||
if parBlocks < nWorkers {
|
// with the limit set to the number of procs available.
|
||||||
nWorkers = parBlocks
|
workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
|
||||||
}
|
|
||||||
// There is a tradeoff between the workers having to wait for work
|
|
||||||
// and a large buffer making operations slow.
|
|
||||||
buf := buffMul * nWorkers
|
|
||||||
if buf > parBlocks {
|
|
||||||
buf = parBlocks
|
|
||||||
}
|
|
||||||
|
|
||||||
sendChan := make(chan subMul, buf)
|
// wg is used to wait for all
|
||||||
|
|
||||||
// Launch workers. A worker receives an {i, j} submatrix of c, and computes
|
|
||||||
// A_ik B_ki (or the transposed version) storing the result in c_ij. When the
|
|
||||||
// channel is finally closed, it signals to the waitgroup that it has finished
|
|
||||||
// computing.
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < nWorkers; i++ {
|
wg.Add(parBlocks)
|
||||||
wg.Add(1)
|
defer wg.Wait()
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
for i := 0; i < m; i += blockSize {
|
||||||
for sub := range sendChan {
|
for j := 0; j < n; j += blockSize {
|
||||||
i := sub.i
|
workerLimit <- struct{}{}
|
||||||
j := sub.j
|
go func(i, j int) {
|
||||||
|
defer func() {
|
||||||
|
wg.Done()
|
||||||
|
<-workerLimit
|
||||||
|
}()
|
||||||
|
|
||||||
leni := blockSize
|
leni := blockSize
|
||||||
if i+leni > m {
|
if i+leni > m {
|
||||||
leni = m - i
|
leni = m - i
|
||||||
@@ -213,21 +206,9 @@ func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []f
|
|||||||
}
|
}
|
||||||
dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
|
dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
|
||||||
}
|
}
|
||||||
}
|
}(i, j)
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send out all of the {i, j} subblocks for computation.
|
|
||||||
for i := 0; i < m; i += blockSize {
|
|
||||||
for j := 0; j < n; j += blockSize {
|
|
||||||
sendChan <- subMul{
|
|
||||||
i: i,
|
|
||||||
j: j,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
close(sendChan)
|
|
||||||
wg.Wait()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// dgemmSerial is serial matrix multiply
|
// dgemmSerial is serial matrix multiply
|
||||||
|
@@ -19,14 +19,8 @@ type Implementation struct{}
|
|||||||
const (
|
const (
|
||||||
blockSize = 64 // b x b matrix
|
blockSize = 64 // b x b matrix
|
||||||
minParBlock = 4 // minimum number of blocks needed to go parallel
|
minParBlock = 4 // minimum number of blocks needed to go parallel
|
||||||
buffMul = 4 // how big is the buffer relative to the number of workers
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// subMul is a common type shared by [SD]gemm.
|
|
||||||
type subMul struct {
|
|
||||||
i, j int // index of block
|
|
||||||
}
|
|
||||||
|
|
||||||
func max(a, b int) int {
|
func max(a, b int) int {
|
||||||
if a > b {
|
if a > b {
|
||||||
return a
|
return a
|
||||||
|
@@ -162,31 +162,24 @@ func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []f
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nWorkers := runtime.GOMAXPROCS(0)
|
// workerLimit acts a number of maximum concurrent workers,
|
||||||
if parBlocks < nWorkers {
|
// with the limit set to the number of procs available.
|
||||||
nWorkers = parBlocks
|
workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
|
||||||
}
|
|
||||||
// There is a tradeoff between the workers having to wait for work
|
|
||||||
// and a large buffer making operations slow.
|
|
||||||
buf := buffMul * nWorkers
|
|
||||||
if buf > parBlocks {
|
|
||||||
buf = parBlocks
|
|
||||||
}
|
|
||||||
|
|
||||||
sendChan := make(chan subMul, buf)
|
// wg is used to wait for all
|
||||||
|
|
||||||
// Launch workers. A worker receives an {i, j} submatrix of c, and computes
|
|
||||||
// A_ik B_ki (or the transposed version) storing the result in c_ij. When the
|
|
||||||
// channel is finally closed, it signals to the waitgroup that it has finished
|
|
||||||
// computing.
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < nWorkers; i++ {
|
wg.Add(parBlocks)
|
||||||
wg.Add(1)
|
defer wg.Wait()
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
for i := 0; i < m; i += blockSize {
|
||||||
for sub := range sendChan {
|
for j := 0; j < n; j += blockSize {
|
||||||
i := sub.i
|
workerLimit <- struct{}{}
|
||||||
j := sub.j
|
go func(i, j int) {
|
||||||
|
defer func() {
|
||||||
|
wg.Done()
|
||||||
|
<-workerLimit
|
||||||
|
}()
|
||||||
|
|
||||||
leni := blockSize
|
leni := blockSize
|
||||||
if i+leni > m {
|
if i+leni > m {
|
||||||
leni = m - i
|
leni = m - i
|
||||||
@@ -217,21 +210,9 @@ func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []f
|
|||||||
}
|
}
|
||||||
sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
|
sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
|
||||||
}
|
}
|
||||||
}
|
}(i, j)
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send out all of the {i, j} subblocks for computation.
|
|
||||||
for i := 0; i < m; i += blockSize {
|
|
||||||
for j := 0; j < n; j += blockSize {
|
|
||||||
sendChan <- subMul{
|
|
||||||
i: i,
|
|
||||||
j: j,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
close(sendChan)
|
|
||||||
wg.Wait()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sgemmSerial is serial matrix multiply
|
// sgemmSerial is serial matrix multiply
|
||||||
|
Reference in New Issue
Block a user