blas/gonum: remove checkDMatrix helper

This commit is contained in:
Vladimir Chalupecky
2018-11-12 14:38:21 +01:00
committed by Vladimír Chalupecký
parent eb6a40d81a
commit 483706e54a
6 changed files with 114 additions and 49 deletions

View File

@@ -21,25 +21,75 @@ import (
// an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or
// B are transposed.
func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
switch tA {
default:
panic(badTranspose)
case blas.NoTrans, blas.Trans, blas.ConjTrans:
}
if tB != blas.NoTrans && tB != blas.Trans && tB != blas.ConjTrans {
switch tB {
default:
panic(badTranspose)
case blas.NoTrans, blas.Trans, blas.ConjTrans:
}
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
if k < 0 {
panic(kLT0)
}
aTrans := tA == blas.Trans || tA == blas.ConjTrans
if aTrans {
checkDMatrix('a', k, m, a, lda)
if lda < max(1, m) {
panic(badLdA)
}
} else {
checkDMatrix('a', m, k, a, lda)
if lda < max(1, k) {
panic(badLdA)
}
}
bTrans := tB == blas.Trans || tB == blas.ConjTrans
if bTrans {
checkDMatrix('b', n, k, b, ldb)
if ldb < max(1, k) {
panic(badLdB)
}
} else {
checkDMatrix('b', k, n, b, ldb)
if ldb < max(1, n) {
panic(badLdB)
}
}
if ldc < max(1, n) {
panic(badLdC)
}
// Quick return if possible.
if m == 0 || n == 0 || ((alpha == 0 || k == 0) && beta == 1) {
return
}
if aTrans {
if len(a) < (k-1)*lda+m {
panic(shortA)
}
} else {
if len(a) < (m-1)*lda+k {
panic(shortA)
}
}
if bTrans {
if len(b) < (n-1)*ldb+k {
panic(shortB)
}
} else {
if len(b) < (k-1)*ldb+n {
panic(shortB)
}
}
if len(c) < (m-1)*ldc+n {
panic(shortC)
}
checkDMatrix('c', m, n, c, ldc)
// scale c
if beta != 1 {

View File

@@ -37,36 +37,6 @@ func min(a, b int) int {
return a
}
func checkSMatrix(name byte, m, n int, a []float32, lda int) {
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
if lda < n {
panic("blas: illegal stride of " + string(name))
}
if len(a) < (m-1)*lda+n {
panic("blas: index of " + string(name) + " out of range")
}
}
func checkDMatrix(name byte, m, n int, a []float64, lda int) {
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
if lda < n {
panic("blas: illegal stride of " + string(name))
}
if len(a) < (m-1)*lda+n {
panic("blas: index of " + string(name) + " out of range")
}
}
func checkZMatrix(name byte, m, n int, a []complex128, lda int) {
if m < 0 {
panic(mLT0)

View File

@@ -15,7 +15,6 @@ var _ blas.Float64Level2 = Implementation{}
// A += alpha * x * y^T
// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
func (Implementation) Dger(m, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int) {
// Check inputs
if m < 0 {
panic(mLT0)
}
@@ -520,7 +519,6 @@ func (Implementation) Dtrsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int,
// where A is an n×n symmetric matrix, x and y are vectors, and alpha and
// beta are scalars.
func (Implementation) Dsymv(ul blas.Uplo, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) {
// Check inputs
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}

View File

@@ -19,7 +19,6 @@ var _ blas.Float32Level2 = Implementation{}
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Sger(m, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32, lda int) {
// Check inputs
if m < 0 {
panic(mLT0)
}
@@ -532,7 +531,6 @@ func (Implementation) Strsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int,
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Ssymv(ul blas.Uplo, n int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int) {
// Check inputs
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}

View File

@@ -25,25 +25,75 @@ import (
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
switch tA {
default:
panic(badTranspose)
case blas.NoTrans, blas.Trans, blas.ConjTrans:
}
if tB != blas.NoTrans && tB != blas.Trans && tB != blas.ConjTrans {
switch tB {
default:
panic(badTranspose)
case blas.NoTrans, blas.Trans, blas.ConjTrans:
}
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
if k < 0 {
panic(kLT0)
}
aTrans := tA == blas.Trans || tA == blas.ConjTrans
if aTrans {
checkSMatrix('a', k, m, a, lda)
if lda < max(1, m) {
panic(badLdA)
}
} else {
checkSMatrix('a', m, k, a, lda)
if lda < max(1, k) {
panic(badLdA)
}
}
bTrans := tB == blas.Trans || tB == blas.ConjTrans
if bTrans {
checkSMatrix('b', n, k, b, ldb)
if ldb < max(1, k) {
panic(badLdB)
}
} else {
checkSMatrix('b', k, n, b, ldb)
if ldb < max(1, n) {
panic(badLdB)
}
}
if ldc < max(1, n) {
panic(badLdC)
}
// Quick return if possible.
if m == 0 || n == 0 || ((alpha == 0 || k == 0) && beta == 1) {
return
}
if aTrans {
if len(a) < (k-1)*lda+m {
panic(shortA)
}
} else {
if len(a) < (m-1)*lda+k {
panic(shortA)
}
}
if bTrans {
if len(b) < (n-1)*ldb+k {
panic(shortB)
}
} else {
if len(b) < (k-1)*ldb+n {
panic(shortB)
}
}
if len(c) < (m-1)*ldc+n {
panic(shortC)
}
checkSMatrix('c', m, n, c, ldc)
// scale c
if beta != 1 {

View File

@@ -123,7 +123,6 @@ echo -e '// Code generated by "go generate gonum.org/v1/gonum/blas/gonum”; DO
cat dgemm.go \
| gofmt -r 'float64 -> float32' \
| gofmt -r 'sliceView64 -> sliceView32' \
| gofmt -r 'checkDMatrix -> checkSMatrix' \
\
| gofmt -r 'dgemmParallel -> sgemmParallel' \
| gofmt -r 'computeNumBlocks64 -> computeNumBlocks32' \