blas/gonum: add Zgemm with test

This commit is contained in:
Vladimir Chalupecky
2019-01-06 21:45:31 +01:00
committed by Vladimír Chalupecký
parent 1fc0fba783
commit bf4bccac52
7 changed files with 520 additions and 3 deletions

View File

@@ -135,9 +135,6 @@ func (Implementation) Cher2k(ul blas.Uplo, t blas.Transpose, n, k int, alpha com
// Level 3 complex128 routines.
func (Implementation) Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
panic(noComplex)
}
func (Implementation) Zsymm(s blas.Side, ul blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
panic(noComplex)
}

View File

@@ -7,9 +7,12 @@ package gonum
import (
"math"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/internal/asm/c128"
)
var _ blas.Complex128Level1 = Implementation{}
// Dzasum returns the sum of the absolute values of the elements of x
// \sum_i |Re(x[i])| + |Im(x[i])|
// Dzasum returns 0 if incX is negative.

View File

@@ -11,6 +11,8 @@ import (
"gonum.org/v1/gonum/internal/asm/c128"
)
var _ blas.Complex128Level2 = Implementation{}
// Zgbmv performs one of the matrix-vector operations
// y = alpha * A * x + beta * y if trans = blas.NoTrans
// y = alpha * A^T * x + beta * y if trans = blas.Trans

View File

@@ -0,0 +1,254 @@
// Copyright ©2019 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gonum
import (
"math/cmplx"
"gonum.org/v1/gonum/blas"
)
var _ blas.Complex128Level3 = Implementation{}
// Zgemm performs one of the matrix-matrix operations
// C = alpha * op(A) * op(B) + beta * C
// where op(X) is one of
// op(X) = X or op(X) = X^T or op(X) = X^H,
// alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix,
// op(B) a k×n matrix and C an m×n matrix.
func (Implementation) Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
switch tA {
default:
panic(badTranspose)
case blas.NoTrans, blas.Trans, blas.ConjTrans:
}
switch tB {
default:
panic(badTranspose)
case blas.NoTrans, blas.Trans, blas.ConjTrans:
}
switch {
case m < 0:
panic(mLT0)
case n < 0:
panic(nLT0)
case k < 0:
panic(kLT0)
}
rowA, colA := m, k
if tA != blas.NoTrans {
rowA, colA = k, m
}
if lda < max(1, colA) {
panic(badLdA)
}
rowB, colB := k, n
if tB != blas.NoTrans {
rowB, colB = n, k
}
if ldb < max(1, colB) {
panic(badLdB)
}
if ldc < max(1, n) {
panic(badLdC)
}
// Quick return if possible.
if m == 0 || n == 0 {
return
}
// For zero matrix size the following slice length checks are trivially satisfied.
if len(a) < (rowA-1)*lda+colA {
panic(shortA)
}
if len(b) < (rowB-1)*ldb+colB {
panic(shortB)
}
if len(c) < (m-1)*ldc+n {
panic(shortC)
}
// Quick return if possible.
if (alpha == 0 || k == 0) && beta == 1 {
return
}
if alpha == 0 {
if beta == 0 {
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
c[i*ldc+j] = 0
}
}
} else {
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
c[i*ldc+j] *= beta
}
}
}
return
}
switch tA {
case blas.NoTrans:
switch tB {
case blas.NoTrans:
// Form C = alpha * A * B + beta * C.
for i := 0; i < m; i++ {
if beta == 0 {
for j := 0; j < n; j++ {
c[i*ldc+j] = 0
}
} else if beta != 1 {
for j := 0; j < n; j++ {
c[i*ldc+j] *= beta
}
}
for l := 0; l < k; l++ {
tmp := alpha * a[i*lda+l]
for j := 0; j < n; j++ {
c[i*ldc+j] += tmp * b[l*ldb+j]
}
}
}
case blas.Trans:
// Form C = alpha * A * B^T + beta * C.
for i := 0; i < m; i++ {
if beta == 0 {
for j := 0; j < n; j++ {
c[i*ldc+j] = 0
}
} else if beta != 1 {
for j := 0; j < n; j++ {
c[i*ldc+j] *= beta
}
}
for l := 0; l < k; l++ {
tmp := alpha * a[i*lda+l]
for j := 0; j < n; j++ {
c[i*ldc+j] += tmp * b[j*ldb+l]
}
}
}
case blas.ConjTrans:
// Form C = alpha * A * B^H + beta * C.
for i := 0; i < m; i++ {
if beta == 0 {
for j := 0; j < n; j++ {
c[i*ldc+j] = 0
}
} else if beta != 1 {
for j := 0; j < n; j++ {
c[i*ldc+j] *= beta
}
}
for l := 0; l < k; l++ {
tmp := alpha * a[i*lda+l]
for j := 0; j < n; j++ {
c[i*ldc+j] += tmp * cmplx.Conj(b[j*ldb+l])
}
}
}
}
case blas.Trans:
switch tB {
case blas.NoTrans:
// Form C = alpha * A^T * B + beta * C.
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
var tmp complex128
for l := 0; l < k; l++ {
tmp += a[l*lda+i] * b[l*ldb+j]
}
if beta == 0 {
c[i*ldc+j] = alpha * tmp
} else {
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
}
}
}
case blas.Trans:
// Form C = alpha * A^T * B^T + beta * C.
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
var tmp complex128
for l := 0; l < k; l++ {
tmp += a[l*lda+i] * b[j*ldb+l]
}
if beta == 0 {
c[i*ldc+j] = alpha * tmp
} else {
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
}
}
}
case blas.ConjTrans:
// Form C = alpha * A^T * B^H + beta * C.
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
var tmp complex128
for l := 0; l < k; l++ {
tmp += a[l*lda+i] * cmplx.Conj(b[j*ldb+l])
}
if beta == 0 {
c[i*ldc+j] = alpha * tmp
} else {
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
}
}
}
}
case blas.ConjTrans:
switch tB {
case blas.NoTrans:
// Form C = alpha * A^H * B + beta * C.
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
var tmp complex128
for l := 0; l < k; l++ {
tmp += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j]
}
if beta == 0 {
c[i*ldc+j] = alpha * tmp
} else {
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
}
}
}
case blas.Trans:
// Form C = alpha * A^H * B^T + beta * C.
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
var tmp complex128
for l := 0; l < k; l++ {
tmp += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l]
}
if beta == 0 {
c[i*ldc+j] = alpha * tmp
} else {
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
}
}
}
case blas.ConjTrans:
// Form C = alpha * A^H * B^H + beta * C.
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
var tmp complex128
for l := 0; l < k; l++ {
tmp += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l])
}
if beta == 0 {
c[i*ldc+j] = alpha * tmp
} else {
c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
}
}
}
}
}
}

View File

@@ -0,0 +1,15 @@
// Copyright ©2019 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gonum
import (
"testing"
"gonum.org/v1/gonum/blas/testblas"
)
func TestZgemm(t *testing.T) {
testblas.ZgemmTest(t, impl)
}

View File

@@ -9,7 +9,9 @@ import (
"math/cmplx"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/floats"
)
// throwPanic will throw unexpected panics if true, or will just report them as errors if false
@@ -523,3 +525,130 @@ func zPackTriBand(k, ldab int, uplo blas.Uplo, n int, a []complex128, lda int) [
}
return ab
}
// zEqualApprox returns whether the slices a and b are approximately equal.
func zEqualApprox(a, b []complex128, tol float64) bool {
if len(a) != len(b) {
panic("mismatched slice length")
}
for i, ai := range a {
if !floats.EqualWithinAbs(cmplx.Abs(ai), cmplx.Abs(b[i]), tol) {
return false
}
}
return true
}
// rndComplex128 returns a complex128 with random components.
func rndComplex128(rnd *rand.Rand) complex128 {
return complex(rnd.NormFloat64(), rnd.NormFloat64())
}
// zmm returns the result of one of the matrix-matrix operations
// alpha * op(A) * op(B) + beta * C
// where op(X) is one of
// op(X) = X or op(X) = X^T or op(X) = X^H,
// alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix,
// op(B) a k×n matrix and C an m×n matrix.
//
// The returned slice is newly allocated, has the same length as c and the
// matrix it represents has the stride ldc. Out-of-range elements are equal to
// those of C to ease comparison of results from BLAS Level 3 functions.
func zmm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) []complex128 {
r := make([]complex128, len(c))
copy(r, c)
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
r[i*ldc+j] = 0
}
}
switch tA {
case blas.NoTrans:
switch tB {
case blas.NoTrans:
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
for l := 0; l < k; l++ {
r[i*ldc+j] += a[i*lda+l] * b[l*ldb+j]
}
}
}
case blas.Trans:
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
for l := 0; l < k; l++ {
r[i*ldc+j] += a[i*lda+l] * b[j*ldb+l]
}
}
}
case blas.ConjTrans:
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
for l := 0; l < k; l++ {
r[i*ldc+j] += a[i*lda+l] * cmplx.Conj(b[j*ldb+l])
}
}
}
}
case blas.Trans:
switch tB {
case blas.NoTrans:
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
for l := 0; l < k; l++ {
r[i*ldc+j] += a[l*lda+i] * b[l*ldb+j]
}
}
}
case blas.Trans:
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
for l := 0; l < k; l++ {
r[i*ldc+j] += a[l*lda+i] * b[j*ldb+l]
}
}
}
case blas.ConjTrans:
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
for l := 0; l < k; l++ {
r[i*ldc+j] += a[l*lda+i] * cmplx.Conj(b[j*ldb+l])
}
}
}
}
case blas.ConjTrans:
switch tB {
case blas.NoTrans:
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
for l := 0; l < k; l++ {
r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j]
}
}
}
case blas.Trans:
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
for l := 0; l < k; l++ {
r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l]
}
}
}
case blas.ConjTrans:
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
for l := 0; l < k; l++ {
r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l])
}
}
}
}
}
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
r[i*ldc+j] = alpha*r[i*ldc+j] + beta*c[i*ldc+j]
}
}
return r
}

117
blas/testblas/zgemm.go Normal file
View File

@@ -0,0 +1,117 @@
// Copyright ©2019 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testblas
import (
"fmt"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/blas"
)
type Zgemmer interface {
Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
}
func ZgemmTest(t *testing.T, impl Zgemmer) {
for _, tA := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
for _, tB := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
name := transString(tA) + "-" + transString(tB)
t.Run(name, func(t *testing.T) {
for _, m := range []int{0, 1, 2, 5, 10} {
for _, n := range []int{0, 1, 2, 5, 10} {
for _, k := range []int{0, 1, 2, 7, 11} {
zgemmTest(t, impl, tA, tB, m, n, k)
}
}
}
})
}
}
}
// transString returns a string representation of blas.Transpose.
func transString(t blas.Transpose) string {
switch t {
case blas.NoTrans:
return "NoTrans"
case blas.Trans:
return "Trans"
case blas.ConjTrans:
return "ConjTrans"
}
return "unknown trans"
}
func zgemmTest(t *testing.T, impl Zgemmer, tA, tB blas.Transpose, m, n, k int) {
const tol = 1e-13
rnd := rand.New(rand.NewSource(1))
rowA, colA := m, k
if tA != blas.NoTrans {
rowA, colA = k, m
}
rowB, colB := k, n
if tB != blas.NoTrans {
rowB, colB = n, k
}
for _, lda := range []int{max(1, colA), colA + 2} {
for _, ldb := range []int{max(1, colB), colB + 3} {
for _, ldc := range []int{max(1, n), n + 4} {
for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} {
for _, beta := range []complex128{0, 1, complex(1.3, -1.1)} {
// Allocate the matrix A and fill it with random numbers.
a := make([]complex128, rowA*lda)
for i := range a {
a[i] = rndComplex128(rnd)
}
// Create a copy of A.
aCopy := make([]complex128, len(a))
copy(aCopy, a)
// Allocate the matrix B and fill it with random numbers.
b := make([]complex128, rowB*ldb)
for i := range b {
b[i] = rndComplex128(rnd)
}
// Create a copy of B.
bCopy := make([]complex128, len(b))
copy(bCopy, b)
// Allocate the matrix C and fill it with random numbers.
c := make([]complex128, m*ldc)
for i := range c {
c[i] = rndComplex128(rnd)
}
// Compute the expected result using an internal Zgemm implementation.
want := zmm(tA, tB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
// Compute a result using Zgemm.
impl.Zgemm(tA, tB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
prefix := fmt.Sprintf("m=%v,n=%v,k=%v,lda=%v,ldb=%v,ldc=%v,alpha=%v,beta=%v", m, n, k, lda, ldb, ldc, alpha, beta)
if !zsame(a, aCopy) {
t.Errorf("%v: unexpected modification of A", prefix)
continue
}
if !zsame(b, bCopy) {
t.Errorf("%v: unexpected modification of B", prefix)
continue
}
if !zEqualApprox(c, want, tol) {
t.Errorf("%v: unexpected result,\nwant=%v\ngot =%v\n", prefix, want, c)
}
}
}
}
}
}
}