mirror of
https://github.com/gonum/gonum.git
synced 2025-10-20 21:59:25 +08:00
Rearrange Lapack to be like BLAS. Implement cholesky decomposition
Responded to PR comments modified travis file Changed input and output types added back needed types by cgo Fixed perl script so it compiles Changes to genLapack to allow compilation Reinstate test-coverage.sh
This commit is contained in:
@@ -39,7 +39,7 @@ install:
|
||||
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then export CGO_LDFLAGS="-L/usr/lib -lopenblas"; fi
|
||||
- go get github.com/gonum/blas
|
||||
- go get github.com/gonum/matrix/mat64
|
||||
- pushd clapack
|
||||
- pushd cgo
|
||||
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then perl genLapack.pl -L/usr/lib -lopenblas; fi
|
||||
- popd
|
||||
|
||||
|
@@ -2,9 +2,8 @@ Gonum LAPACK [](https://
|
||||
======
|
||||
|
||||
A collection of packages to provide LAPACK functionality for the Go programming
|
||||
language (http://golang.org)
|
||||
|
||||
This is work in progress. Breaking changes are likely to happen.
|
||||
language (http://golang.org). This provides a partial implementation in native go
|
||||
and a wrapper using cgo to a c-based implementation.
|
||||
|
||||
## Installation
|
||||
|
||||
|
16
clapack/genLapack.pl → cgo/genLapack.pl
Executable file → Normal file
16
clapack/genLapack.pl → cgo/genLapack.pl
Executable file → Normal file
@@ -63,6 +63,8 @@ func init() {
|
||||
_ = lapack.Float64(Lapack{})
|
||||
}
|
||||
|
||||
func isZero(ret C.int) bool { return ret == 0 }
|
||||
|
||||
EOH
|
||||
|
||||
$/ = undef;
|
||||
@@ -110,7 +112,14 @@ our %typeConv = (
|
||||
"LAPACK_C_SELECT2" => "Select2Complex64",
|
||||
"LAPACK_Z_SELECT1" => "Select1Complex128",
|
||||
"LAPACK_Z_SELECT2" => "Select2Complex128",
|
||||
"void" => ""
|
||||
"void" => "",
|
||||
|
||||
"lapack_int_return_type" => "bool",
|
||||
"lapack_int_return" => "isZero",
|
||||
"float_return_type" => "float32",
|
||||
"float_return" => "float32",
|
||||
"double_return_type" => "float64",
|
||||
"double_return" => "float64"
|
||||
);
|
||||
|
||||
foreach my $line (@lines) {
|
||||
@@ -162,14 +171,15 @@ sub processProto {
|
||||
$gofunc = ucfirst $func;
|
||||
}
|
||||
|
||||
my $GoRet = $typeConv{$ret};
|
||||
my $GoRet = $typeConv{$ret."_return"};
|
||||
my $GoRetType = $typeConv{$ret."_return_type"};
|
||||
my $complexType = $func;
|
||||
$complexType =~ s/.*_[isd]?([zc]).*/$1/;
|
||||
my ($params,$bp) = processParamToGo($func, $paramList, $complexType);
|
||||
if ($params eq "") {
|
||||
return
|
||||
}
|
||||
print $golapack "func (Lapack) ".$gofunc."(".$params.") ".$GoRet."{\n";
|
||||
print $golapack "func (Lapack) ".$gofunc."(".$params.") ".$GoRetType."{\n";
|
||||
print $golapack "\t";
|
||||
if ($ret ne 'void') {
|
||||
print $golapack "\n".$bp."\n"."return ".$GoRet."(";
|
@@ -1,61 +0,0 @@
|
||||
package dla
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/lapack/clapack"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
type fm struct {
|
||||
*mat64.Dense
|
||||
margin int
|
||||
}
|
||||
|
||||
func (m fm) Format(fs fmt.State, c rune) {
|
||||
if c == 'v' && fs.Flag('#') {
|
||||
fmt.Fprintf(fs, "%#v", m.Dense)
|
||||
return
|
||||
}
|
||||
mat64.Format(m.Dense, m.margin, '.', fs, c)
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register(clapack.Lapack{})
|
||||
}
|
||||
|
||||
func TestQR(t *testing.T) {
|
||||
A := blas64.General{
|
||||
Rows: 3,
|
||||
Cols: 2,
|
||||
Stride: 2,
|
||||
Data: []float64{1, 2, 3, 4, 5, 6},
|
||||
}
|
||||
B := blas64.General{
|
||||
Rows: 3,
|
||||
Cols: 2,
|
||||
Stride: 2,
|
||||
Data: []float64{1, 1, 1, 2, 2, 2},
|
||||
}
|
||||
|
||||
tau := make([]float64, 2)
|
||||
|
||||
C := blas64.General{Rows: 2, Cols: 2, Stride: 2, Data: make([]float64, 2*2)}
|
||||
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, A, B, 0, C)
|
||||
|
||||
fmt.Println(C)
|
||||
|
||||
f := QR(A, tau)
|
||||
|
||||
fmt.Println(B)
|
||||
fmt.Println(f)
|
||||
|
||||
f.Solve(B)
|
||||
var pm mat64.Dense
|
||||
pm.SetRawMatrix(B)
|
||||
fmt.Println(fm{&pm, 0})
|
||||
}
|
@@ -1,9 +0,0 @@
|
||||
package dla
|
||||
|
||||
import "github.com/gonum/lapack"
|
||||
|
||||
var impl lapack.Float64
|
||||
|
||||
func Register(i lapack.Float64) {
|
||||
impl = i
|
||||
}
|
40
dla/qr.go
40
dla/qr.go
@@ -1,40 +0,0 @@
|
||||
package dla
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
type QRFact struct {
|
||||
a blas64.General
|
||||
tau []float64
|
||||
}
|
||||
|
||||
func QR(a blas64.General, tau []float64) QRFact {
|
||||
impl.Dgeqrf(a.Rows, a.Cols, a.Data, a.Stride, tau)
|
||||
return QRFact{a: a, tau: tau}
|
||||
}
|
||||
|
||||
func (f QRFact) R() blas64.Triangular {
|
||||
n := f.a.Rows
|
||||
if f.a.Cols < n {
|
||||
n = f.a.Cols
|
||||
}
|
||||
return blas64.Triangular{
|
||||
Data: f.a.Data,
|
||||
N: n,
|
||||
Stride: f.a.Stride,
|
||||
Uplo: blas.Upper,
|
||||
Diag: blas.NonUnit,
|
||||
}
|
||||
}
|
||||
|
||||
func (f QRFact) Solve(b blas64.General) blas64.General {
|
||||
if f.a.Cols != b.Cols {
|
||||
panic("dimension missmatch")
|
||||
}
|
||||
impl.Dormqr(blas.Left, blas.Trans, b.Rows, b.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, b.Data, b.Stride)
|
||||
b.Rows = f.a.Cols
|
||||
blas64.Trsm(blas.Left, blas.NoTrans, 1, f.R(), b)
|
||||
return b
|
||||
}
|
61
dla/svd.go
61
dla/svd.go
@@ -1,61 +0,0 @@
|
||||
package dla
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
func SVD(A blas64.General) (U blas64.General, s []float64, Vt blas64.General) {
|
||||
m := A.Rows
|
||||
n := A.Cols
|
||||
U.Stride = 1
|
||||
Vt.Stride = 1
|
||||
if m >= n {
|
||||
Vt = blas64.General{
|
||||
Rows: n,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, n*n),
|
||||
}
|
||||
s = make([]float64, n)
|
||||
U = A
|
||||
} else {
|
||||
U = blas64.General{
|
||||
Rows: m,
|
||||
Cols: m,
|
||||
Stride: m,
|
||||
Data: make([]float64, n*n),
|
||||
}
|
||||
s = make([]float64, m)
|
||||
Vt = A
|
||||
}
|
||||
|
||||
impl.Dgesdd(lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func SVDbd(uplo blas.Uplo, d, e []float64) (U blas64.General, s []float64, Vt blas64.General) {
|
||||
n := len(d)
|
||||
if len(e) != n {
|
||||
panic("dimensionality missmatch")
|
||||
}
|
||||
|
||||
U = blas64.General{
|
||||
Rows: n,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, n*n),
|
||||
}
|
||||
Vt = blas64.General{
|
||||
Rows: n,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, n*n),
|
||||
}
|
||||
|
||||
impl.Dbdsdc(uplo, lapack.Explicit, n, d, e, U.Data, U.Stride, Vt.Data, Vt.Stride, nil, nil)
|
||||
s = d
|
||||
return
|
||||
}
|
21
lapack.go
21
lapack.go
@@ -1,8 +1,6 @@
|
||||
package lapack
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
)
|
||||
import "github.com/gonum/blas"
|
||||
|
||||
const None = 'N'
|
||||
|
||||
@@ -21,19 +19,10 @@ const (
|
||||
Explicit (CompSV) = 'I'
|
||||
)
|
||||
|
||||
// Float64 defines the float64 interface for the Lapack function. This interface
|
||||
// contains the functions needed in the gonum suite.
|
||||
type Float64 interface {
|
||||
Dgeqrf(m, n int, a []float64, lda int, tau []float64) int
|
||||
Dormqr(s blas.Side, t blas.Transpose, m, n, k int, a []float64, lda int, tau []float64, c []float64, ldc int) int
|
||||
Dgesdd(job Job, m, n int, a []float64, lda int, s []float64, u []float64, ldu int, vt []float64, ldvt int) int
|
||||
Dgebrd(m, n int, a []float64, lda int, d, e, tauq, taup []float64) int
|
||||
Dbdsdc(uplo blas.Uplo, compq CompSV, n int, d, e []float64, u []float64, ldu int, vt []float64, ldvt int, q []float64, iq []int32) int
|
||||
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
|
||||
}
|
||||
|
||||
type Complex128 interface {
|
||||
Float64
|
||||
|
||||
Zgeqrf(m, n int, a []complex128, lda int, tau []complex128) int
|
||||
Zunmqr(s blas.Side, t blas.Transpose, m, n, k int, a []complex128, lda int, tau []complex128, c []complex128, ldc int) int
|
||||
Zgesdd(job Job, m, n int, a []complex128, lda int, s []float64, u []complex128, ldu int, vt []complex128, ldvt int) int
|
||||
Zgebrd(m, n int, a []complex128, lda int, d, e []float64, tauq, taup []complex128) int
|
||||
}
|
||||
type Complex128 interface{}
|
||||
|
49
lapack64/lapack64.go
Normal file
49
lapack64/lapack64.go
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package lapack64 provides a set of convenient wrapper functions for LAPACK
|
||||
// calls, as specified in the netlib standard (www.netlib.org).
|
||||
//
|
||||
// The native Go routines are used by default, and the Use function can be used
|
||||
// to set an alternate implementation.
|
||||
//
|
||||
// If the type of matrix (General, Symmetric, etc.) is known and fixed, it is
|
||||
// used in the wrapper signature. In many cases, however, the type of the matrix
|
||||
// changes during the call to the routine, for example the matrix is symmetric on
|
||||
// entry and is triangular on exit. In these cases the correct types should be checked
|
||||
// in the documentation.
|
||||
//
|
||||
// The full set of Lapack functions is very large, and it is not clear that a
|
||||
// full implementation is desirable, let alone feasible. Please open up an issue
|
||||
// if there is a specific function you need and/or are willing to implement.
|
||||
package lapack64
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/lapack"
|
||||
"github.com/gonum/lapack/native"
|
||||
)
|
||||
|
||||
var lapack64 lapack.Float64 = native.Implementation{}
|
||||
|
||||
// Use sets the LAPACK float64 implementation to be used by subsequent BLAS calls.
|
||||
// The default implementation is native.Implementation.
|
||||
func Use(l lapack.Float64) {
|
||||
lapack64 = l
|
||||
}
|
||||
|
||||
// Potrf computes the cholesky factorization of a.
|
||||
// A = U^T * U if ul == blas.Upper
|
||||
// A = L * L^T if ul == blas.Lower
|
||||
// The underlying data between the input matrix and output matrix is shared.
|
||||
func Potrf(a blas64.Symmetric) (t blas64.Triangular, ok bool) {
|
||||
ok = lapack64.Dpotrf(a.Uplo, a.N, a.Data, a.Stride)
|
||||
t.Uplo = a.Uplo
|
||||
t.N = a.N
|
||||
t.Data = a.Data
|
||||
t.Stride = a.Stride
|
||||
t.Diag = blas.NonUnit
|
||||
return
|
||||
}
|
65
native/dpotf2.go
Normal file
65
native/dpotf2.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package native
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
// Dpotrf computes the cholesky decomposition of the symmetric positive definite
|
||||
// matrix a. If ul == blas.Upper, then a is stored as an upper-triangular matrix,
|
||||
// and a = U^T U is stored in place into a. If ul == blas.Lower, then a = L L^T
|
||||
// is computed and stored in-place into a. If a is not positive definite, false
|
||||
// is returned. This is the unblocked version of the algorithm.
|
||||
func (Implementation) Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool) {
|
||||
if ul != blas.Upper && ul != blas.Lower {
|
||||
panic(badUplo)
|
||||
}
|
||||
if n < 0 {
|
||||
panic(nLT0)
|
||||
}
|
||||
if lda < n {
|
||||
panic(badLdA)
|
||||
}
|
||||
if n == 0 {
|
||||
return true
|
||||
}
|
||||
bi := blas64.Implementation()
|
||||
if ul == blas.Upper {
|
||||
for j := 0; j < n; j++ {
|
||||
ajj := a[j*lda+j]
|
||||
if j != 0 {
|
||||
ajj -= bi.Ddot(j, a[j:], lda, a[j:], lda)
|
||||
}
|
||||
if ajj <= 0 || math.IsNaN(ajj) {
|
||||
a[j*lda+j] = ajj
|
||||
return false
|
||||
}
|
||||
ajj = math.Sqrt(ajj)
|
||||
a[j*lda+j] = ajj
|
||||
if j < n-1 {
|
||||
bi.Dgemv(blas.Trans, j, n-j-1, -1, a[j+1:], lda, a[j:], lda, 1, a[j*lda+j+1:], 1)
|
||||
bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
ajj := a[j*lda+j]
|
||||
if j != 0 {
|
||||
ajj -= bi.Ddot(j, a[j*lda:], 1, a[j*lda:], 1)
|
||||
}
|
||||
if ajj <= 0 || math.IsNaN(ajj) {
|
||||
a[j*lda+j] = ajj
|
||||
return false
|
||||
}
|
||||
ajj = math.Sqrt(ajj)
|
||||
a[j*lda+j] = ajj
|
||||
if j < n-1 {
|
||||
bi.Dgemv(blas.NoTrans, n-j-1, j, -1, a[(j+1)*lda:], lda, a[j*lda:], 1, 1, a[(j+1)*lda+j:], lda)
|
||||
bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
63
native/dpotrf.go
Normal file
63
native/dpotrf.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package native
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
// Dpotrf computes the cholesky decomposition of the symmetric positive definite
|
||||
// matrix a. If ul == blas.Upper, then a is stored as an upper-triangular matrix,
|
||||
// and a = U U^T is stored in place into a. If ul == blas.Lower, then a = L L^T
|
||||
// is computed and stored in-place into a. If a is not positive definite, false
|
||||
// is returned. This is the blocked version of the algorithm.
|
||||
func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) {
|
||||
bi := blas64.Implementation()
|
||||
if ul != blas.Upper && ul != blas.Lower {
|
||||
panic(badUplo)
|
||||
}
|
||||
if n < 0 {
|
||||
panic(nLT0)
|
||||
}
|
||||
if lda < n {
|
||||
panic(badLdA)
|
||||
}
|
||||
if n == 0 {
|
||||
return true
|
||||
}
|
||||
nb := blockSize()
|
||||
if n <= nb {
|
||||
return impl.Dpotf2(ul, n, a, lda)
|
||||
}
|
||||
if ul == blas.Upper {
|
||||
for j := 0; j < n; j += nb {
|
||||
jb := min(nb, n-j)
|
||||
bi.Dsyrk(blas.Upper, blas.Trans, jb, j, -1, a[j:], lda, 1, a[j*lda+j:], lda)
|
||||
ok = impl.Dpotf2(blas.Upper, jb, a[j*lda+j:], lda)
|
||||
if !ok {
|
||||
return ok
|
||||
}
|
||||
if j+jb < n {
|
||||
bi.Dgemm(blas.Trans, blas.NoTrans, jb, n-j-jb, j, -1,
|
||||
a[j:], lda, a[j+jb:], lda, 1, a[j*lda+j+jb:], lda)
|
||||
bi.Dtrsm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit, jb, n-j-jb, 1,
|
||||
a[j*lda+j:], lda, a[j*lda+j+jb:], lda)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
for j := 0; j < n; j += nb {
|
||||
jb := min(nb, n-j)
|
||||
bi.Dsyrk(blas.Lower, blas.NoTrans, jb, j, -1, a[j*lda:], lda, 1, a[j*lda+j:], lda)
|
||||
ok := impl.Dpotf2(blas.Lower, jb, a[j*lda+j:], lda)
|
||||
if !ok {
|
||||
return ok
|
||||
}
|
||||
if j+jb < n {
|
||||
bi.Dgemm(blas.NoTrans, blas.Trans, n-j-jb, jb, j, -1,
|
||||
a[(j+jb)*lda:], lda, a[j*lda:], lda, 1, a[(j+jb)*lda+j:], lda)
|
||||
bi.Dtrsm(blas.Right, blas.Lower, blas.Trans, blas.NonUnit, n-j-jb, jb, 1,
|
||||
a[j*lda+j:], lda, a[(j+jb)*lda+j:], lda)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
31
native/general.go
Normal file
31
native/general.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package native
|
||||
|
||||
import "github.com/gonum/lapack"
|
||||
|
||||
type Implementation struct{}
|
||||
|
||||
var _ lapack.Float64 = Implementation{}
|
||||
|
||||
const (
|
||||
badUplo = "lapack: illegal triangle"
|
||||
nLT0 = "lapack: n < 0"
|
||||
badLdA = "lapack: index of a out of range"
|
||||
)
|
||||
|
||||
func blockSize() int {
|
||||
return 64
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
17
native/lapack_test.go
Normal file
17
native/lapack_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package native
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/lapack/testlapack"
|
||||
)
|
||||
|
||||
var impl = Implementation{}
|
||||
|
||||
func TestDpotf2(t *testing.T) {
|
||||
testlapack.Dpotf2Test(t, impl)
|
||||
}
|
||||
|
||||
func TestDpotrf(t *testing.T) {
|
||||
testlapack.DpotrfTest(t, impl)
|
||||
}
|
@@ -40,4 +40,3 @@ then
|
||||
fi
|
||||
|
||||
rm -rf ./profile.out
|
||||
rm -rf ./acc.out
|
||||
|
103
testlapack/dpotf2.go
Normal file
103
testlapack/dpotf2.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dpotf2er interface {
|
||||
Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
|
||||
}
|
||||
|
||||
func Dpotf2Test(t *testing.T, impl Dpotf2er) {
|
||||
for _, test := range []struct {
|
||||
a [][]float64
|
||||
ul blas.Uplo
|
||||
pos bool
|
||||
U [][]float64
|
||||
}{
|
||||
{
|
||||
a: [][]float64{
|
||||
{23, 37, 34, 32},
|
||||
{108, 71, 48, 48},
|
||||
{109, 109, 67, 58},
|
||||
{106, 107, 106, 63},
|
||||
},
|
||||
pos: true,
|
||||
U: [][]float64{
|
||||
{4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393},
|
||||
{0, 3.387958215439679, -1.976308959006481, -1.026654004678691},
|
||||
{0, 0, 3.582364210034111, 2.419258947036024},
|
||||
{0, 0, 0, 3.401680257083044},
|
||||
},
|
||||
},
|
||||
} {
|
||||
testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper)
|
||||
testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper)
|
||||
aT := transpose(test.a)
|
||||
L := transpose(test.U)
|
||||
testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower)
|
||||
testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower)
|
||||
}
|
||||
}
|
||||
|
||||
func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) {
|
||||
aFlat := flattenTri(a, stride, ul)
|
||||
ansFlat := flattenTri(ans, stride, ul)
|
||||
pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride)
|
||||
if pos != testPos {
|
||||
t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos)
|
||||
return
|
||||
}
|
||||
if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) {
|
||||
t.Errorf("Result mismatch: Want %v, Got %v", ansFlat, aFlat)
|
||||
}
|
||||
}
|
||||
|
||||
// flattenTri with a certain stride. stride must be >= dimension. Puts repeatable
|
||||
// nonce values in non-accessed places
|
||||
func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 {
|
||||
m := len(a)
|
||||
n := len(a[0])
|
||||
if stride < n {
|
||||
panic("bad stride")
|
||||
}
|
||||
upper := ul == blas.Upper
|
||||
v := make([]float64, m*stride)
|
||||
count := 1000.0
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < stride; j++ {
|
||||
if j >= n || (upper && j < i) || (!upper && j > i) {
|
||||
// not accessed, so give a unique crazy number
|
||||
v[i*stride+j] = count
|
||||
count++
|
||||
continue
|
||||
}
|
||||
v[i*stride+j] = a[i][j]
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func transpose(a [][]float64) [][]float64 {
|
||||
m := len(a)
|
||||
n := len(a[0])
|
||||
if m != n {
|
||||
panic("not square")
|
||||
}
|
||||
aNew := make([][]float64, m)
|
||||
for i := 0; i < m; i++ {
|
||||
aNew[i] = make([]float64, n)
|
||||
}
|
||||
for i := 0; i < m; i++ {
|
||||
if len(a[i]) != n {
|
||||
panic("bad n size")
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
aNew[j][i] = a[i][j]
|
||||
}
|
||||
}
|
||||
return aNew
|
||||
}
|
71
testlapack/dpotrf.go
Normal file
71
testlapack/dpotrf.go
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dpotrfer interface {
|
||||
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
|
||||
}
|
||||
|
||||
func DpotrfTest(t *testing.T, impl Dpotrfer) {
|
||||
bi := blas64.Implementation()
|
||||
for i, test := range []struct {
|
||||
n int
|
||||
}{
|
||||
{n: 10},
|
||||
{n: 30},
|
||||
{n: 63},
|
||||
{n: 65},
|
||||
{n: 128},
|
||||
{n: 1000},
|
||||
} {
|
||||
n := test.n
|
||||
// Construct a positive-definite symmetric matrix
|
||||
base := make([]float64, n*n)
|
||||
for i := range base {
|
||||
base[i] = rand.Float64()
|
||||
}
|
||||
a := make([]float64, len(base))
|
||||
bi.Dgemm(blas.Trans, blas.NoTrans, n, n, n, 1, base, n, base, n, 0, a, n)
|
||||
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
|
||||
// Test with Upper
|
||||
impl.Dpotrf(blas.Upper, n, a, n)
|
||||
|
||||
// zero all the other elements
|
||||
for i := 0; i < n; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
a[i*n+j] = 0
|
||||
}
|
||||
}
|
||||
// Multiply u^T * u
|
||||
ans := make([]float64, len(a))
|
||||
bi.Dsyrk(blas.Upper, blas.Trans, n, n, 1, a, n, 0, ans, n)
|
||||
|
||||
match := true
|
||||
for i := 0; i < n; i++ {
|
||||
for j := i; j < n; j++ {
|
||||
if !floats.EqualWithinAbsOrRel(ans[i*n+j], aCopy[i*n+j], 1e-14, 1e-14) {
|
||||
match = false
|
||||
}
|
||||
}
|
||||
}
|
||||
if !match {
|
||||
//fmt.Println(aCopy)
|
||||
//fmt.Println(ans)
|
||||
t.Errorf("Case %v: Mismatch for upper", i)
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,9 +0,0 @@
|
||||
package zla
|
||||
|
||||
import "github.com/gonum/lapack"
|
||||
|
||||
var impl lapack.Complex128
|
||||
|
||||
func Register(i lapack.Complex128) {
|
||||
impl = i
|
||||
}
|
40
zla/qr.go
40
zla/qr.go
@@ -1,40 +0,0 @@
|
||||
package zla
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/cblas128"
|
||||
)
|
||||
|
||||
type QRFact struct {
|
||||
a cblas128.General
|
||||
tau []complex128
|
||||
}
|
||||
|
||||
func QR(A cblas128.General, tau []complex128) QRFact {
|
||||
impl.Zgeqrf(A.Rows, A.Cols, A.Data, A.Stride, tau)
|
||||
return QRFact{A, tau}
|
||||
}
|
||||
|
||||
func (f QRFact) R() cblas128.Triangular {
|
||||
n := f.a.Rows
|
||||
if f.a.Cols < n {
|
||||
n = f.a.Cols
|
||||
}
|
||||
return cblas128.Triangular{
|
||||
Data: f.a.Data,
|
||||
N: n,
|
||||
Stride: f.a.Stride,
|
||||
Uplo: blas.Upper,
|
||||
Diag: blas.NonUnit,
|
||||
}
|
||||
}
|
||||
|
||||
func (f QRFact) Solve(B cblas128.General) cblas128.General {
|
||||
if f.a.Cols != B.Cols {
|
||||
panic("dimension missmatch")
|
||||
}
|
||||
impl.Zunmqr(blas.Left, blas.ConjTrans, f.a.Rows, B.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, B.Data, B.Stride)
|
||||
B.Rows = f.a.Cols
|
||||
cblas128.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B)
|
||||
return B
|
||||
}
|
86
zla/svd.go
86
zla/svd.go
@@ -1,86 +0,0 @@
|
||||
package zla
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/cblas128"
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
func SVD(A cblas128.General) (U cblas128.General, s []float64, Vt cblas128.General) {
|
||||
m := A.Rows
|
||||
n := A.Cols
|
||||
U.Stride = 1
|
||||
Vt.Stride = 1
|
||||
if m >= n {
|
||||
Vt = cblas128.General{Rows: n, Cols: n, Stride: n, Data: make([]complex128, n*n)}
|
||||
s = make([]float64, n)
|
||||
U = A
|
||||
} else {
|
||||
U = cblas128.General{Rows: n, Cols: n, Stride: n, Data: make([]complex128, n*n)}
|
||||
s = make([]float64, m)
|
||||
Vt = A
|
||||
}
|
||||
|
||||
impl.Zgesdd(lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func c128col(i int, a cblas128.General) cblas128.Vector {
|
||||
return cblas128.Vector{
|
||||
Inc: a.Stride,
|
||||
Data: a.Data[i:],
|
||||
}
|
||||
}
|
||||
|
||||
//Lanczos bidiagonalization with full reorthogonalization
|
||||
func LanczosBi(L cblas128.General, u []complex128, numIter int) (U cblas128.General, V cblas128.General, a []float64, b []float64) {
|
||||
|
||||
m := L.Rows
|
||||
n := L.Cols
|
||||
|
||||
uv := cblas128.Vector{Inc: 1, Data: u}
|
||||
cblas128.Scal(len(u), complex(1/cblas128.Nrm2(len(u), uv), 0), uv)
|
||||
|
||||
U = cblas128.General{Rows: m, Cols: numIter, Stride: numIter, Data: make([]complex128, m*numIter)}
|
||||
V = cblas128.General{Rows: n, Cols: numIter, Stride: numIter, Data: make([]complex128, n*numIter)}
|
||||
|
||||
a = make([]float64, numIter)
|
||||
b = make([]float64, numIter)
|
||||
|
||||
cblas128.Copy(len(u), uv, c128col(0, U))
|
||||
|
||||
tr := cblas128.Vector{Inc: 1, Data: make([]complex128, n)}
|
||||
cblas128.Gemv(blas.ConjTrans, 1, L, uv, 0, tr)
|
||||
a[0] = cblas128.Nrm2(n, tr)
|
||||
cblas128.Copy(n, tr, c128col(0, V))
|
||||
cblas128.Scal(n, complex(1/a[0], 0), c128col(0, V))
|
||||
|
||||
tl := cblas128.Vector{Inc: 1, Data: make([]complex128, m)}
|
||||
for k := 0; k < numIter-1; k++ {
|
||||
cblas128.Copy(m, c128col(k, U), tl)
|
||||
cblas128.Scal(m, complex(-a[k], 0), tl)
|
||||
cblas128.Gemv(blas.NoTrans, 1, L, c128col(k, V), 1, tl)
|
||||
|
||||
for i := 0; i <= k; i++ {
|
||||
cblas128.Axpy(m, -cblas128.Dotc(m, c128col(i, U), tl), c128col(i, U), tl)
|
||||
}
|
||||
|
||||
b[k] = cblas128.Nrm2(m, tl)
|
||||
cblas128.Copy(m, tl, c128col(k+1, U))
|
||||
cblas128.Scal(m, complex(1/b[k], 0), c128col(k+1, U))
|
||||
|
||||
cblas128.Copy(n, c128col(k, V), tr)
|
||||
cblas128.Scal(n, complex(-b[k], 0), tr)
|
||||
cblas128.Gemv(blas.ConjTrans, 1, L, c128col(k+1, U), 1, tr)
|
||||
|
||||
for i := 0; i <= k; i++ {
|
||||
cblas128.Axpy(n, -cblas128.Dotc(n, c128col(i, V), tr), c128col(i, V), tr)
|
||||
}
|
||||
|
||||
a[k+1] = cblas128.Nrm2(n, tr)
|
||||
cblas128.Copy(n, tr, c128col(k+1, V))
|
||||
cblas128.Scal(n, complex(1/a[k+1], 0), c128col(k+1, V))
|
||||
}
|
||||
return
|
||||
}
|
105
zla/zla_test.go
105
zla/zla_test.go
@@ -1,105 +0,0 @@
|
||||
package zla
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/blas/cblas128"
|
||||
"github.com/gonum/lapack/clapack"
|
||||
"github.com/gonum/lapack/dla"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register(clapack.Lapack{})
|
||||
dla.Register(clapack.Lapack{})
|
||||
}
|
||||
|
||||
func fillRandn(a []complex128, mu complex128, sigmaSq float64) {
|
||||
fact := math.Sqrt(0.5 * sigmaSq)
|
||||
for i := range a {
|
||||
a[i] = complex(fact*rand.NormFloat64(), fact*rand.NormFloat64()) + mu
|
||||
}
|
||||
}
|
||||
|
||||
func TestQR(t *testing.T) {
|
||||
A := cblas128.General{
|
||||
Rows: 3,
|
||||
Cols: 2,
|
||||
Stride: 2,
|
||||
Data: []complex128{complex(1, 0), complex(2, 0), complex(3, 0), complex(4, 0), complex(5, 0), complex(6, 0)},
|
||||
}
|
||||
B := cblas128.General{
|
||||
Rows: 3,
|
||||
Cols: 2,
|
||||
Stride: 2,
|
||||
Data: []complex128{complex(1, 0), complex(1, 0), complex(1, 0), complex(2, 0), complex(2, 0), complex(2, 0)},
|
||||
}
|
||||
|
||||
tau := make([]complex128, 2)
|
||||
|
||||
f := QR(A, tau)
|
||||
|
||||
//fmt.Println(B)
|
||||
f.Solve(B)
|
||||
//fmt.Println(B)
|
||||
}
|
||||
|
||||
func f64col(i int, a blas64.General) blas64.Vector {
|
||||
return blas64.Vector{
|
||||
Inc: a.Stride,
|
||||
Data: a.Data[i:],
|
||||
}
|
||||
}
|
||||
|
||||
func TestLanczos(t *testing.T) {
|
||||
A := cblas128.General{Rows: 3, Cols: 4, Stride: 4, Data: make([]complex128, 3*4)}
|
||||
fillRandn(A.Data, 0, 1)
|
||||
|
||||
Acpy := cblas128.General{Rows: 3, Cols: 4, Stride: 4, Data: make([]complex128, 3*4)}
|
||||
copy(Acpy.Data, A.Data)
|
||||
|
||||
u0 := make([]complex128, 3)
|
||||
fillRandn(u0, 0, 1)
|
||||
|
||||
Ul, Vl, a, b := LanczosBi(Acpy, u0, 3)
|
||||
|
||||
fmt.Println(a, b)
|
||||
|
||||
tmpc := cblas128.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]complex128, 3*3)}
|
||||
bidic := cblas128.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]complex128, 3*3)}
|
||||
|
||||
cblas128.Gemm(blas.NoTrans, blas.NoTrans, 1, A, Vl, 0, tmpc)
|
||||
cblas128.Gemm(blas.ConjTrans, blas.NoTrans, 1, Ul, tmpc, 0, bidic)
|
||||
|
||||
fmt.Println(bidic)
|
||||
|
||||
Ur, s, Vr := dla.SVDbd(blas.Lower, a, b)
|
||||
|
||||
tmp := blas64.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]float64, 3*3)}
|
||||
bidi := blas64.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]float64, 3*3)}
|
||||
|
||||
copy(tmp.Data, Ur.Data)
|
||||
for i := 0; i < 3; i++ {
|
||||
blas64.Scal(3, s[i], f64col(i, tmp))
|
||||
}
|
||||
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, Vr, 0, bidi)
|
||||
|
||||
fmt.Println(bidi)
|
||||
/*
|
||||
|
||||
_ = Ul
|
||||
_ = Vl
|
||||
Uc := zbw.NewGeneral( 3, 3, nil)
|
||||
zbw.Real2Cmplx(Ur.Data[:3*3], Uc.Data)
|
||||
|
||||
fmt.Println(Uc.Data)
|
||||
|
||||
U := zbw.NewGeneral( M, K, nil)
|
||||
zbw.Gemm(blas.NoTrans, blas.NoTrans, 1, U1, Uc, 0, U)
|
||||
*/
|
||||
}
|
Reference in New Issue
Block a user