diff --git a/blas/gonum/cmplx.go b/blas/gonum/cmplx.go index 7af67e49..6e750067 100644 --- a/blas/gonum/cmplx.go +++ b/blas/gonum/cmplx.go @@ -138,9 +138,6 @@ func (Implementation) Cher2k(ul blas.Uplo, t blas.Transpose, n, k int, alpha com func (Implementation) Zsymm(s blas.Side, ul blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) { panic(noComplex) } -func (Implementation) Zsyrk(ul blas.Uplo, t blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int) { - panic(noComplex) -} func (Implementation) Zsyr2k(ul blas.Uplo, t blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) { panic(noComplex) } diff --git a/blas/gonum/level3cmplx128.go b/blas/gonum/level3cmplx128.go index 664a4e35..a3175ad1 100644 --- a/blas/gonum/level3cmplx128.go +++ b/blas/gonum/level3cmplx128.go @@ -8,6 +8,7 @@ import ( "math/cmplx" "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/internal/asm/c128" ) var _ blas.Complex128Level3 = Implementation{} @@ -252,3 +253,146 @@ func (Implementation) Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128 } } } + +// Zsyrk performs one of the symmetric rank-k operations +// C = alpha*A*A^T + beta*C if trans == blas.NoTrans +// C = alpha*A^T*A + beta*C if trans == blas.Trans +// where alpha and beta are scalars, C is an n×n symmetric matrix and A is +// an n×k matrix in the first case and a k×n matrix in the second case. +func (Implementation) Zsyrk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int) { + var rowA, colA int + switch trans { + default: + panic(badTranspose) + case blas.NoTrans: + rowA, colA = n, k + case blas.Trans: + rowA, colA = k, n + } + switch { + case uplo != blas.Lower && uplo != blas.Upper: + panic(badUplo) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case lda < max(1, colA): + panic(badLdA) + case ldc < max(1, n): + panic(badLdC) + } + + // Quick return if possible. + if n == 0 { + return + } + + // For zero matrix size the following slice length checks are trivially satisfied. + if len(a) < (rowA-1)*lda+colA { + panic(shortA) + } + if len(c) < (n-1)*ldc+n { + panic(shortC) + } + + // Quick return if possible. + if (alpha == 0 || k == 0) && beta == 1 { + return + } + + if alpha == 0 { + if uplo == blas.Upper { + if beta == 0 { + for i := 0; i < n; i++ { + ci := c[i*ldc+i : i*ldc+n] + for j := range ci { + ci[j] = 0 + } + } + } else { + for i := 0; i < n; i++ { + ci := c[i*ldc+i : i*ldc+n] + c128.ScalUnitary(beta, ci) + } + } + } else { + if beta == 0 { + for i := 0; i < n; i++ { + ci := c[i*ldc : i*ldc+i+1] + for j := range ci { + ci[j] = 0 + } + } + } else { + for i := 0; i < n; i++ { + ci := c[i*ldc : i*ldc+i+1] + c128.ScalUnitary(beta, ci) + } + } + } + return + } + + if trans == blas.NoTrans { + // Form C = alpha*A*A^T + beta*C. + if uplo == blas.Upper { + for i := 0; i < n; i++ { + ci := c[i*ldc+i : i*ldc+n] + ai := a[i*lda : i*lda+k] + for jc, cij := range ci { + j := i + jc + ci[jc] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k]) + } + } + } else { + for i := 0; i < n; i++ { + ci := c[i*ldc : i*ldc+i+1] + ai := a[i*lda : i*lda+k] + for j, cij := range ci { + ci[j] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k]) + } + } + } + } else { + // Form C = alpha*A^T*A + beta*C. + if uplo == blas.Upper { + for i := 0; i < n; i++ { + ci := c[i*ldc+i : i*ldc+n] + if beta == 0 { + for jc := range ci { + ci[jc] = 0 + } + } else if beta != 1 { + for jc := range ci { + ci[jc] *= beta + } + } + for j := 0; j < k; j++ { + aji := a[j*lda+i] + if aji != 0 { + c128.AxpyUnitary(alpha*aji, a[j*lda+i:j*lda+n], ci) + } + } + } + } else { + for i := 0; i < n; i++ { + ci := c[i*ldc : i*ldc+i+1] + if beta == 0 { + for j := range ci { + ci[j] = 0 + } + } else if beta != 1 { + for j := range ci { + ci[j] *= beta + } + } + for j := 0; j < k; j++ { + aji := a[j*lda+i] + if aji != 0 { + c128.AxpyUnitary(alpha*aji, a[j*lda:j*lda+i+1], ci) + } + } + } + } + } +} diff --git a/blas/gonum/level3cmplx128_test.go b/blas/gonum/level3cmplx128_test.go index 417434a8..772fc293 100644 --- a/blas/gonum/level3cmplx128_test.go +++ b/blas/gonum/level3cmplx128_test.go @@ -13,3 +13,7 @@ import ( func TestZgemm(t *testing.T) { testblas.ZgemmTest(t, impl) } + +func TestZsyrk(t *testing.T) { + testblas.ZsyrkTest(t, impl) +} diff --git a/blas/testblas/common.go b/blas/testblas/common.go index 89b36695..d34a0fd8 100644 --- a/blas/testblas/common.go +++ b/blas/testblas/common.go @@ -652,3 +652,55 @@ func zmm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, l } return r } + +// transString returns a string representation of blas.Transpose. +func transString(t blas.Transpose) string { + switch t { + case blas.NoTrans: + return "NoTrans" + case blas.Trans: + return "Trans" + case blas.ConjTrans: + return "ConjTrans" + } + return "unknown trans" +} + +// uploString returns a string representation of blas.Uplo. +func uploString(uplo blas.Uplo) string { + switch uplo { + case blas.Lower: + return "Lower" + case blas.Upper: + return "Upper" + } + return "unknown uplo" +} + +// zSameLowerTri returns whether n×n matrices A and B are same under the diagonal. +func zSameLowerTri(n int, a []complex128, lda int, b []complex128, ldb int) bool { + for i := 1; i < n; i++ { + for j := 0; j < i; j++ { + aij := a[i*lda+j] + bij := b[i*ldb+j] + if !sameComplex128(aij, bij) { + return false + } + } + } + return true +} + +// zSameUpperTri returns whether n×n matrices A and B are same above the diagonal. +func zSameUpperTri(n int, a []complex128, lda int, b []complex128, ldb int) bool { + for i := 0; i < n-1; i++ { + for j := i + 1; j < n; j++ { + aij := a[i*lda+j] + bij := b[i*ldb+j] + if !sameComplex128(aij, bij) { + return false + } + } + } + return true +} diff --git a/blas/testblas/zgemm.go b/blas/testblas/zgemm.go index c5c4e30d..f119b207 100644 --- a/blas/testblas/zgemm.go +++ b/blas/testblas/zgemm.go @@ -33,19 +33,6 @@ func ZgemmTest(t *testing.T, impl Zgemmer) { } } -// transString returns a string representation of blas.Transpose. -func transString(t blas.Transpose) string { - switch t { - case blas.NoTrans: - return "NoTrans" - case blas.Trans: - return "Trans" - case blas.ConjTrans: - return "ConjTrans" - } - return "unknown trans" -} - func zgemmTest(t *testing.T, impl Zgemmer, tA, tB blas.Transpose, m, n, k int) { const tol = 1e-13 diff --git a/blas/testblas/zsyrk.go b/blas/testblas/zsyrk.go new file mode 100644 index 00000000..bc124b2e --- /dev/null +++ b/blas/testblas/zsyrk.go @@ -0,0 +1,135 @@ +// Copyright ©2019 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testblas + +import ( + "fmt" + "testing" + + "golang.org/x/exp/rand" + + "gonum.org/v1/gonum/blas" +) + +type Zsyrker interface { + Zsyrk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int) +} + +func ZsyrkTest(t *testing.T, impl Zsyrker) { + for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { + for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} { + name := uploString(uplo) + "-" + transString(trans) + t.Run(name, func(t *testing.T) { + for _, n := range []int{0, 1, 2, 3, 4, 5} { + for _, k := range []int{0, 1, 2, 3, 4, 5, 7} { + zsyrkTest(t, impl, uplo, trans, n, k) + } + } + }) + } + } +} + +func zsyrkTest(t *testing.T, impl Zsyrker, uplo blas.Uplo, trans blas.Transpose, n, k int) { + const tol = 1e-13 + + rnd := rand.New(rand.NewSource(1)) + + rowA, colA := n, k + if trans == blas.Trans { + rowA, colA = k, n + } + for _, lda := range []int{max(1, colA), colA + 2} { + for _, ldc := range []int{max(1, n), n + 4} { + for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} { + for _, beta := range []complex128{0, 1, complex(1.3, -1.1)} { + // Allocate the matrix A and fill it with random numbers. + a := make([]complex128, rowA*lda) + for i := range a { + a[i] = rndComplex128(rnd) + } + // Create a copy of A for checking that + // Zsyrk does not modify A. + aCopy := make([]complex128, len(a)) + copy(aCopy, a) + + // Allocate the matrix C and fill it with random numbers. + c := make([]complex128, n*ldc) + for i := range c { + c[i] = rndComplex128(rnd) + } + // Create a copy of C for checking that + // Zsyrk does not modify its triangle + // opposite to uplo. + cCopy := make([]complex128, len(c)) + copy(cCopy, c) + // Create a copy of C expanded into a + // full symmetric matrix for computing + // the expected result using zmm. + cSym := make([]complex128, len(c)) + copy(cSym, c) + if uplo == blas.Upper { + for i := 0; i < n-1; i++ { + for j := i + 1; j < n; j++ { + cSym[j*ldc+i] = cSym[i*ldc+j] + } + } + } else { + for i := 1; i < n; i++ { + for j := 0; j < i; j++ { + cSym[j*ldc+i] = cSym[i*ldc+j] + } + } + } + + // Compute the expected result using an internal Zgemm implementation. + var want []complex128 + if trans == blas.NoTrans { + want = zmm(blas.NoTrans, blas.Trans, n, n, k, alpha, a, lda, a, lda, beta, cSym, ldc) + } else { + want = zmm(blas.Trans, blas.NoTrans, n, n, k, alpha, a, lda, a, lda, beta, cSym, ldc) + } + + // Compute the result using Zsyrk. + impl.Zsyrk(uplo, trans, n, k, alpha, a, lda, beta, c, ldc) + + prefix := fmt.Sprintf("n=%v,k=%v,lda=%v,ldc=%v,alpha=%v,beta=%v", n, k, lda, ldc, alpha, beta) + + if !zsame(a, aCopy) { + t.Errorf("%v: unexpected modification of A", prefix) + continue + } + if uplo == blas.Upper && !zSameLowerTri(n, c, ldc, cCopy, ldc) { + t.Errorf("%v: unexpected modification in lower triangle of C", prefix) + continue + } + if uplo == blas.Lower && !zSameUpperTri(n, c, ldc, cCopy, ldc) { + t.Errorf("%v: unexpected modification in upper triangle of C", prefix) + continue + } + + // Expand C into a full symmetric matrix + // for comparison with the result from zmm. + if uplo == blas.Upper { + for i := 0; i < n-1; i++ { + for j := i + 1; j < n; j++ { + c[j*ldc+i] = c[i*ldc+j] + } + } + } else { + for i := 1; i < n; i++ { + for j := 0; j < i; j++ { + c[j*ldc+i] = c[i*ldc+j] + } + } + } + if !zEqualApprox(c, want, tol) { + t.Errorf("%v: unexpected result", prefix) + } + } + } + } + } +}