mirror of
https://github.com/gonum/gonum.git
synced 2025-11-03 11:21:14 +08:00
blas/gonum: add Zgemm with test
This commit is contained in:
committed by
Vladimír Chalupecký
parent
1fc0fba783
commit
bf4bccac52
@@ -135,9 +135,6 @@ func (Implementation) Cher2k(ul blas.Uplo, t blas.Transpose, n, k int, alpha com
|
||||
|
||||
// Level 3 complex128 routines.
|
||||
|
||||
func (Implementation) Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
|
||||
panic(noComplex)
|
||||
}
|
||||
func (Implementation) Zsymm(s blas.Side, ul blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
|
||||
panic(noComplex)
|
||||
}
|
||||
|
||||
@@ -7,9 +7,12 @@ package gonum
|
||||
import (
|
||||
"math"
|
||||
|
||||
"gonum.org/v1/gonum/blas"
|
||||
"gonum.org/v1/gonum/internal/asm/c128"
|
||||
)
|
||||
|
||||
var _ blas.Complex128Level1 = Implementation{}
|
||||
|
||||
// Dzasum returns the sum of the absolute values of the elements of x
|
||||
// \sum_i |Re(x[i])| + |Im(x[i])|
|
||||
// Dzasum returns 0 if incX is negative.
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"gonum.org/v1/gonum/internal/asm/c128"
|
||||
)
|
||||
|
||||
var _ blas.Complex128Level2 = Implementation{}
|
||||
|
||||
// Zgbmv performs one of the matrix-vector operations
|
||||
// y = alpha * A * x + beta * y if trans = blas.NoTrans
|
||||
// y = alpha * A^T * x + beta * y if trans = blas.Trans
|
||||
|
||||
254
blas/gonum/level3cmplx128.go
Normal file
254
blas/gonum/level3cmplx128.go
Normal file
@@ -0,0 +1,254 @@
|
||||
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gonum
|
||||
|
||||
import (
|
||||
"math/cmplx"
|
||||
|
||||
"gonum.org/v1/gonum/blas"
|
||||
)
|
||||
|
||||
var _ blas.Complex128Level3 = Implementation{}
|
||||
|
||||
// Zgemm performs one of the matrix-matrix operations
|
||||
// C = alpha * op(A) * op(B) + beta * C
|
||||
// where op(X) is one of
|
||||
// op(X) = X or op(X) = X^T or op(X) = X^H,
|
||||
// alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix,
|
||||
// op(B) a k×n matrix and C an m×n matrix.
|
||||
func (Implementation) Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
|
||||
switch tA {
|
||||
default:
|
||||
panic(badTranspose)
|
||||
case blas.NoTrans, blas.Trans, blas.ConjTrans:
|
||||
}
|
||||
switch tB {
|
||||
default:
|
||||
panic(badTranspose)
|
||||
case blas.NoTrans, blas.Trans, blas.ConjTrans:
|
||||
}
|
||||
switch {
|
||||
case m < 0:
|
||||
panic(mLT0)
|
||||
case n < 0:
|
||||
panic(nLT0)
|
||||
case k < 0:
|
||||
panic(kLT0)
|
||||
}
|
||||
rowA, colA := m, k
|
||||
if tA != blas.NoTrans {
|
||||
rowA, colA = k, m
|
||||
}
|
||||
if lda < max(1, colA) {
|
||||
panic(badLdA)
|
||||
}
|
||||
rowB, colB := k, n
|
||||
if tB != blas.NoTrans {
|
||||
rowB, colB = n, k
|
||||
}
|
||||
if ldb < max(1, colB) {
|
||||
panic(badLdB)
|
||||
}
|
||||
if ldc < max(1, n) {
|
||||
panic(badLdC)
|
||||
}
|
||||
|
||||
// Quick return if possible.
|
||||
if m == 0 || n == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// For zero matrix size the following slice length checks are trivially satisfied.
|
||||
if len(a) < (rowA-1)*lda+colA {
|
||||
panic(shortA)
|
||||
}
|
||||
if len(b) < (rowB-1)*ldb+colB {
|
||||
panic(shortB)
|
||||
}
|
||||
if len(c) < (m-1)*ldc+n {
|
||||
panic(shortC)
|
||||
}
|
||||
|
||||
// Quick return if possible.
|
||||
if (alpha == 0 || k == 0) && beta == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
if alpha == 0 {
|
||||
if beta == 0 {
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] = 0
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] *= beta
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
switch tA {
|
||||
case blas.NoTrans:
|
||||
switch tB {
|
||||
case blas.NoTrans:
|
||||
// Form C = alpha * A * B + beta * C.
|
||||
for i := 0; i < m; i++ {
|
||||
if beta == 0 {
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] = 0
|
||||
}
|
||||
} else if beta != 1 {
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] *= beta
|
||||
}
|
||||
}
|
||||
for l := 0; l < k; l++ {
|
||||
tmp := alpha * a[i*lda+l]
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] += tmp * b[l*ldb+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.Trans:
|
||||
// Form C = alpha * A * B^T + beta * C.
|
||||
for i := 0; i < m; i++ {
|
||||
if beta == 0 {
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] = 0
|
||||
}
|
||||
} else if beta != 1 {
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] *= beta
|
||||
}
|
||||
}
|
||||
for l := 0; l < k; l++ {
|
||||
tmp := alpha * a[i*lda+l]
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] += tmp * b[j*ldb+l]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.ConjTrans:
|
||||
// Form C = alpha * A * B^H + beta * C.
|
||||
for i := 0; i < m; i++ {
|
||||
if beta == 0 {
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] = 0
|
||||
}
|
||||
} else if beta != 1 {
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] *= beta
|
||||
}
|
||||
}
|
||||
for l := 0; l < k; l++ {
|
||||
tmp := alpha * a[i*lda+l]
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] += tmp * cmplx.Conj(b[j*ldb+l])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.Trans:
|
||||
switch tB {
|
||||
case blas.NoTrans:
|
||||
// Form C = alpha * A^T * B + beta * C.
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
var tmp complex128
|
||||
for l := 0; l < k; l++ {
|
||||
tmp += a[l*lda+i] * b[l*ldb+j]
|
||||
}
|
||||
if beta == 0 {
|
||||
c[i*ldc+j] = alpha * tmp
|
||||
} else {
|
||||
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.Trans:
|
||||
// Form C = alpha * A^T * B^T + beta * C.
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
var tmp complex128
|
||||
for l := 0; l < k; l++ {
|
||||
tmp += a[l*lda+i] * b[j*ldb+l]
|
||||
}
|
||||
if beta == 0 {
|
||||
c[i*ldc+j] = alpha * tmp
|
||||
} else {
|
||||
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.ConjTrans:
|
||||
// Form C = alpha * A^T * B^H + beta * C.
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
var tmp complex128
|
||||
for l := 0; l < k; l++ {
|
||||
tmp += a[l*lda+i] * cmplx.Conj(b[j*ldb+l])
|
||||
}
|
||||
if beta == 0 {
|
||||
c[i*ldc+j] = alpha * tmp
|
||||
} else {
|
||||
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.ConjTrans:
|
||||
switch tB {
|
||||
case blas.NoTrans:
|
||||
// Form C = alpha * A^H * B + beta * C.
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
var tmp complex128
|
||||
for l := 0; l < k; l++ {
|
||||
tmp += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j]
|
||||
}
|
||||
if beta == 0 {
|
||||
c[i*ldc+j] = alpha * tmp
|
||||
} else {
|
||||
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.Trans:
|
||||
// Form C = alpha * A^H * B^T + beta * C.
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
var tmp complex128
|
||||
for l := 0; l < k; l++ {
|
||||
tmp += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l]
|
||||
}
|
||||
if beta == 0 {
|
||||
c[i*ldc+j] = alpha * tmp
|
||||
} else {
|
||||
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.ConjTrans:
|
||||
// Form C = alpha * A^H * B^H + beta * C.
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
var tmp complex128
|
||||
for l := 0; l < k; l++ {
|
||||
tmp += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l])
|
||||
}
|
||||
if beta == 0 {
|
||||
c[i*ldc+j] = alpha * tmp
|
||||
} else {
|
||||
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
15
blas/gonum/level3cmplx128_test.go
Normal file
15
blas/gonum/level3cmplx128_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gonum
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gonum.org/v1/gonum/blas/testblas"
|
||||
)
|
||||
|
||||
func TestZgemm(t *testing.T) {
|
||||
testblas.ZgemmTest(t, impl)
|
||||
}
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"math/cmplx"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
"gonum.org/v1/gonum/blas"
|
||||
"gonum.org/v1/gonum/floats"
|
||||
)
|
||||
|
||||
// throwPanic will throw unexpected panics if true, or will just report them as errors if false
|
||||
@@ -523,3 +525,130 @@ func zPackTriBand(k, ldab int, uplo blas.Uplo, n int, a []complex128, lda int) [
|
||||
}
|
||||
return ab
|
||||
}
|
||||
|
||||
// zEqualApprox returns whether the slices a and b are approximately equal.
|
||||
func zEqualApprox(a, b []complex128, tol float64) bool {
|
||||
if len(a) != len(b) {
|
||||
panic("mismatched slice length")
|
||||
}
|
||||
for i, ai := range a {
|
||||
if !floats.EqualWithinAbs(cmplx.Abs(ai), cmplx.Abs(b[i]), tol) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// rndComplex128 returns a complex128 with random components.
|
||||
func rndComplex128(rnd *rand.Rand) complex128 {
|
||||
return complex(rnd.NormFloat64(), rnd.NormFloat64())
|
||||
}
|
||||
|
||||
// zmm returns the result of one of the matrix-matrix operations
|
||||
// alpha * op(A) * op(B) + beta * C
|
||||
// where op(X) is one of
|
||||
// op(X) = X or op(X) = X^T or op(X) = X^H,
|
||||
// alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix,
|
||||
// op(B) a k×n matrix and C an m×n matrix.
|
||||
//
|
||||
// The returned slice is newly allocated, has the same length as c and the
|
||||
// matrix it represents has the stride ldc. Out-of-range elements are equal to
|
||||
// those of C to ease comparison of results from BLAS Level 3 functions.
|
||||
func zmm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) []complex128 {
|
||||
r := make([]complex128, len(c))
|
||||
copy(r, c)
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
r[i*ldc+j] = 0
|
||||
}
|
||||
}
|
||||
switch tA {
|
||||
case blas.NoTrans:
|
||||
switch tB {
|
||||
case blas.NoTrans:
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
for l := 0; l < k; l++ {
|
||||
r[i*ldc+j] += a[i*lda+l] * b[l*ldb+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.Trans:
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
for l := 0; l < k; l++ {
|
||||
r[i*ldc+j] += a[i*lda+l] * b[j*ldb+l]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.ConjTrans:
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
for l := 0; l < k; l++ {
|
||||
r[i*ldc+j] += a[i*lda+l] * cmplx.Conj(b[j*ldb+l])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.Trans:
|
||||
switch tB {
|
||||
case blas.NoTrans:
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
for l := 0; l < k; l++ {
|
||||
r[i*ldc+j] += a[l*lda+i] * b[l*ldb+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.Trans:
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
for l := 0; l < k; l++ {
|
||||
r[i*ldc+j] += a[l*lda+i] * b[j*ldb+l]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.ConjTrans:
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
for l := 0; l < k; l++ {
|
||||
r[i*ldc+j] += a[l*lda+i] * cmplx.Conj(b[j*ldb+l])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.ConjTrans:
|
||||
switch tB {
|
||||
case blas.NoTrans:
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
for l := 0; l < k; l++ {
|
||||
r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.Trans:
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
for l := 0; l < k; l++ {
|
||||
r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l]
|
||||
}
|
||||
}
|
||||
}
|
||||
case blas.ConjTrans:
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
for l := 0; l < k; l++ {
|
||||
r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
r[i*ldc+j] = alpha*r[i*ldc+j] + beta*c[i*ldc+j]
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
117
blas/testblas/zgemm.go
Normal file
117
blas/testblas/zgemm.go
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testblas
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
"gonum.org/v1/gonum/blas"
|
||||
)
|
||||
|
||||
type Zgemmer interface {
|
||||
Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
|
||||
}
|
||||
|
||||
func ZgemmTest(t *testing.T, impl Zgemmer) {
|
||||
for _, tA := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
|
||||
for _, tB := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
|
||||
name := transString(tA) + "-" + transString(tB)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
for _, m := range []int{0, 1, 2, 5, 10} {
|
||||
for _, n := range []int{0, 1, 2, 5, 10} {
|
||||
for _, k := range []int{0, 1, 2, 7, 11} {
|
||||
zgemmTest(t, impl, tA, tB, m, n, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// transString returns a string representation of blas.Transpose.
|
||||
func transString(t blas.Transpose) string {
|
||||
switch t {
|
||||
case blas.NoTrans:
|
||||
return "NoTrans"
|
||||
case blas.Trans:
|
||||
return "Trans"
|
||||
case blas.ConjTrans:
|
||||
return "ConjTrans"
|
||||
}
|
||||
return "unknown trans"
|
||||
}
|
||||
|
||||
func zgemmTest(t *testing.T, impl Zgemmer, tA, tB blas.Transpose, m, n, k int) {
|
||||
const tol = 1e-13
|
||||
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
|
||||
rowA, colA := m, k
|
||||
if tA != blas.NoTrans {
|
||||
rowA, colA = k, m
|
||||
}
|
||||
rowB, colB := k, n
|
||||
if tB != blas.NoTrans {
|
||||
rowB, colB = n, k
|
||||
}
|
||||
|
||||
for _, lda := range []int{max(1, colA), colA + 2} {
|
||||
for _, ldb := range []int{max(1, colB), colB + 3} {
|
||||
for _, ldc := range []int{max(1, n), n + 4} {
|
||||
for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} {
|
||||
for _, beta := range []complex128{0, 1, complex(1.3, -1.1)} {
|
||||
// Allocate the matrix A and fill it with random numbers.
|
||||
a := make([]complex128, rowA*lda)
|
||||
for i := range a {
|
||||
a[i] = rndComplex128(rnd)
|
||||
}
|
||||
// Create a copy of A.
|
||||
aCopy := make([]complex128, len(a))
|
||||
copy(aCopy, a)
|
||||
|
||||
// Allocate the matrix B and fill it with random numbers.
|
||||
b := make([]complex128, rowB*ldb)
|
||||
for i := range b {
|
||||
b[i] = rndComplex128(rnd)
|
||||
}
|
||||
// Create a copy of B.
|
||||
bCopy := make([]complex128, len(b))
|
||||
copy(bCopy, b)
|
||||
|
||||
// Allocate the matrix C and fill it with random numbers.
|
||||
c := make([]complex128, m*ldc)
|
||||
for i := range c {
|
||||
c[i] = rndComplex128(rnd)
|
||||
}
|
||||
|
||||
// Compute the expected result using an internal Zgemm implementation.
|
||||
want := zmm(tA, tB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
|
||||
|
||||
// Compute a result using Zgemm.
|
||||
impl.Zgemm(tA, tB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
|
||||
|
||||
prefix := fmt.Sprintf("m=%v,n=%v,k=%v,lda=%v,ldb=%v,ldc=%v,alpha=%v,beta=%v", m, n, k, lda, ldb, ldc, alpha, beta)
|
||||
|
||||
if !zsame(a, aCopy) {
|
||||
t.Errorf("%v: unexpected modification of A", prefix)
|
||||
continue
|
||||
}
|
||||
if !zsame(b, bCopy) {
|
||||
t.Errorf("%v: unexpected modification of B", prefix)
|
||||
continue
|
||||
}
|
||||
|
||||
if !zEqualApprox(c, want, tol) {
|
||||
t.Errorf("%v: unexpected result,\nwant=%v\ngot =%v\n", prefix, want, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user