diff --git a/.travis.yml b/.travis.yml index 1ccbdd4b..cccc6e87 100644 --- a/.travis.yml +++ b/.travis.yml @@ -39,7 +39,7 @@ install: - if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then export CGO_LDFLAGS="-L/usr/lib -lopenblas"; fi - go get github.com/gonum/blas - go get github.com/gonum/matrix/mat64 - - pushd clapack + - pushd cgo - if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then perl genLapack.pl -L/usr/lib -lopenblas; fi - popd diff --git a/README.md b/README.md index e18b6d5f..fe13f07f 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,8 @@ Gonum LAPACK [![Build Status](https://travis-ci.org/gonum/lapack.svg)](https:// ====== A collection of packages to provide LAPACK functionality for the Go programming -language (http://golang.org) - -This is work in progress. Breaking changes are likely to happen. +language (http://golang.org). This provides a partial implementation in native go +and a wrapper using cgo to a c-based implementation. ## Installation diff --git a/clapack/genLapack.pl b/cgo/genLapack.pl old mode 100755 new mode 100644 similarity index 95% rename from clapack/genLapack.pl rename to cgo/genLapack.pl index 8c15042e..07de5ef8 --- a/clapack/genLapack.pl +++ b/cgo/genLapack.pl @@ -57,12 +57,14 @@ const ( rowMajor order = 101 + iota colMajor ) - + func init() { _ = lapack.Complex128(Lapack{}) _ = lapack.Float64(Lapack{}) } +func isZero(ret C.int) bool { return ret == 0 } + EOH $/ = undef; @@ -110,11 +112,18 @@ our %typeConv = ( "LAPACK_C_SELECT2" => "Select2Complex64", "LAPACK_Z_SELECT1" => "Select1Complex128", "LAPACK_Z_SELECT2" => "Select2Complex128", - "void" => "" + "void" => "", + + "lapack_int_return_type" => "bool", + "lapack_int_return" => "isZero", + "float_return_type" => "float32", + "float_return" => "float32", + "double_return_type" => "float64", + "double_return" => "float64" ); foreach my $line (@lines) { - process($line); + process($line); } close($golapack); @@ -162,14 +171,15 @@ sub processProto { $gofunc = ucfirst $func; } - my $GoRet = $typeConv{$ret}; + my $GoRet = $typeConv{$ret."_return"}; + my $GoRetType = $typeConv{$ret."_return_type"}; my $complexType = $func; $complexType =~ s/.*_[isd]?([zc]).*/$1/; my ($params,$bp) = processParamToGo($func, $paramList, $complexType); if ($params eq "") { return } - print $golapack "func (Lapack) ".$gofunc."(".$params.") ".$GoRet."{\n"; + print $golapack "func (Lapack) ".$gofunc."(".$params.") ".$GoRetType."{\n"; print $golapack "\t"; if ($ret ne 'void') { print $golapack "\n".$bp."\n"."return ".$GoRet."("; diff --git a/clapack/lapacke.h b/cgo/lapacke.h similarity index 100% rename from clapack/lapacke.h rename to cgo/lapacke.h diff --git a/clapack/lapacke_config.h b/cgo/lapacke_config.h similarity index 100% rename from clapack/lapacke_config.h rename to cgo/lapacke_config.h diff --git a/clapack/lapacke_mangling.h b/cgo/lapacke_mangling.h similarity index 100% rename from clapack/lapacke_mangling.h rename to cgo/lapacke_mangling.h diff --git a/clapack/lapacke_utils.h b/cgo/lapacke_utils.h similarity index 100% rename from clapack/lapacke_utils.h rename to cgo/lapacke_utils.h diff --git a/dla/dla_test.go b/dla/dla_test.go deleted file mode 100644 index 4d976adb..00000000 --- a/dla/dla_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package dla - -import ( - "fmt" - "testing" - - "github.com/gonum/blas" - "github.com/gonum/blas/blas64" - "github.com/gonum/lapack/clapack" - "github.com/gonum/matrix/mat64" -) - -type fm struct { - *mat64.Dense - margin int -} - -func (m fm) Format(fs fmt.State, c rune) { - if c == 'v' && fs.Flag('#') { - fmt.Fprintf(fs, "%#v", m.Dense) - return - } - mat64.Format(m.Dense, m.margin, '.', fs, c) -} - -func init() { - Register(clapack.Lapack{}) -} - -func TestQR(t *testing.T) { - A := blas64.General{ - Rows: 3, - Cols: 2, - Stride: 2, - Data: []float64{1, 2, 3, 4, 5, 6}, - } - B := blas64.General{ - Rows: 3, - Cols: 2, - Stride: 2, - Data: []float64{1, 1, 1, 2, 2, 2}, - } - - tau := make([]float64, 2) - - C := blas64.General{Rows: 2, Cols: 2, Stride: 2, Data: make([]float64, 2*2)} - - blas64.Gemm(blas.Trans, blas.NoTrans, 1, A, B, 0, C) - - fmt.Println(C) - - f := QR(A, tau) - - fmt.Println(B) - fmt.Println(f) - - f.Solve(B) - var pm mat64.Dense - pm.SetRawMatrix(B) - fmt.Println(fm{&pm, 0}) -} diff --git a/dla/impl.go b/dla/impl.go deleted file mode 100644 index 3771495f..00000000 --- a/dla/impl.go +++ /dev/null @@ -1,9 +0,0 @@ -package dla - -import "github.com/gonum/lapack" - -var impl lapack.Float64 - -func Register(i lapack.Float64) { - impl = i -} diff --git a/dla/qr.go b/dla/qr.go deleted file mode 100644 index 80904f06..00000000 --- a/dla/qr.go +++ /dev/null @@ -1,40 +0,0 @@ -package dla - -import ( - "github.com/gonum/blas" - "github.com/gonum/blas/blas64" -) - -type QRFact struct { - a blas64.General - tau []float64 -} - -func QR(a blas64.General, tau []float64) QRFact { - impl.Dgeqrf(a.Rows, a.Cols, a.Data, a.Stride, tau) - return QRFact{a: a, tau: tau} -} - -func (f QRFact) R() blas64.Triangular { - n := f.a.Rows - if f.a.Cols < n { - n = f.a.Cols - } - return blas64.Triangular{ - Data: f.a.Data, - N: n, - Stride: f.a.Stride, - Uplo: blas.Upper, - Diag: blas.NonUnit, - } -} - -func (f QRFact) Solve(b blas64.General) blas64.General { - if f.a.Cols != b.Cols { - panic("dimension missmatch") - } - impl.Dormqr(blas.Left, blas.Trans, b.Rows, b.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, b.Data, b.Stride) - b.Rows = f.a.Cols - blas64.Trsm(blas.Left, blas.NoTrans, 1, f.R(), b) - return b -} diff --git a/dla/svd.go b/dla/svd.go deleted file mode 100644 index 4177d383..00000000 --- a/dla/svd.go +++ /dev/null @@ -1,61 +0,0 @@ -package dla - -import ( - "github.com/gonum/blas" - "github.com/gonum/blas/blas64" - "github.com/gonum/lapack" -) - -func SVD(A blas64.General) (U blas64.General, s []float64, Vt blas64.General) { - m := A.Rows - n := A.Cols - U.Stride = 1 - Vt.Stride = 1 - if m >= n { - Vt = blas64.General{ - Rows: n, - Cols: n, - Stride: n, - Data: make([]float64, n*n), - } - s = make([]float64, n) - U = A - } else { - U = blas64.General{ - Rows: m, - Cols: m, - Stride: m, - Data: make([]float64, n*n), - } - s = make([]float64, m) - Vt = A - } - - impl.Dgesdd(lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride) - - return -} - -func SVDbd(uplo blas.Uplo, d, e []float64) (U blas64.General, s []float64, Vt blas64.General) { - n := len(d) - if len(e) != n { - panic("dimensionality missmatch") - } - - U = blas64.General{ - Rows: n, - Cols: n, - Stride: n, - Data: make([]float64, n*n), - } - Vt = blas64.General{ - Rows: n, - Cols: n, - Stride: n, - Data: make([]float64, n*n), - } - - impl.Dbdsdc(uplo, lapack.Explicit, n, d, e, U.Data, U.Stride, Vt.Data, Vt.Stride, nil, nil) - s = d - return -} diff --git a/lapack.go b/lapack.go index 4257c241..5bc7b4b2 100644 --- a/lapack.go +++ b/lapack.go @@ -1,8 +1,6 @@ package lapack -import ( - "github.com/gonum/blas" -) +import "github.com/gonum/blas" const None = 'N' @@ -21,19 +19,10 @@ const ( Explicit (CompSV) = 'I' ) +// Float64 defines the float64 interface for the Lapack function. This interface +// contains the functions needed in the gonum suite. type Float64 interface { - Dgeqrf(m, n int, a []float64, lda int, tau []float64) int - Dormqr(s blas.Side, t blas.Transpose, m, n, k int, a []float64, lda int, tau []float64, c []float64, ldc int) int - Dgesdd(job Job, m, n int, a []float64, lda int, s []float64, u []float64, ldu int, vt []float64, ldvt int) int - Dgebrd(m, n int, a []float64, lda int, d, e, tauq, taup []float64) int - Dbdsdc(uplo blas.Uplo, compq CompSV, n int, d, e []float64, u []float64, ldu int, vt []float64, ldvt int, q []float64, iq []int32) int + Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) } -type Complex128 interface { - Float64 - - Zgeqrf(m, n int, a []complex128, lda int, tau []complex128) int - Zunmqr(s blas.Side, t blas.Transpose, m, n, k int, a []complex128, lda int, tau []complex128, c []complex128, ldc int) int - Zgesdd(job Job, m, n int, a []complex128, lda int, s []float64, u []complex128, ldu int, vt []complex128, ldvt int) int - Zgebrd(m, n int, a []complex128, lda int, d, e []float64, tauq, taup []complex128) int -} +type Complex128 interface{} diff --git a/lapack64/lapack64.go b/lapack64/lapack64.go new file mode 100644 index 00000000..fe13d587 --- /dev/null +++ b/lapack64/lapack64.go @@ -0,0 +1,49 @@ +// Copyright ©2015 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package lapack64 provides a set of convenient wrapper functions for LAPACK +// calls, as specified in the netlib standard (www.netlib.org). +// +// The native Go routines are used by default, and the Use function can be used +// to set an alternate implementation. +// +// If the type of matrix (General, Symmetric, etc.) is known and fixed, it is +// used in the wrapper signature. In many cases, however, the type of the matrix +// changes during the call to the routine, for example the matrix is symmetric on +// entry and is triangular on exit. In these cases the correct types should be checked +// in the documentation. +// +// The full set of Lapack functions is very large, and it is not clear that a +// full implementation is desirable, let alone feasible. Please open up an issue +// if there is a specific function you need and/or are willing to implement. +package lapack64 + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" + "github.com/gonum/lapack" + "github.com/gonum/lapack/native" +) + +var lapack64 lapack.Float64 = native.Implementation{} + +// Use sets the LAPACK float64 implementation to be used by subsequent BLAS calls. +// The default implementation is native.Implementation. +func Use(l lapack.Float64) { + lapack64 = l +} + +// Potrf computes the cholesky factorization of a. +// A = U^T * U if ul == blas.Upper +// A = L * L^T if ul == blas.Lower +// The underlying data between the input matrix and output matrix is shared. +func Potrf(a blas64.Symmetric) (t blas64.Triangular, ok bool) { + ok = lapack64.Dpotrf(a.Uplo, a.N, a.Data, a.Stride) + t.Uplo = a.Uplo + t.N = a.N + t.Data = a.Data + t.Stride = a.Stride + t.Diag = blas.NonUnit + return +} diff --git a/native/dpotf2.go b/native/dpotf2.go new file mode 100644 index 00000000..28a5ca2a --- /dev/null +++ b/native/dpotf2.go @@ -0,0 +1,65 @@ +package native + +import ( + "math" + + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Dpotrf computes the cholesky decomposition of the symmetric positive definite +// matrix a. If ul == blas.Upper, then a is stored as an upper-triangular matrix, +// and a = U^T U is stored in place into a. If ul == blas.Lower, then a = L L^T +// is computed and stored in-place into a. If a is not positive definite, false +// is returned. This is the unblocked version of the algorithm. +func (Implementation) Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool) { + if ul != blas.Upper && ul != blas.Lower { + panic(badUplo) + } + if n < 0 { + panic(nLT0) + } + if lda < n { + panic(badLdA) + } + if n == 0 { + return true + } + bi := blas64.Implementation() + if ul == blas.Upper { + for j := 0; j < n; j++ { + ajj := a[j*lda+j] + if j != 0 { + ajj -= bi.Ddot(j, a[j:], lda, a[j:], lda) + } + if ajj <= 0 || math.IsNaN(ajj) { + a[j*lda+j] = ajj + return false + } + ajj = math.Sqrt(ajj) + a[j*lda+j] = ajj + if j < n-1 { + bi.Dgemv(blas.Trans, j, n-j-1, -1, a[j+1:], lda, a[j:], lda, 1, a[j*lda+j+1:], 1) + bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1) + } + } + return true + } + for j := 0; j < n; j++ { + ajj := a[j*lda+j] + if j != 0 { + ajj -= bi.Ddot(j, a[j*lda:], 1, a[j*lda:], 1) + } + if ajj <= 0 || math.IsNaN(ajj) { + a[j*lda+j] = ajj + return false + } + ajj = math.Sqrt(ajj) + a[j*lda+j] = ajj + if j < n-1 { + bi.Dgemv(blas.NoTrans, n-j-1, j, -1, a[(j+1)*lda:], lda, a[j*lda:], 1, 1, a[(j+1)*lda+j:], lda) + bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda) + } + } + return true +} diff --git a/native/dpotrf.go b/native/dpotrf.go new file mode 100644 index 00000000..568cc0a1 --- /dev/null +++ b/native/dpotrf.go @@ -0,0 +1,63 @@ +package native + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Dpotrf computes the cholesky decomposition of the symmetric positive definite +// matrix a. If ul == blas.Upper, then a is stored as an upper-triangular matrix, +// and a = U U^T is stored in place into a. If ul == blas.Lower, then a = L L^T +// is computed and stored in-place into a. If a is not positive definite, false +// is returned. This is the blocked version of the algorithm. +func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) { + bi := blas64.Implementation() + if ul != blas.Upper && ul != blas.Lower { + panic(badUplo) + } + if n < 0 { + panic(nLT0) + } + if lda < n { + panic(badLdA) + } + if n == 0 { + return true + } + nb := blockSize() + if n <= nb { + return impl.Dpotf2(ul, n, a, lda) + } + if ul == blas.Upper { + for j := 0; j < n; j += nb { + jb := min(nb, n-j) + bi.Dsyrk(blas.Upper, blas.Trans, jb, j, -1, a[j:], lda, 1, a[j*lda+j:], lda) + ok = impl.Dpotf2(blas.Upper, jb, a[j*lda+j:], lda) + if !ok { + return ok + } + if j+jb < n { + bi.Dgemm(blas.Trans, blas.NoTrans, jb, n-j-jb, j, -1, + a[j:], lda, a[j+jb:], lda, 1, a[j*lda+j+jb:], lda) + bi.Dtrsm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit, jb, n-j-jb, 1, + a[j*lda+j:], lda, a[j*lda+j+jb:], lda) + } + } + return true + } + for j := 0; j < n; j += nb { + jb := min(nb, n-j) + bi.Dsyrk(blas.Lower, blas.NoTrans, jb, j, -1, a[j*lda:], lda, 1, a[j*lda+j:], lda) + ok := impl.Dpotf2(blas.Lower, jb, a[j*lda+j:], lda) + if !ok { + return ok + } + if j+jb < n { + bi.Dgemm(blas.NoTrans, blas.Trans, n-j-jb, jb, j, -1, + a[(j+jb)*lda:], lda, a[j*lda:], lda, 1, a[(j+jb)*lda+j:], lda) + bi.Dtrsm(blas.Right, blas.Lower, blas.Trans, blas.NonUnit, n-j-jb, jb, 1, + a[j*lda+j:], lda, a[(j+jb)*lda+j:], lda) + } + } + return true +} diff --git a/native/general.go b/native/general.go new file mode 100644 index 00000000..8ed45786 --- /dev/null +++ b/native/general.go @@ -0,0 +1,31 @@ +package native + +import "github.com/gonum/lapack" + +type Implementation struct{} + +var _ lapack.Float64 = Implementation{} + +const ( + badUplo = "lapack: illegal triangle" + nLT0 = "lapack: n < 0" + badLdA = "lapack: index of a out of range" +) + +func blockSize() int { + return 64 +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/native/lapack_test.go b/native/lapack_test.go new file mode 100644 index 00000000..1ae051b0 --- /dev/null +++ b/native/lapack_test.go @@ -0,0 +1,17 @@ +package native + +import ( + "testing" + + "github.com/gonum/lapack/testlapack" +) + +var impl = Implementation{} + +func TestDpotf2(t *testing.T) { + testlapack.Dpotf2Test(t, impl) +} + +func TestDpotrf(t *testing.T) { + testlapack.DpotrfTest(t, impl) +} diff --git a/test-coverage.sh b/test-coverage.sh index 3df03f52..1f276049 100644 --- a/test-coverage.sh +++ b/test-coverage.sh @@ -40,4 +40,3 @@ then fi rm -rf ./profile.out -rm -rf ./acc.out diff --git a/testlapack/dpotf2.go b/testlapack/dpotf2.go new file mode 100644 index 00000000..6bdb85b6 --- /dev/null +++ b/testlapack/dpotf2.go @@ -0,0 +1,103 @@ +package testlapack + +import ( + "testing" + + "github.com/gonum/blas" + "github.com/gonum/floats" +) + +type Dpotf2er interface { + Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool) +} + +func Dpotf2Test(t *testing.T, impl Dpotf2er) { + for _, test := range []struct { + a [][]float64 + ul blas.Uplo + pos bool + U [][]float64 + }{ + { + a: [][]float64{ + {23, 37, 34, 32}, + {108, 71, 48, 48}, + {109, 109, 67, 58}, + {106, 107, 106, 63}, + }, + pos: true, + U: [][]float64{ + {4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393}, + {0, 3.387958215439679, -1.976308959006481, -1.026654004678691}, + {0, 0, 3.582364210034111, 2.419258947036024}, + {0, 0, 0, 3.401680257083044}, + }, + }, + } { + testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper) + testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper) + aT := transpose(test.a) + L := transpose(test.U) + testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower) + testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower) + } +} + +func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) { + aFlat := flattenTri(a, stride, ul) + ansFlat := flattenTri(ans, stride, ul) + pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride) + if pos != testPos { + t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos) + return + } + if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) { + t.Errorf("Result mismatch: Want %v, Got %v", ansFlat, aFlat) + } +} + +// flattenTri with a certain stride. stride must be >= dimension. Puts repeatable +// nonce values in non-accessed places +func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 { + m := len(a) + n := len(a[0]) + if stride < n { + panic("bad stride") + } + upper := ul == blas.Upper + v := make([]float64, m*stride) + count := 1000.0 + for i := 0; i < m; i++ { + for j := 0; j < stride; j++ { + if j >= n || (upper && j < i) || (!upper && j > i) { + // not accessed, so give a unique crazy number + v[i*stride+j] = count + count++ + continue + } + v[i*stride+j] = a[i][j] + } + } + return v +} + +func transpose(a [][]float64) [][]float64 { + m := len(a) + n := len(a[0]) + if m != n { + panic("not square") + } + aNew := make([][]float64, m) + for i := 0; i < m; i++ { + aNew[i] = make([]float64, n) + } + for i := 0; i < m; i++ { + if len(a[i]) != n { + panic("bad n size") + } + for j := 0; j < n; j++ { + aNew[j][i] = a[i][j] + } + } + return aNew +} diff --git a/testlapack/dpotrf.go b/testlapack/dpotrf.go new file mode 100644 index 00000000..1ff30411 --- /dev/null +++ b/testlapack/dpotrf.go @@ -0,0 +1,71 @@ +// Copyright ©2015 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testlapack + +import ( + "math/rand" + "testing" + + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" + "github.com/gonum/floats" +) + +type Dpotrfer interface { + Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) +} + +func DpotrfTest(t *testing.T, impl Dpotrfer) { + bi := blas64.Implementation() + for i, test := range []struct { + n int + }{ + {n: 10}, + {n: 30}, + {n: 63}, + {n: 65}, + {n: 128}, + {n: 1000}, + } { + n := test.n + // Construct a positive-definite symmetric matrix + base := make([]float64, n*n) + for i := range base { + base[i] = rand.Float64() + } + a := make([]float64, len(base)) + bi.Dgemm(blas.Trans, blas.NoTrans, n, n, n, 1, base, n, base, n, 0, a, n) + + aCopy := make([]float64, len(a)) + copy(aCopy, a) + + // Test with Upper + impl.Dpotrf(blas.Upper, n, a, n) + + // zero all the other elements + for i := 0; i < n; i++ { + for j := 0; j < i; j++ { + a[i*n+j] = 0 + } + } + // Multiply u^T * u + ans := make([]float64, len(a)) + bi.Dsyrk(blas.Upper, blas.Trans, n, n, 1, a, n, 0, ans, n) + + match := true + for i := 0; i < n; i++ { + for j := i; j < n; j++ { + if !floats.EqualWithinAbsOrRel(ans[i*n+j], aCopy[i*n+j], 1e-14, 1e-14) { + match = false + } + } + } + if !match { + //fmt.Println(aCopy) + //fmt.Println(ans) + t.Errorf("Case %v: Mismatch for upper", i) + } + } +} diff --git a/zla/impl.go b/zla/impl.go deleted file mode 100644 index 63c6fe05..00000000 --- a/zla/impl.go +++ /dev/null @@ -1,9 +0,0 @@ -package zla - -import "github.com/gonum/lapack" - -var impl lapack.Complex128 - -func Register(i lapack.Complex128) { - impl = i -} diff --git a/zla/qr.go b/zla/qr.go deleted file mode 100644 index cb8b728a..00000000 --- a/zla/qr.go +++ /dev/null @@ -1,40 +0,0 @@ -package zla - -import ( - "github.com/gonum/blas" - "github.com/gonum/blas/cblas128" -) - -type QRFact struct { - a cblas128.General - tau []complex128 -} - -func QR(A cblas128.General, tau []complex128) QRFact { - impl.Zgeqrf(A.Rows, A.Cols, A.Data, A.Stride, tau) - return QRFact{A, tau} -} - -func (f QRFact) R() cblas128.Triangular { - n := f.a.Rows - if f.a.Cols < n { - n = f.a.Cols - } - return cblas128.Triangular{ - Data: f.a.Data, - N: n, - Stride: f.a.Stride, - Uplo: blas.Upper, - Diag: blas.NonUnit, - } -} - -func (f QRFact) Solve(B cblas128.General) cblas128.General { - if f.a.Cols != B.Cols { - panic("dimension missmatch") - } - impl.Zunmqr(blas.Left, blas.ConjTrans, f.a.Rows, B.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, B.Data, B.Stride) - B.Rows = f.a.Cols - cblas128.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B) - return B -} diff --git a/zla/svd.go b/zla/svd.go deleted file mode 100644 index 2762e9df..00000000 --- a/zla/svd.go +++ /dev/null @@ -1,86 +0,0 @@ -package zla - -import ( - "github.com/gonum/blas" - "github.com/gonum/blas/cblas128" - "github.com/gonum/lapack" -) - -func SVD(A cblas128.General) (U cblas128.General, s []float64, Vt cblas128.General) { - m := A.Rows - n := A.Cols - U.Stride = 1 - Vt.Stride = 1 - if m >= n { - Vt = cblas128.General{Rows: n, Cols: n, Stride: n, Data: make([]complex128, n*n)} - s = make([]float64, n) - U = A - } else { - U = cblas128.General{Rows: n, Cols: n, Stride: n, Data: make([]complex128, n*n)} - s = make([]float64, m) - Vt = A - } - - impl.Zgesdd(lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride) - - return -} - -func c128col(i int, a cblas128.General) cblas128.Vector { - return cblas128.Vector{ - Inc: a.Stride, - Data: a.Data[i:], - } -} - -//Lanczos bidiagonalization with full reorthogonalization -func LanczosBi(L cblas128.General, u []complex128, numIter int) (U cblas128.General, V cblas128.General, a []float64, b []float64) { - - m := L.Rows - n := L.Cols - - uv := cblas128.Vector{Inc: 1, Data: u} - cblas128.Scal(len(u), complex(1/cblas128.Nrm2(len(u), uv), 0), uv) - - U = cblas128.General{Rows: m, Cols: numIter, Stride: numIter, Data: make([]complex128, m*numIter)} - V = cblas128.General{Rows: n, Cols: numIter, Stride: numIter, Data: make([]complex128, n*numIter)} - - a = make([]float64, numIter) - b = make([]float64, numIter) - - cblas128.Copy(len(u), uv, c128col(0, U)) - - tr := cblas128.Vector{Inc: 1, Data: make([]complex128, n)} - cblas128.Gemv(blas.ConjTrans, 1, L, uv, 0, tr) - a[0] = cblas128.Nrm2(n, tr) - cblas128.Copy(n, tr, c128col(0, V)) - cblas128.Scal(n, complex(1/a[0], 0), c128col(0, V)) - - tl := cblas128.Vector{Inc: 1, Data: make([]complex128, m)} - for k := 0; k < numIter-1; k++ { - cblas128.Copy(m, c128col(k, U), tl) - cblas128.Scal(m, complex(-a[k], 0), tl) - cblas128.Gemv(blas.NoTrans, 1, L, c128col(k, V), 1, tl) - - for i := 0; i <= k; i++ { - cblas128.Axpy(m, -cblas128.Dotc(m, c128col(i, U), tl), c128col(i, U), tl) - } - - b[k] = cblas128.Nrm2(m, tl) - cblas128.Copy(m, tl, c128col(k+1, U)) - cblas128.Scal(m, complex(1/b[k], 0), c128col(k+1, U)) - - cblas128.Copy(n, c128col(k, V), tr) - cblas128.Scal(n, complex(-b[k], 0), tr) - cblas128.Gemv(blas.ConjTrans, 1, L, c128col(k+1, U), 1, tr) - - for i := 0; i <= k; i++ { - cblas128.Axpy(n, -cblas128.Dotc(n, c128col(i, V), tr), c128col(i, V), tr) - } - - a[k+1] = cblas128.Nrm2(n, tr) - cblas128.Copy(n, tr, c128col(k+1, V)) - cblas128.Scal(n, complex(1/a[k+1], 0), c128col(k+1, V)) - } - return -} diff --git a/zla/zla_test.go b/zla/zla_test.go deleted file mode 100644 index 204ecf1b..00000000 --- a/zla/zla_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package zla - -import ( - "fmt" - "math" - "math/rand" - "testing" - - "github.com/gonum/blas" - "github.com/gonum/blas/blas64" - "github.com/gonum/blas/cblas128" - "github.com/gonum/lapack/clapack" - "github.com/gonum/lapack/dla" -) - -func init() { - Register(clapack.Lapack{}) - dla.Register(clapack.Lapack{}) -} - -func fillRandn(a []complex128, mu complex128, sigmaSq float64) { - fact := math.Sqrt(0.5 * sigmaSq) - for i := range a { - a[i] = complex(fact*rand.NormFloat64(), fact*rand.NormFloat64()) + mu - } -} - -func TestQR(t *testing.T) { - A := cblas128.General{ - Rows: 3, - Cols: 2, - Stride: 2, - Data: []complex128{complex(1, 0), complex(2, 0), complex(3, 0), complex(4, 0), complex(5, 0), complex(6, 0)}, - } - B := cblas128.General{ - Rows: 3, - Cols: 2, - Stride: 2, - Data: []complex128{complex(1, 0), complex(1, 0), complex(1, 0), complex(2, 0), complex(2, 0), complex(2, 0)}, - } - - tau := make([]complex128, 2) - - f := QR(A, tau) - - //fmt.Println(B) - f.Solve(B) - //fmt.Println(B) -} - -func f64col(i int, a blas64.General) blas64.Vector { - return blas64.Vector{ - Inc: a.Stride, - Data: a.Data[i:], - } -} - -func TestLanczos(t *testing.T) { - A := cblas128.General{Rows: 3, Cols: 4, Stride: 4, Data: make([]complex128, 3*4)} - fillRandn(A.Data, 0, 1) - - Acpy := cblas128.General{Rows: 3, Cols: 4, Stride: 4, Data: make([]complex128, 3*4)} - copy(Acpy.Data, A.Data) - - u0 := make([]complex128, 3) - fillRandn(u0, 0, 1) - - Ul, Vl, a, b := LanczosBi(Acpy, u0, 3) - - fmt.Println(a, b) - - tmpc := cblas128.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]complex128, 3*3)} - bidic := cblas128.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]complex128, 3*3)} - - cblas128.Gemm(blas.NoTrans, blas.NoTrans, 1, A, Vl, 0, tmpc) - cblas128.Gemm(blas.ConjTrans, blas.NoTrans, 1, Ul, tmpc, 0, bidic) - - fmt.Println(bidic) - - Ur, s, Vr := dla.SVDbd(blas.Lower, a, b) - - tmp := blas64.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]float64, 3*3)} - bidi := blas64.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]float64, 3*3)} - - copy(tmp.Data, Ur.Data) - for i := 0; i < 3; i++ { - blas64.Scal(3, s[i], f64col(i, tmp)) - } - - blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, Vr, 0, bidi) - - fmt.Println(bidi) - /* - - _ = Ul - _ = Vl - Uc := zbw.NewGeneral( 3, 3, nil) - zbw.Real2Cmplx(Ur.Data[:3*3], Uc.Data) - - fmt.Println(Uc.Data) - - U := zbw.NewGeneral( M, K, nil) - zbw.Gemm(blas.NoTrans, blas.NoTrans, 1, U1, Uc, 0, U) - */ -}