mirror of
https://github.com/gonum/gonum.git
synced 2025-10-12 10:30:17 +08:00
blas/gonum: remove checkDMatrix helper
This commit is contained in:

committed by
Vladimír Chalupecký

parent
eb6a40d81a
commit
483706e54a
@@ -21,25 +21,75 @@ import (
|
||||
// an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or
|
||||
// B are transposed.
|
||||
func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
|
||||
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
|
||||
switch tA {
|
||||
default:
|
||||
panic(badTranspose)
|
||||
case blas.NoTrans, blas.Trans, blas.ConjTrans:
|
||||
}
|
||||
if tB != blas.NoTrans && tB != blas.Trans && tB != blas.ConjTrans {
|
||||
switch tB {
|
||||
default:
|
||||
panic(badTranspose)
|
||||
case blas.NoTrans, blas.Trans, blas.ConjTrans:
|
||||
}
|
||||
if m < 0 {
|
||||
panic(mLT0)
|
||||
}
|
||||
if n < 0 {
|
||||
panic(nLT0)
|
||||
}
|
||||
if k < 0 {
|
||||
panic(kLT0)
|
||||
}
|
||||
aTrans := tA == blas.Trans || tA == blas.ConjTrans
|
||||
if aTrans {
|
||||
checkDMatrix('a', k, m, a, lda)
|
||||
if lda < max(1, m) {
|
||||
panic(badLdA)
|
||||
}
|
||||
} else {
|
||||
checkDMatrix('a', m, k, a, lda)
|
||||
if lda < max(1, k) {
|
||||
panic(badLdA)
|
||||
}
|
||||
}
|
||||
bTrans := tB == blas.Trans || tB == blas.ConjTrans
|
||||
if bTrans {
|
||||
checkDMatrix('b', n, k, b, ldb)
|
||||
if ldb < max(1, k) {
|
||||
panic(badLdB)
|
||||
}
|
||||
} else {
|
||||
checkDMatrix('b', k, n, b, ldb)
|
||||
if ldb < max(1, n) {
|
||||
panic(badLdB)
|
||||
}
|
||||
}
|
||||
if ldc < max(1, n) {
|
||||
panic(badLdC)
|
||||
}
|
||||
|
||||
// Quick return if possible.
|
||||
if m == 0 || n == 0 || ((alpha == 0 || k == 0) && beta == 1) {
|
||||
return
|
||||
}
|
||||
|
||||
if aTrans {
|
||||
if len(a) < (k-1)*lda+m {
|
||||
panic(shortA)
|
||||
}
|
||||
} else {
|
||||
if len(a) < (m-1)*lda+k {
|
||||
panic(shortA)
|
||||
}
|
||||
}
|
||||
if bTrans {
|
||||
if len(b) < (n-1)*ldb+k {
|
||||
panic(shortB)
|
||||
}
|
||||
} else {
|
||||
if len(b) < (k-1)*ldb+n {
|
||||
panic(shortB)
|
||||
}
|
||||
}
|
||||
if len(c) < (m-1)*ldc+n {
|
||||
panic(shortC)
|
||||
}
|
||||
checkDMatrix('c', m, n, c, ldc)
|
||||
|
||||
// scale c
|
||||
if beta != 1 {
|
||||
|
@@ -37,36 +37,6 @@ func min(a, b int) int {
|
||||
return a
|
||||
}
|
||||
|
||||
func checkSMatrix(name byte, m, n int, a []float32, lda int) {
|
||||
if m < 0 {
|
||||
panic(mLT0)
|
||||
}
|
||||
if n < 0 {
|
||||
panic(nLT0)
|
||||
}
|
||||
if lda < n {
|
||||
panic("blas: illegal stride of " + string(name))
|
||||
}
|
||||
if len(a) < (m-1)*lda+n {
|
||||
panic("blas: index of " + string(name) + " out of range")
|
||||
}
|
||||
}
|
||||
|
||||
func checkDMatrix(name byte, m, n int, a []float64, lda int) {
|
||||
if m < 0 {
|
||||
panic(mLT0)
|
||||
}
|
||||
if n < 0 {
|
||||
panic(nLT0)
|
||||
}
|
||||
if lda < n {
|
||||
panic("blas: illegal stride of " + string(name))
|
||||
}
|
||||
if len(a) < (m-1)*lda+n {
|
||||
panic("blas: index of " + string(name) + " out of range")
|
||||
}
|
||||
}
|
||||
|
||||
func checkZMatrix(name byte, m, n int, a []complex128, lda int) {
|
||||
if m < 0 {
|
||||
panic(mLT0)
|
||||
|
@@ -15,7 +15,6 @@ var _ blas.Float64Level2 = Implementation{}
|
||||
// A += alpha * x * y^T
|
||||
// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
|
||||
func (Implementation) Dger(m, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int) {
|
||||
// Check inputs
|
||||
if m < 0 {
|
||||
panic(mLT0)
|
||||
}
|
||||
@@ -520,7 +519,6 @@ func (Implementation) Dtrsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int,
|
||||
// where A is an n×n symmetric matrix, x and y are vectors, and alpha and
|
||||
// beta are scalars.
|
||||
func (Implementation) Dsymv(ul blas.Uplo, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) {
|
||||
// Check inputs
|
||||
if ul != blas.Lower && ul != blas.Upper {
|
||||
panic(badUplo)
|
||||
}
|
||||
|
@@ -19,7 +19,6 @@ var _ blas.Float32Level2 = Implementation{}
|
||||
//
|
||||
// Float32 implementations are autogenerated and not directly tested.
|
||||
func (Implementation) Sger(m, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32, lda int) {
|
||||
// Check inputs
|
||||
if m < 0 {
|
||||
panic(mLT0)
|
||||
}
|
||||
@@ -532,7 +531,6 @@ func (Implementation) Strsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int,
|
||||
//
|
||||
// Float32 implementations are autogenerated and not directly tested.
|
||||
func (Implementation) Ssymv(ul blas.Uplo, n int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int) {
|
||||
// Check inputs
|
||||
if ul != blas.Lower && ul != blas.Upper {
|
||||
panic(badUplo)
|
||||
}
|
||||
|
@@ -25,25 +25,75 @@ import (
|
||||
//
|
||||
// Float32 implementations are autogenerated and not directly tested.
|
||||
func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
|
||||
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
|
||||
switch tA {
|
||||
default:
|
||||
panic(badTranspose)
|
||||
case blas.NoTrans, blas.Trans, blas.ConjTrans:
|
||||
}
|
||||
if tB != blas.NoTrans && tB != blas.Trans && tB != blas.ConjTrans {
|
||||
switch tB {
|
||||
default:
|
||||
panic(badTranspose)
|
||||
case blas.NoTrans, blas.Trans, blas.ConjTrans:
|
||||
}
|
||||
if m < 0 {
|
||||
panic(mLT0)
|
||||
}
|
||||
if n < 0 {
|
||||
panic(nLT0)
|
||||
}
|
||||
if k < 0 {
|
||||
panic(kLT0)
|
||||
}
|
||||
aTrans := tA == blas.Trans || tA == blas.ConjTrans
|
||||
if aTrans {
|
||||
checkSMatrix('a', k, m, a, lda)
|
||||
if lda < max(1, m) {
|
||||
panic(badLdA)
|
||||
}
|
||||
} else {
|
||||
checkSMatrix('a', m, k, a, lda)
|
||||
if lda < max(1, k) {
|
||||
panic(badLdA)
|
||||
}
|
||||
}
|
||||
bTrans := tB == blas.Trans || tB == blas.ConjTrans
|
||||
if bTrans {
|
||||
checkSMatrix('b', n, k, b, ldb)
|
||||
if ldb < max(1, k) {
|
||||
panic(badLdB)
|
||||
}
|
||||
} else {
|
||||
checkSMatrix('b', k, n, b, ldb)
|
||||
if ldb < max(1, n) {
|
||||
panic(badLdB)
|
||||
}
|
||||
}
|
||||
if ldc < max(1, n) {
|
||||
panic(badLdC)
|
||||
}
|
||||
|
||||
// Quick return if possible.
|
||||
if m == 0 || n == 0 || ((alpha == 0 || k == 0) && beta == 1) {
|
||||
return
|
||||
}
|
||||
|
||||
if aTrans {
|
||||
if len(a) < (k-1)*lda+m {
|
||||
panic(shortA)
|
||||
}
|
||||
} else {
|
||||
if len(a) < (m-1)*lda+k {
|
||||
panic(shortA)
|
||||
}
|
||||
}
|
||||
if bTrans {
|
||||
if len(b) < (n-1)*ldb+k {
|
||||
panic(shortB)
|
||||
}
|
||||
} else {
|
||||
if len(b) < (k-1)*ldb+n {
|
||||
panic(shortB)
|
||||
}
|
||||
}
|
||||
if len(c) < (m-1)*ldc+n {
|
||||
panic(shortC)
|
||||
}
|
||||
checkSMatrix('c', m, n, c, ldc)
|
||||
|
||||
// scale c
|
||||
if beta != 1 {
|
||||
|
@@ -123,7 +123,6 @@ echo -e '// Code generated by "go generate gonum.org/v1/gonum/blas/gonum”; DO
|
||||
cat dgemm.go \
|
||||
| gofmt -r 'float64 -> float32' \
|
||||
| gofmt -r 'sliceView64 -> sliceView32' \
|
||||
| gofmt -r 'checkDMatrix -> checkSMatrix' \
|
||||
\
|
||||
| gofmt -r 'dgemmParallel -> sgemmParallel' \
|
||||
| gofmt -r 'computeNumBlocks64 -> computeNumBlocks32' \
|
||||
|
Reference in New Issue
Block a user