From 578f42c8df25a400f7ddf1ac593befb488949e50 Mon Sep 17 00:00:00 2001 From: btracey Date: Tue, 5 May 2015 22:55:15 -0700 Subject: [PATCH] Rearrange Lapack to be like BLAS. Implement cholesky decomposition Responded to PR comments modified travis file Changed input and output types added back needed types by cgo Fixed perl script so it compiles Changes to genLapack to allow compilation Reinstate test-coverage.sh --- .travis.yml | 2 +- README.md | 5 +- {clapack => cgo}/genLapack.pl | 20 ++++-- {clapack => cgo}/lapacke.h | 0 {clapack => cgo}/lapacke_config.h | 0 {clapack => cgo}/lapacke_mangling.h | 0 {clapack => cgo}/lapacke_utils.h | 0 dla/dla_test.go | 61 ---------------- dla/impl.go | 9 --- dla/qr.go | 40 ----------- dla/svd.go | 61 ---------------- lapack.go | 21 ++---- lapack64/lapack64.go | 49 +++++++++++++ native/dpotf2.go | 65 +++++++++++++++++ native/dpotrf.go | 63 +++++++++++++++++ native/general.go | 31 ++++++++ native/lapack_test.go | 17 +++++ test-coverage.sh | 1 - testlapack/dpotf2.go | 103 +++++++++++++++++++++++++++ testlapack/dpotrf.go | 71 +++++++++++++++++++ zla/impl.go | 9 --- zla/qr.go | 40 ----------- zla/svd.go | 86 ----------------------- zla/zla_test.go | 105 ---------------------------- 24 files changed, 422 insertions(+), 437 deletions(-) rename {clapack => cgo}/genLapack.pl (95%) mode change 100755 => 100644 rename {clapack => cgo}/lapacke.h (100%) rename {clapack => cgo}/lapacke_config.h (100%) rename {clapack => cgo}/lapacke_mangling.h (100%) rename {clapack => cgo}/lapacke_utils.h (100%) delete mode 100644 dla/dla_test.go delete mode 100644 dla/impl.go delete mode 100644 dla/qr.go delete mode 100644 dla/svd.go create mode 100644 lapack64/lapack64.go create mode 100644 native/dpotf2.go create mode 100644 native/dpotrf.go create mode 100644 native/general.go create mode 100644 native/lapack_test.go create mode 100644 testlapack/dpotf2.go create mode 100644 testlapack/dpotrf.go delete mode 100644 zla/impl.go delete mode 100644 zla/qr.go delete mode 100644 zla/svd.go delete mode 100644 zla/zla_test.go diff --git a/.travis.yml b/.travis.yml index 1ccbdd4b..cccc6e87 100644 --- a/.travis.yml +++ b/.travis.yml @@ -39,7 +39,7 @@ install: - if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then export CGO_LDFLAGS="-L/usr/lib -lopenblas"; fi - go get github.com/gonum/blas - go get github.com/gonum/matrix/mat64 - - pushd clapack + - pushd cgo - if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then perl genLapack.pl -L/usr/lib -lopenblas; fi - popd diff --git a/README.md b/README.md index e18b6d5f..fe13f07f 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,8 @@ Gonum LAPACK [![Build Status](https://travis-ci.org/gonum/lapack.svg)](https:// ====== A collection of packages to provide LAPACK functionality for the Go programming -language (http://golang.org) - -This is work in progress. Breaking changes are likely to happen. +language (http://golang.org). This provides a partial implementation in native go +and a wrapper using cgo to a c-based implementation. ## Installation diff --git a/clapack/genLapack.pl b/cgo/genLapack.pl old mode 100755 new mode 100644 similarity index 95% rename from clapack/genLapack.pl rename to cgo/genLapack.pl index 8c15042e..07de5ef8 --- a/clapack/genLapack.pl +++ b/cgo/genLapack.pl @@ -57,12 +57,14 @@ const ( rowMajor order = 101 + iota colMajor ) - + func init() { _ = lapack.Complex128(Lapack{}) _ = lapack.Float64(Lapack{}) } +func isZero(ret C.int) bool { return ret == 0 } + EOH $/ = undef; @@ -110,11 +112,18 @@ our %typeConv = ( "LAPACK_C_SELECT2" => "Select2Complex64", "LAPACK_Z_SELECT1" => "Select1Complex128", "LAPACK_Z_SELECT2" => "Select2Complex128", - "void" => "" + "void" => "", + + "lapack_int_return_type" => "bool", + "lapack_int_return" => "isZero", + "float_return_type" => "float32", + "float_return" => "float32", + "double_return_type" => "float64", + "double_return" => "float64" ); foreach my $line (@lines) { - process($line); + process($line); } close($golapack); @@ -162,14 +171,15 @@ sub processProto { $gofunc = ucfirst $func; } - my $GoRet = $typeConv{$ret}; + my $GoRet = $typeConv{$ret."_return"}; + my $GoRetType = $typeConv{$ret."_return_type"}; my $complexType = $func; $complexType =~ s/.*_[isd]?([zc]).*/$1/; my ($params,$bp) = processParamToGo($func, $paramList, $complexType); if ($params eq "") { return } - print $golapack "func (Lapack) ".$gofunc."(".$params.") ".$GoRet."{\n"; + print $golapack "func (Lapack) ".$gofunc."(".$params.") ".$GoRetType."{\n"; print $golapack "\t"; if ($ret ne 'void') { print $golapack "\n".$bp."\n"."return ".$GoRet."("; diff --git a/clapack/lapacke.h b/cgo/lapacke.h similarity index 100% rename from clapack/lapacke.h rename to cgo/lapacke.h diff --git a/clapack/lapacke_config.h b/cgo/lapacke_config.h similarity index 100% rename from clapack/lapacke_config.h rename to cgo/lapacke_config.h diff --git a/clapack/lapacke_mangling.h b/cgo/lapacke_mangling.h similarity index 100% rename from clapack/lapacke_mangling.h rename to cgo/lapacke_mangling.h diff --git a/clapack/lapacke_utils.h b/cgo/lapacke_utils.h similarity index 100% rename from clapack/lapacke_utils.h rename to cgo/lapacke_utils.h diff --git a/dla/dla_test.go b/dla/dla_test.go deleted file mode 100644 index 4d976adb..00000000 --- a/dla/dla_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package dla - -import ( - "fmt" - "testing" - - "github.com/gonum/blas" - "github.com/gonum/blas/blas64" - "github.com/gonum/lapack/clapack" - "github.com/gonum/matrix/mat64" -) - -type fm struct { - *mat64.Dense - margin int -} - -func (m fm) Format(fs fmt.State, c rune) { - if c == 'v' && fs.Flag('#') { - fmt.Fprintf(fs, "%#v", m.Dense) - return - } - mat64.Format(m.Dense, m.margin, '.', fs, c) -} - -func init() { - Register(clapack.Lapack{}) -} - -func TestQR(t *testing.T) { - A := blas64.General{ - Rows: 3, - Cols: 2, - Stride: 2, - Data: []float64{1, 2, 3, 4, 5, 6}, - } - B := blas64.General{ - Rows: 3, - Cols: 2, - Stride: 2, - Data: []float64{1, 1, 1, 2, 2, 2}, - } - - tau := make([]float64, 2) - - C := blas64.General{Rows: 2, Cols: 2, Stride: 2, Data: make([]float64, 2*2)} - - blas64.Gemm(blas.Trans, blas.NoTrans, 1, A, B, 0, C) - - fmt.Println(C) - - f := QR(A, tau) - - fmt.Println(B) - fmt.Println(f) - - f.Solve(B) - var pm mat64.Dense - pm.SetRawMatrix(B) - fmt.Println(fm{&pm, 0}) -} diff --git a/dla/impl.go b/dla/impl.go deleted file mode 100644 index 3771495f..00000000 --- a/dla/impl.go +++ /dev/null @@ -1,9 +0,0 @@ -package dla - -import "github.com/gonum/lapack" - -var impl lapack.Float64 - -func Register(i lapack.Float64) { - impl = i -} diff --git a/dla/qr.go b/dla/qr.go deleted file mode 100644 index 80904f06..00000000 --- a/dla/qr.go +++ /dev/null @@ -1,40 +0,0 @@ -package dla - -import ( - "github.com/gonum/blas" - "github.com/gonum/blas/blas64" -) - -type QRFact struct { - a blas64.General - tau []float64 -} - -func QR(a blas64.General, tau []float64) QRFact { - impl.Dgeqrf(a.Rows, a.Cols, a.Data, a.Stride, tau) - return QRFact{a: a, tau: tau} -} - -func (f QRFact) R() blas64.Triangular { - n := f.a.Rows - if f.a.Cols < n { - n = f.a.Cols - } - return blas64.Triangular{ - Data: f.a.Data, - N: n, - Stride: f.a.Stride, - Uplo: blas.Upper, - Diag: blas.NonUnit, - } -} - -func (f QRFact) Solve(b blas64.General) blas64.General { - if f.a.Cols != b.Cols { - panic("dimension missmatch") - } - impl.Dormqr(blas.Left, blas.Trans, b.Rows, b.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, b.Data, b.Stride) - b.Rows = f.a.Cols - blas64.Trsm(blas.Left, blas.NoTrans, 1, f.R(), b) - return b -} diff --git a/dla/svd.go b/dla/svd.go deleted file mode 100644 index 4177d383..00000000 --- a/dla/svd.go +++ /dev/null @@ -1,61 +0,0 @@ -package dla - -import ( - "github.com/gonum/blas" - "github.com/gonum/blas/blas64" - "github.com/gonum/lapack" -) - -func SVD(A blas64.General) (U blas64.General, s []float64, Vt blas64.General) { - m := A.Rows - n := A.Cols - U.Stride = 1 - Vt.Stride = 1 - if m >= n { - Vt = blas64.General{ - Rows: n, - Cols: n, - Stride: n, - Data: make([]float64, n*n), - } - s = make([]float64, n) - U = A - } else { - U = blas64.General{ - Rows: m, - Cols: m, - Stride: m, - Data: make([]float64, n*n), - } - s = make([]float64, m) - Vt = A - } - - impl.Dgesdd(lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride) - - return -} - -func SVDbd(uplo blas.Uplo, d, e []float64) (U blas64.General, s []float64, Vt blas64.General) { - n := len(d) - if len(e) != n { - panic("dimensionality missmatch") - } - - U = blas64.General{ - Rows: n, - Cols: n, - Stride: n, - Data: make([]float64, n*n), - } - Vt = blas64.General{ - Rows: n, - Cols: n, - Stride: n, - Data: make([]float64, n*n), - } - - impl.Dbdsdc(uplo, lapack.Explicit, n, d, e, U.Data, U.Stride, Vt.Data, Vt.Stride, nil, nil) - s = d - return -} diff --git a/lapack.go b/lapack.go index 4257c241..5bc7b4b2 100644 --- a/lapack.go +++ b/lapack.go @@ -1,8 +1,6 @@ package lapack -import ( - "github.com/gonum/blas" -) +import "github.com/gonum/blas" const None = 'N' @@ -21,19 +19,10 @@ const ( Explicit (CompSV) = 'I' ) +// Float64 defines the float64 interface for the Lapack function. This interface +// contains the functions needed in the gonum suite. type Float64 interface { - Dgeqrf(m, n int, a []float64, lda int, tau []float64) int - Dormqr(s blas.Side, t blas.Transpose, m, n, k int, a []float64, lda int, tau []float64, c []float64, ldc int) int - Dgesdd(job Job, m, n int, a []float64, lda int, s []float64, u []float64, ldu int, vt []float64, ldvt int) int - Dgebrd(m, n int, a []float64, lda int, d, e, tauq, taup []float64) int - Dbdsdc(uplo blas.Uplo, compq CompSV, n int, d, e []float64, u []float64, ldu int, vt []float64, ldvt int, q []float64, iq []int32) int + Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) } -type Complex128 interface { - Float64 - - Zgeqrf(m, n int, a []complex128, lda int, tau []complex128) int - Zunmqr(s blas.Side, t blas.Transpose, m, n, k int, a []complex128, lda int, tau []complex128, c []complex128, ldc int) int - Zgesdd(job Job, m, n int, a []complex128, lda int, s []float64, u []complex128, ldu int, vt []complex128, ldvt int) int - Zgebrd(m, n int, a []complex128, lda int, d, e []float64, tauq, taup []complex128) int -} +type Complex128 interface{} diff --git a/lapack64/lapack64.go b/lapack64/lapack64.go new file mode 100644 index 00000000..fe13d587 --- /dev/null +++ b/lapack64/lapack64.go @@ -0,0 +1,49 @@ +// Copyright ©2015 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package lapack64 provides a set of convenient wrapper functions for LAPACK +// calls, as specified in the netlib standard (www.netlib.org). +// +// The native Go routines are used by default, and the Use function can be used +// to set an alternate implementation. +// +// If the type of matrix (General, Symmetric, etc.) is known and fixed, it is +// used in the wrapper signature. In many cases, however, the type of the matrix +// changes during the call to the routine, for example the matrix is symmetric on +// entry and is triangular on exit. In these cases the correct types should be checked +// in the documentation. +// +// The full set of Lapack functions is very large, and it is not clear that a +// full implementation is desirable, let alone feasible. Please open up an issue +// if there is a specific function you need and/or are willing to implement. +package lapack64 + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" + "github.com/gonum/lapack" + "github.com/gonum/lapack/native" +) + +var lapack64 lapack.Float64 = native.Implementation{} + +// Use sets the LAPACK float64 implementation to be used by subsequent BLAS calls. +// The default implementation is native.Implementation. +func Use(l lapack.Float64) { + lapack64 = l +} + +// Potrf computes the cholesky factorization of a. +// A = U^T * U if ul == blas.Upper +// A = L * L^T if ul == blas.Lower +// The underlying data between the input matrix and output matrix is shared. +func Potrf(a blas64.Symmetric) (t blas64.Triangular, ok bool) { + ok = lapack64.Dpotrf(a.Uplo, a.N, a.Data, a.Stride) + t.Uplo = a.Uplo + t.N = a.N + t.Data = a.Data + t.Stride = a.Stride + t.Diag = blas.NonUnit + return +} diff --git a/native/dpotf2.go b/native/dpotf2.go new file mode 100644 index 00000000..28a5ca2a --- /dev/null +++ b/native/dpotf2.go @@ -0,0 +1,65 @@ +package native + +import ( + "math" + + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Dpotrf computes the cholesky decomposition of the symmetric positive definite +// matrix a. If ul == blas.Upper, then a is stored as an upper-triangular matrix, +// and a = U^T U is stored in place into a. If ul == blas.Lower, then a = L L^T +// is computed and stored in-place into a. If a is not positive definite, false +// is returned. This is the unblocked version of the algorithm. +func (Implementation) Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool) { + if ul != blas.Upper && ul != blas.Lower { + panic(badUplo) + } + if n < 0 { + panic(nLT0) + } + if lda < n { + panic(badLdA) + } + if n == 0 { + return true + } + bi := blas64.Implementation() + if ul == blas.Upper { + for j := 0; j < n; j++ { + ajj := a[j*lda+j] + if j != 0 { + ajj -= bi.Ddot(j, a[j:], lda, a[j:], lda) + } + if ajj <= 0 || math.IsNaN(ajj) { + a[j*lda+j] = ajj + return false + } + ajj = math.Sqrt(ajj) + a[j*lda+j] = ajj + if j < n-1 { + bi.Dgemv(blas.Trans, j, n-j-1, -1, a[j+1:], lda, a[j:], lda, 1, a[j*lda+j+1:], 1) + bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1) + } + } + return true + } + for j := 0; j < n; j++ { + ajj := a[j*lda+j] + if j != 0 { + ajj -= bi.Ddot(j, a[j*lda:], 1, a[j*lda:], 1) + } + if ajj <= 0 || math.IsNaN(ajj) { + a[j*lda+j] = ajj + return false + } + ajj = math.Sqrt(ajj) + a[j*lda+j] = ajj + if j < n-1 { + bi.Dgemv(blas.NoTrans, n-j-1, j, -1, a[(j+1)*lda:], lda, a[j*lda:], 1, 1, a[(j+1)*lda+j:], lda) + bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda) + } + } + return true +} diff --git a/native/dpotrf.go b/native/dpotrf.go new file mode 100644 index 00000000..568cc0a1 --- /dev/null +++ b/native/dpotrf.go @@ -0,0 +1,63 @@ +package native + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Dpotrf computes the cholesky decomposition of the symmetric positive definite +// matrix a. If ul == blas.Upper, then a is stored as an upper-triangular matrix, +// and a = U U^T is stored in place into a. If ul == blas.Lower, then a = L L^T +// is computed and stored in-place into a. If a is not positive definite, false +// is returned. This is the blocked version of the algorithm. +func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) { + bi := blas64.Implementation() + if ul != blas.Upper && ul != blas.Lower { + panic(badUplo) + } + if n < 0 { + panic(nLT0) + } + if lda < n { + panic(badLdA) + } + if n == 0 { + return true + } + nb := blockSize() + if n <= nb { + return impl.Dpotf2(ul, n, a, lda) + } + if ul == blas.Upper { + for j := 0; j < n; j += nb { + jb := min(nb, n-j) + bi.Dsyrk(blas.Upper, blas.Trans, jb, j, -1, a[j:], lda, 1, a[j*lda+j:], lda) + ok = impl.Dpotf2(blas.Upper, jb, a[j*lda+j:], lda) + if !ok { + return ok + } + if j+jb < n { + bi.Dgemm(blas.Trans, blas.NoTrans, jb, n-j-jb, j, -1, + a[j:], lda, a[j+jb:], lda, 1, a[j*lda+j+jb:], lda) + bi.Dtrsm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit, jb, n-j-jb, 1, + a[j*lda+j:], lda, a[j*lda+j+jb:], lda) + } + } + return true + } + for j := 0; j < n; j += nb { + jb := min(nb, n-j) + bi.Dsyrk(blas.Lower, blas.NoTrans, jb, j, -1, a[j*lda:], lda, 1, a[j*lda+j:], lda) + ok := impl.Dpotf2(blas.Lower, jb, a[j*lda+j:], lda) + if !ok { + return ok + } + if j+jb < n { + bi.Dgemm(blas.NoTrans, blas.Trans, n-j-jb, jb, j, -1, + a[(j+jb)*lda:], lda, a[j*lda:], lda, 1, a[(j+jb)*lda+j:], lda) + bi.Dtrsm(blas.Right, blas.Lower, blas.Trans, blas.NonUnit, n-j-jb, jb, 1, + a[j*lda+j:], lda, a[(j+jb)*lda+j:], lda) + } + } + return true +} diff --git a/native/general.go b/native/general.go new file mode 100644 index 00000000..8ed45786 --- /dev/null +++ b/native/general.go @@ -0,0 +1,31 @@ +package native + +import "github.com/gonum/lapack" + +type Implementation struct{} + +var _ lapack.Float64 = Implementation{} + +const ( + badUplo = "lapack: illegal triangle" + nLT0 = "lapack: n < 0" + badLdA = "lapack: index of a out of range" +) + +func blockSize() int { + return 64 +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/native/lapack_test.go b/native/lapack_test.go new file mode 100644 index 00000000..1ae051b0 --- /dev/null +++ b/native/lapack_test.go @@ -0,0 +1,17 @@ +package native + +import ( + "testing" + + "github.com/gonum/lapack/testlapack" +) + +var impl = Implementation{} + +func TestDpotf2(t *testing.T) { + testlapack.Dpotf2Test(t, impl) +} + +func TestDpotrf(t *testing.T) { + testlapack.DpotrfTest(t, impl) +} diff --git a/test-coverage.sh b/test-coverage.sh index 3df03f52..1f276049 100644 --- a/test-coverage.sh +++ b/test-coverage.sh @@ -40,4 +40,3 @@ then fi rm -rf ./profile.out -rm -rf ./acc.out diff --git a/testlapack/dpotf2.go b/testlapack/dpotf2.go new file mode 100644 index 00000000..6bdb85b6 --- /dev/null +++ b/testlapack/dpotf2.go @@ -0,0 +1,103 @@ +package testlapack + +import ( + "testing" + + "github.com/gonum/blas" + "github.com/gonum/floats" +) + +type Dpotf2er interface { + Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool) +} + +func Dpotf2Test(t *testing.T, impl Dpotf2er) { + for _, test := range []struct { + a [][]float64 + ul blas.Uplo + pos bool + U [][]float64 + }{ + { + a: [][]float64{ + {23, 37, 34, 32}, + {108, 71, 48, 48}, + {109, 109, 67, 58}, + {106, 107, 106, 63}, + }, + pos: true, + U: [][]float64{ + {4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393}, + {0, 3.387958215439679, -1.976308959006481, -1.026654004678691}, + {0, 0, 3.582364210034111, 2.419258947036024}, + {0, 0, 0, 3.401680257083044}, + }, + }, + } { + testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper) + testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper) + aT := transpose(test.a) + L := transpose(test.U) + testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower) + testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower) + } +} + +func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) { + aFlat := flattenTri(a, stride, ul) + ansFlat := flattenTri(ans, stride, ul) + pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride) + if pos != testPos { + t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos) + return + } + if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) { + t.Errorf("Result mismatch: Want %v, Got %v", ansFlat, aFlat) + } +} + +// flattenTri with a certain stride. stride must be >= dimension. Puts repeatable +// nonce values in non-accessed places +func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 { + m := len(a) + n := len(a[0]) + if stride < n { + panic("bad stride") + } + upper := ul == blas.Upper + v := make([]float64, m*stride) + count := 1000.0 + for i := 0; i < m; i++ { + for j := 0; j < stride; j++ { + if j >= n || (upper && j < i) || (!upper && j > i) { + // not accessed, so give a unique crazy number + v[i*stride+j] = count + count++ + continue + } + v[i*stride+j] = a[i][j] + } + } + return v +} + +func transpose(a [][]float64) [][]float64 { + m := len(a) + n := len(a[0]) + if m != n { + panic("not square") + } + aNew := make([][]float64, m) + for i := 0; i < m; i++ { + aNew[i] = make([]float64, n) + } + for i := 0; i < m; i++ { + if len(a[i]) != n { + panic("bad n size") + } + for j := 0; j < n; j++ { + aNew[j][i] = a[i][j] + } + } + return aNew +} diff --git a/testlapack/dpotrf.go b/testlapack/dpotrf.go new file mode 100644 index 00000000..1ff30411 --- /dev/null +++ b/testlapack/dpotrf.go @@ -0,0 +1,71 @@ +// Copyright ©2015 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testlapack + +import ( + "math/rand" + "testing" + + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" + "github.com/gonum/floats" +) + +type Dpotrfer interface { + Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) +} + +func DpotrfTest(t *testing.T, impl Dpotrfer) { + bi := blas64.Implementation() + for i, test := range []struct { + n int + }{ + {n: 10}, + {n: 30}, + {n: 63}, + {n: 65}, + {n: 128}, + {n: 1000}, + } { + n := test.n + // Construct a positive-definite symmetric matrix + base := make([]float64, n*n) + for i := range base { + base[i] = rand.Float64() + } + a := make([]float64, len(base)) + bi.Dgemm(blas.Trans, blas.NoTrans, n, n, n, 1, base, n, base, n, 0, a, n) + + aCopy := make([]float64, len(a)) + copy(aCopy, a) + + // Test with Upper + impl.Dpotrf(blas.Upper, n, a, n) + + // zero all the other elements + for i := 0; i < n; i++ { + for j := 0; j < i; j++ { + a[i*n+j] = 0 + } + } + // Multiply u^T * u + ans := make([]float64, len(a)) + bi.Dsyrk(blas.Upper, blas.Trans, n, n, 1, a, n, 0, ans, n) + + match := true + for i := 0; i < n; i++ { + for j := i; j < n; j++ { + if !floats.EqualWithinAbsOrRel(ans[i*n+j], aCopy[i*n+j], 1e-14, 1e-14) { + match = false + } + } + } + if !match { + //fmt.Println(aCopy) + //fmt.Println(ans) + t.Errorf("Case %v: Mismatch for upper", i) + } + } +} diff --git a/zla/impl.go b/zla/impl.go deleted file mode 100644 index 63c6fe05..00000000 --- a/zla/impl.go +++ /dev/null @@ -1,9 +0,0 @@ -package zla - -import "github.com/gonum/lapack" - -var impl lapack.Complex128 - -func Register(i lapack.Complex128) { - impl = i -} diff --git a/zla/qr.go b/zla/qr.go deleted file mode 100644 index cb8b728a..00000000 --- a/zla/qr.go +++ /dev/null @@ -1,40 +0,0 @@ -package zla - -import ( - "github.com/gonum/blas" - "github.com/gonum/blas/cblas128" -) - -type QRFact struct { - a cblas128.General - tau []complex128 -} - -func QR(A cblas128.General, tau []complex128) QRFact { - impl.Zgeqrf(A.Rows, A.Cols, A.Data, A.Stride, tau) - return QRFact{A, tau} -} - -func (f QRFact) R() cblas128.Triangular { - n := f.a.Rows - if f.a.Cols < n { - n = f.a.Cols - } - return cblas128.Triangular{ - Data: f.a.Data, - N: n, - Stride: f.a.Stride, - Uplo: blas.Upper, - Diag: blas.NonUnit, - } -} - -func (f QRFact) Solve(B cblas128.General) cblas128.General { - if f.a.Cols != B.Cols { - panic("dimension missmatch") - } - impl.Zunmqr(blas.Left, blas.ConjTrans, f.a.Rows, B.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, B.Data, B.Stride) - B.Rows = f.a.Cols - cblas128.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B) - return B -} diff --git a/zla/svd.go b/zla/svd.go deleted file mode 100644 index 2762e9df..00000000 --- a/zla/svd.go +++ /dev/null @@ -1,86 +0,0 @@ -package zla - -import ( - "github.com/gonum/blas" - "github.com/gonum/blas/cblas128" - "github.com/gonum/lapack" -) - -func SVD(A cblas128.General) (U cblas128.General, s []float64, Vt cblas128.General) { - m := A.Rows - n := A.Cols - U.Stride = 1 - Vt.Stride = 1 - if m >= n { - Vt = cblas128.General{Rows: n, Cols: n, Stride: n, Data: make([]complex128, n*n)} - s = make([]float64, n) - U = A - } else { - U = cblas128.General{Rows: n, Cols: n, Stride: n, Data: make([]complex128, n*n)} - s = make([]float64, m) - Vt = A - } - - impl.Zgesdd(lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride) - - return -} - -func c128col(i int, a cblas128.General) cblas128.Vector { - return cblas128.Vector{ - Inc: a.Stride, - Data: a.Data[i:], - } -} - -//Lanczos bidiagonalization with full reorthogonalization -func LanczosBi(L cblas128.General, u []complex128, numIter int) (U cblas128.General, V cblas128.General, a []float64, b []float64) { - - m := L.Rows - n := L.Cols - - uv := cblas128.Vector{Inc: 1, Data: u} - cblas128.Scal(len(u), complex(1/cblas128.Nrm2(len(u), uv), 0), uv) - - U = cblas128.General{Rows: m, Cols: numIter, Stride: numIter, Data: make([]complex128, m*numIter)} - V = cblas128.General{Rows: n, Cols: numIter, Stride: numIter, Data: make([]complex128, n*numIter)} - - a = make([]float64, numIter) - b = make([]float64, numIter) - - cblas128.Copy(len(u), uv, c128col(0, U)) - - tr := cblas128.Vector{Inc: 1, Data: make([]complex128, n)} - cblas128.Gemv(blas.ConjTrans, 1, L, uv, 0, tr) - a[0] = cblas128.Nrm2(n, tr) - cblas128.Copy(n, tr, c128col(0, V)) - cblas128.Scal(n, complex(1/a[0], 0), c128col(0, V)) - - tl := cblas128.Vector{Inc: 1, Data: make([]complex128, m)} - for k := 0; k < numIter-1; k++ { - cblas128.Copy(m, c128col(k, U), tl) - cblas128.Scal(m, complex(-a[k], 0), tl) - cblas128.Gemv(blas.NoTrans, 1, L, c128col(k, V), 1, tl) - - for i := 0; i <= k; i++ { - cblas128.Axpy(m, -cblas128.Dotc(m, c128col(i, U), tl), c128col(i, U), tl) - } - - b[k] = cblas128.Nrm2(m, tl) - cblas128.Copy(m, tl, c128col(k+1, U)) - cblas128.Scal(m, complex(1/b[k], 0), c128col(k+1, U)) - - cblas128.Copy(n, c128col(k, V), tr) - cblas128.Scal(n, complex(-b[k], 0), tr) - cblas128.Gemv(blas.ConjTrans, 1, L, c128col(k+1, U), 1, tr) - - for i := 0; i <= k; i++ { - cblas128.Axpy(n, -cblas128.Dotc(n, c128col(i, V), tr), c128col(i, V), tr) - } - - a[k+1] = cblas128.Nrm2(n, tr) - cblas128.Copy(n, tr, c128col(k+1, V)) - cblas128.Scal(n, complex(1/a[k+1], 0), c128col(k+1, V)) - } - return -} diff --git a/zla/zla_test.go b/zla/zla_test.go deleted file mode 100644 index 204ecf1b..00000000 --- a/zla/zla_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package zla - -import ( - "fmt" - "math" - "math/rand" - "testing" - - "github.com/gonum/blas" - "github.com/gonum/blas/blas64" - "github.com/gonum/blas/cblas128" - "github.com/gonum/lapack/clapack" - "github.com/gonum/lapack/dla" -) - -func init() { - Register(clapack.Lapack{}) - dla.Register(clapack.Lapack{}) -} - -func fillRandn(a []complex128, mu complex128, sigmaSq float64) { - fact := math.Sqrt(0.5 * sigmaSq) - for i := range a { - a[i] = complex(fact*rand.NormFloat64(), fact*rand.NormFloat64()) + mu - } -} - -func TestQR(t *testing.T) { - A := cblas128.General{ - Rows: 3, - Cols: 2, - Stride: 2, - Data: []complex128{complex(1, 0), complex(2, 0), complex(3, 0), complex(4, 0), complex(5, 0), complex(6, 0)}, - } - B := cblas128.General{ - Rows: 3, - Cols: 2, - Stride: 2, - Data: []complex128{complex(1, 0), complex(1, 0), complex(1, 0), complex(2, 0), complex(2, 0), complex(2, 0)}, - } - - tau := make([]complex128, 2) - - f := QR(A, tau) - - //fmt.Println(B) - f.Solve(B) - //fmt.Println(B) -} - -func f64col(i int, a blas64.General) blas64.Vector { - return blas64.Vector{ - Inc: a.Stride, - Data: a.Data[i:], - } -} - -func TestLanczos(t *testing.T) { - A := cblas128.General{Rows: 3, Cols: 4, Stride: 4, Data: make([]complex128, 3*4)} - fillRandn(A.Data, 0, 1) - - Acpy := cblas128.General{Rows: 3, Cols: 4, Stride: 4, Data: make([]complex128, 3*4)} - copy(Acpy.Data, A.Data) - - u0 := make([]complex128, 3) - fillRandn(u0, 0, 1) - - Ul, Vl, a, b := LanczosBi(Acpy, u0, 3) - - fmt.Println(a, b) - - tmpc := cblas128.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]complex128, 3*3)} - bidic := cblas128.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]complex128, 3*3)} - - cblas128.Gemm(blas.NoTrans, blas.NoTrans, 1, A, Vl, 0, tmpc) - cblas128.Gemm(blas.ConjTrans, blas.NoTrans, 1, Ul, tmpc, 0, bidic) - - fmt.Println(bidic) - - Ur, s, Vr := dla.SVDbd(blas.Lower, a, b) - - tmp := blas64.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]float64, 3*3)} - bidi := blas64.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]float64, 3*3)} - - copy(tmp.Data, Ur.Data) - for i := 0; i < 3; i++ { - blas64.Scal(3, s[i], f64col(i, tmp)) - } - - blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, Vr, 0, bidi) - - fmt.Println(bidi) - /* - - _ = Ul - _ = Vl - Uc := zbw.NewGeneral( 3, 3, nil) - zbw.Real2Cmplx(Ur.Data[:3*3], Uc.Data) - - fmt.Println(Uc.Data) - - U := zbw.NewGeneral( M, K, nil) - zbw.Gemm(blas.NoTrans, blas.NoTrans, 1, U1, Uc, 0, U) - */ -}