mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00

name old time/op new time/op delta DgemmSmSmSm-32 820ns ± 1% 823ns ± 0% ~ (p=0.127 n=5+5) DgemmMedMedMed-32 137µs ± 1% 139µs ± 0% +1.12% (p=0.008 n=5+5) DgemmMedLgMed-32 463µs ± 0% 450µs ± 0% -2.88% (p=0.008 n=5+5) DgemmLgLgLg-32 25.0ms ± 1% 24.9ms ± 0% ~ (p=1.000 n=5+5) DgemmLgSmLg-32 685µs ± 1% 694µs ± 1% +1.40% (p=0.008 n=5+5) DgemmLgLgSm-32 808µs ± 1% 761µs ± 0% -5.77% (p=0.008 n=5+5) DgemmHgHgSm-32 71.7ms ± 0% 68.5ms ± 0% -4.40% (p=0.008 n=5+5) DgemmMedMedMedTNT-32 345µs ±10% 228µs ± 1% -33.97% (p=0.008 n=5+5) DgemmMedMedMedNTT-32 142µs ± 0% 149µs ± 1% +5.05% (p=0.008 n=5+5) DgemmMedMedMedTT-32 584µs ±33% 417µs ± 4% -28.48% (p=0.008 n=5+5)
300 lines
7.9 KiB
Go
300 lines
7.9 KiB
Go
// Code generated by "go generate gonum.org/v1/gonum/blas/gonum”; DO NOT EDIT.
|
||
|
||
// Copyright ©2014 The Gonum Authors. All rights reserved.
|
||
// Use of this source code is governed by a BSD-style
|
||
// license that can be found in the LICENSE file.
|
||
|
||
package gonum
|
||
|
||
import (
|
||
"runtime"
|
||
"sync"
|
||
|
||
"gonum.org/v1/gonum/blas"
|
||
"gonum.org/v1/gonum/internal/asm/f32"
|
||
)
|
||
|
||
// Sgemm performs one of the matrix-matrix operations
|
||
// C = alpha * A * B + beta * C
|
||
// C = alpha * Aᵀ * B + beta * C
|
||
// C = alpha * A * Bᵀ + beta * C
|
||
// C = alpha * Aᵀ * Bᵀ + beta * C
|
||
// where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is
|
||
// an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or
|
||
// B are transposed.
|
||
//
|
||
// Float32 implementations are autogenerated and not directly tested.
|
||
func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
|
||
switch tA {
|
||
default:
|
||
panic(badTranspose)
|
||
case blas.NoTrans, blas.Trans, blas.ConjTrans:
|
||
}
|
||
switch tB {
|
||
default:
|
||
panic(badTranspose)
|
||
case blas.NoTrans, blas.Trans, blas.ConjTrans:
|
||
}
|
||
if m < 0 {
|
||
panic(mLT0)
|
||
}
|
||
if n < 0 {
|
||
panic(nLT0)
|
||
}
|
||
if k < 0 {
|
||
panic(kLT0)
|
||
}
|
||
aTrans := tA == blas.Trans || tA == blas.ConjTrans
|
||
if aTrans {
|
||
if lda < max(1, m) {
|
||
panic(badLdA)
|
||
}
|
||
} else {
|
||
if lda < max(1, k) {
|
||
panic(badLdA)
|
||
}
|
||
}
|
||
bTrans := tB == blas.Trans || tB == blas.ConjTrans
|
||
if bTrans {
|
||
if ldb < max(1, k) {
|
||
panic(badLdB)
|
||
}
|
||
} else {
|
||
if ldb < max(1, n) {
|
||
panic(badLdB)
|
||
}
|
||
}
|
||
if ldc < max(1, n) {
|
||
panic(badLdC)
|
||
}
|
||
|
||
// Quick return if possible.
|
||
if m == 0 || n == 0 {
|
||
return
|
||
}
|
||
|
||
// For zero matrix size the following slice length checks are trivially satisfied.
|
||
if aTrans {
|
||
if len(a) < (k-1)*lda+m {
|
||
panic(shortA)
|
||
}
|
||
} else {
|
||
if len(a) < (m-1)*lda+k {
|
||
panic(shortA)
|
||
}
|
||
}
|
||
if bTrans {
|
||
if len(b) < (n-1)*ldb+k {
|
||
panic(shortB)
|
||
}
|
||
} else {
|
||
if len(b) < (k-1)*ldb+n {
|
||
panic(shortB)
|
||
}
|
||
}
|
||
if len(c) < (m-1)*ldc+n {
|
||
panic(shortC)
|
||
}
|
||
|
||
// Quick return if possible.
|
||
if (alpha == 0 || k == 0) && beta == 1 {
|
||
return
|
||
}
|
||
|
||
// scale c
|
||
if beta != 1 {
|
||
if beta == 0 {
|
||
for i := 0; i < m; i++ {
|
||
ctmp := c[i*ldc : i*ldc+n]
|
||
for j := range ctmp {
|
||
ctmp[j] = 0
|
||
}
|
||
}
|
||
} else {
|
||
for i := 0; i < m; i++ {
|
||
ctmp := c[i*ldc : i*ldc+n]
|
||
for j := range ctmp {
|
||
ctmp[j] *= beta
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
sgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
|
||
}
|
||
|
||
func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
|
||
// dgemmParallel computes a parallel matrix multiplication by partitioning
|
||
// a and b into sub-blocks, and updating c with the multiplication of the sub-block
|
||
// In all cases,
|
||
// A = [ A_11 A_12 ... A_1j
|
||
// A_21 A_22 ... A_2j
|
||
// ...
|
||
// A_i1 A_i2 ... A_ij]
|
||
//
|
||
// and same for B. All of the submatrix sizes are blockSize×blockSize except
|
||
// at the edges.
|
||
//
|
||
// In all cases, there is one dimension for each matrix along which
|
||
// C must be updated sequentially.
|
||
// Cij = \sum_k Aik Bki, (A * B)
|
||
// Cij = \sum_k Aki Bkj, (Aᵀ * B)
|
||
// Cij = \sum_k Aik Bjk, (A * Bᵀ)
|
||
// Cij = \sum_k Aki Bjk, (Aᵀ * Bᵀ)
|
||
//
|
||
// This code computes one {i, j} block sequentially along the k dimension,
|
||
// and computes all of the {i, j} blocks concurrently. This
|
||
// partitioning allows Cij to be updated in-place without race-conditions.
|
||
// Instead of launching a goroutine for each possible concurrent computation,
|
||
// a number of worker goroutines are created and channels are used to pass
|
||
// available and completed cases.
|
||
//
|
||
// http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
|
||
// multiplies, though this code does not copy matrices to attempt to eliminate
|
||
// cache misses.
|
||
|
||
maxKLen := k
|
||
parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
|
||
if parBlocks < minParBlock {
|
||
// The matrix multiplication is small in the dimensions where it can be
|
||
// computed concurrently. Just do it in serial.
|
||
sgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
|
||
return
|
||
}
|
||
|
||
// workerLimit acts a number of maximum concurrent workers,
|
||
// with the limit set to the number of procs available.
|
||
workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
|
||
|
||
// wg is used to wait for all
|
||
var wg sync.WaitGroup
|
||
wg.Add(parBlocks)
|
||
defer wg.Wait()
|
||
|
||
for i := 0; i < m; i += blockSize {
|
||
for j := 0; j < n; j += blockSize {
|
||
workerLimit <- struct{}{}
|
||
go func(i, j int) {
|
||
defer func() {
|
||
wg.Done()
|
||
<-workerLimit
|
||
}()
|
||
|
||
leni := blockSize
|
||
if i+leni > m {
|
||
leni = m - i
|
||
}
|
||
lenj := blockSize
|
||
if j+lenj > n {
|
||
lenj = n - j
|
||
}
|
||
|
||
cSub := sliceView32(c, ldc, i, j, leni, lenj)
|
||
|
||
// Compute A_ik B_kj for all k
|
||
for k := 0; k < maxKLen; k += blockSize {
|
||
lenk := blockSize
|
||
if k+lenk > maxKLen {
|
||
lenk = maxKLen - k
|
||
}
|
||
var aSub, bSub []float32
|
||
if aTrans {
|
||
aSub = sliceView32(a, lda, k, i, lenk, leni)
|
||
} else {
|
||
aSub = sliceView32(a, lda, i, k, leni, lenk)
|
||
}
|
||
if bTrans {
|
||
bSub = sliceView32(b, ldb, j, k, lenj, lenk)
|
||
} else {
|
||
bSub = sliceView32(b, ldb, k, j, lenk, lenj)
|
||
}
|
||
sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
|
||
}
|
||
}(i, j)
|
||
}
|
||
}
|
||
}
|
||
|
||
// sgemmSerial is serial matrix multiply
|
||
func sgemmSerial(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
|
||
switch {
|
||
case !aTrans && !bTrans:
|
||
sgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
|
||
return
|
||
case aTrans && !bTrans:
|
||
sgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
|
||
return
|
||
case !aTrans && bTrans:
|
||
sgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
|
||
return
|
||
case aTrans && bTrans:
|
||
sgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
|
||
return
|
||
default:
|
||
panic("unreachable")
|
||
}
|
||
}
|
||
|
||
// sgemmSerial where neither a nor b are transposed
|
||
func sgemmSerialNotNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
|
||
// This style is used instead of the literal [i*stride +j]) is used because
|
||
// approximately 5 times faster as of go 1.3.
|
||
for i := 0; i < m; i++ {
|
||
ctmp := c[i*ldc : i*ldc+n]
|
||
for l, v := range a[i*lda : i*lda+k] {
|
||
tmp := alpha * v
|
||
if tmp != 0 {
|
||
f32.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// sgemmSerial where neither a is transposed and b is not
|
||
func sgemmSerialTransNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
|
||
// This style is used instead of the literal [i*stride +j]) is used because
|
||
// approximately 5 times faster as of go 1.3.
|
||
for l := 0; l < k; l++ {
|
||
btmp := b[l*ldb : l*ldb+n]
|
||
for i, v := range a[l*lda : l*lda+m] {
|
||
tmp := alpha * v
|
||
if tmp != 0 {
|
||
ctmp := c[i*ldc : i*ldc+n]
|
||
f32.AxpyUnitary(tmp, btmp, ctmp)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// sgemmSerial where neither a is not transposed and b is
|
||
func sgemmSerialNotTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
|
||
// This style is used instead of the literal [i*stride +j]) is used because
|
||
// approximately 5 times faster as of go 1.3.
|
||
for i := 0; i < m; i++ {
|
||
atmp := a[i*lda : i*lda+k]
|
||
ctmp := c[i*ldc : i*ldc+n]
|
||
for j := 0; j < n; j++ {
|
||
ctmp[j] += alpha * f32.DotUnitary(atmp, b[j*ldb:j*ldb+k])
|
||
}
|
||
}
|
||
}
|
||
|
||
// sgemmSerial where both are transposed
|
||
func sgemmSerialTransTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
|
||
// This style is used instead of the literal [i*stride +j]) is used because
|
||
// approximately 5 times faster as of go 1.3.
|
||
for l := 0; l < k; l++ {
|
||
for i, v := range a[l*lda : l*lda+m] {
|
||
tmp := alpha * v
|
||
if tmp != 0 {
|
||
ctmp := c[i*ldc : i*ldc+n]
|
||
f32.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func sliceView32(a []float32, lda, i, j, r, c int) []float32 {
|
||
return a[i*lda+j : (i+r-1)*lda+j+c]
|
||
}
|