Make gonum/lapack packages use blas64 and cblas128

This commit is contained in:
kortschak
2015-01-12 14:39:38 +10:30
parent f795a38819
commit 2c251e7b5d
7 changed files with 165 additions and 129 deletions

View File

@@ -6,7 +6,7 @@ language (http://golang.org)
This is work in progress. Breaking changes are likely to happen.
## Installation
## Installation
```
go get github.com/gonum/blas
@@ -50,47 +50,6 @@ The linker flags (i.e. path to the BLAS library and library name) might have to
The recommended (free) option for good performance on both linux and darwin is OpenBLAS.
### blas/dbw
Experimental wrapper for the float64 part of the lapack interface.
You have to register an implementation before you can use the LAPACK functions:
```go
package main
import (
"fmt"
"github.com/gonum/blas/cblas"
"github.com/gonum/blas/dbw"
"github.com/gonum/lapack/clapack"
"github.com/gonum/lapack/dla"
)
func init() {
dbw.Register(cblas.Blas{})
dla.Register(clapack.Lapack{})
}
func main() {
A := dbw.NewGeneral(3, 2, []float64{1, 2, 3, 4, 5, 6})
B := dbw.NewGeneral(3, 2, []float64{1, 2, 1, 2, 1, 2})
tau := dbw.Allocate(2)
f := dla.QR(A, tau)
f.Solve(B)
fmt.Println(B.Data)
}
```
### blas/zbw
Experimental wrapper for the complex128 part of the lapack interface.
## Issues
If you find any bugs, feel free to file an issue on the github issue tracker. Discussions on API changes, added features, code review, or similar requests are preferred on the gonum-dev Google Group.

View File

@@ -5,39 +5,47 @@ import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/blas/cblas"
"github.com/gonum/blas/dbw"
"github.com/gonum/blas/blas64"
"github.com/gonum/lapack/clapack"
"github.com/gonum/matrix/mat64"
)
type fm struct {
mat64.Matrix
*mat64.Dense
margin int
}
func (m fm) Format(fs fmt.State, c rune) {
if c == 'v' && fs.Flag('#') {
fmt.Fprintf(fs, "%#v", m.Matrix)
fmt.Fprintf(fs, "%#v", m.Dense)
return
}
mat64.Format(m.Matrix, m.margin, '.', fs, c)
mat64.Format(m.Dense, m.margin, '.', fs, c)
}
func init() {
Register(clapack.Lapack{})
dbw.Register(cblas.Blas{})
}
func TestQR(t *testing.T) {
A := dbw.NewGeneral(3, 2, []float64{1, 2, 3, 4, 5, 6})
B := dbw.NewGeneral(3, 2, []float64{1, 1, 1, 2, 2, 2})
A := blas64.General{
Rows: 3,
Cols: 2,
Stride: 2,
Data: []float64{1, 2, 3, 4, 5, 6},
}
B := blas64.General{
Rows: 3,
Cols: 2,
Stride: 2,
Data: []float64{1, 1, 1, 2, 2, 2},
}
tau := dbw.Allocate(2)
tau := make([]float64, 2)
C := dbw.NewGeneral(2, 2, nil)
C := blas64.General{Rows: 2, Cols: 2, Stride: 2, Data: make([]float64, 2*2)}
dbw.Gemm(blas.Trans, blas.NoTrans, 1, A, B, 0, C)
blas64.Gemm(blas.Trans, blas.NoTrans, 1, A, B, 0, C)
fmt.Println(C)
@@ -47,5 +55,7 @@ func TestQR(t *testing.T) {
fmt.Println(f)
f.Solve(B)
fmt.Println(fm{B, 0})
var pm mat64.Dense
pm.SetRawMatrix(B)
fmt.Println(fm{&pm, 0})
}

View File

@@ -2,29 +2,39 @@ package dla
import (
"github.com/gonum/blas"
"github.com/gonum/blas/dbw"
"github.com/gonum/blas/blas64"
)
type QRFact struct {
a dbw.General
a blas64.General
tau []float64
}
func QR(A dbw.General, tau []float64) QRFact {
impl.Dgeqrf(A.Rows, A.Cols, A.Data, A.Stride, tau)
return QRFact{A, tau}
func QR(a blas64.General, tau []float64) QRFact {
impl.Dgeqrf(a.Rows, a.Cols, a.Data, a.Stride, tau)
return QRFact{a: a, tau: tau}
}
func (f QRFact) R() dbw.Triangular {
return dbw.Ge2Tr(f.a, blas.NonUnit, blas.Upper)
func (f QRFact) R() blas64.Triangular {
n := f.a.Rows
if f.a.Cols < n {
n = f.a.Cols
}
return blas64.Triangular{
Data: f.a.Data,
N: n,
Stride: f.a.Stride,
Uplo: blas.Upper,
Diag: blas.NonUnit,
}
}
func (f QRFact) Solve(B dbw.General) dbw.General {
if f.a.Cols != B.Cols {
func (f QRFact) Solve(b blas64.General) blas64.General {
if f.a.Cols != b.Cols {
panic("dimension missmatch")
}
impl.Dormqr(blas.Left, blas.Trans, B.Rows, B.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, B.Data, B.Stride)
B.Rows = f.a.Cols
dbw.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B)
return B
impl.Dormqr(blas.Left, blas.Trans, b.Rows, b.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, b.Data, b.Stride)
b.Rows = f.a.Cols
blas64.Trsm(blas.Left, blas.NoTrans, 1, f.R(), b)
return b
}

View File

@@ -2,21 +2,31 @@ package dla
import (
"github.com/gonum/blas"
"github.com/gonum/blas/dbw"
"github.com/gonum/blas/blas64"
"github.com/gonum/lapack"
)
func SVD(A dbw.General) (U dbw.General, s []float64, Vt dbw.General) {
func SVD(A blas64.General) (U blas64.General, s []float64, Vt blas64.General) {
m := A.Rows
n := A.Cols
U.Stride = 1
Vt.Stride = 1
if m >= n {
Vt = dbw.NewGeneral(n, n, nil)
Vt = blas64.General{
Rows: n,
Cols: n,
Stride: n,
Data: make([]float64, n*n),
}
s = make([]float64, n)
U = A
} else {
U = dbw.NewGeneral(m, m, nil)
U = blas64.General{
Rows: m,
Cols: m,
Stride: m,
Data: make([]float64, n*n),
}
s = make([]float64, m)
Vt = A
}
@@ -26,14 +36,24 @@ func SVD(A dbw.General) (U dbw.General, s []float64, Vt dbw.General) {
return
}
func SVDbd(uplo blas.Uplo, d, e []float64) (U dbw.General, s []float64, Vt dbw.General) {
func SVDbd(uplo blas.Uplo, d, e []float64) (U blas64.General, s []float64, Vt blas64.General) {
n := len(d)
if len(e) != n {
panic("dimensionality missmatch")
}
U = dbw.NewGeneral(n, n, nil)
Vt = dbw.NewGeneral(n, n, nil)
U = blas64.General{
Rows: n,
Cols: n,
Stride: n,
Data: make([]float64, n*n),
}
Vt = blas64.General{
Rows: n,
Cols: n,
Stride: n,
Data: make([]float64, n*n),
}
impl.Dbdsdc(uplo, lapack.Explicit, n, d, e, U.Data, U.Stride, Vt.Data, Vt.Stride, nil, nil)
s = d

View File

@@ -1,30 +1,42 @@
//+build cblas
package zla
import (
"github.com/gonum/blas"
"github.com/gonum/blas/zbw"
"github.com/gonum/blas/cblas128"
)
type QRFact struct {
a zbw.General
a cblas128.General
tau []complex128
}
func QR(A zbw.General, tau []complex128) QRFact {
func QR(A cblas128.General, tau []complex128) QRFact {
impl.Zgeqrf(A.Rows, A.Cols, A.Data, A.Stride, tau)
return QRFact{A, tau}
}
func (f QRFact) R() zbw.Triangular {
return zbw.Ge2Tr(f.a, blas.NonUnit, blas.Upper)
func (f QRFact) R() cblas128.Triangular {
n := f.a.Rows
if f.a.Cols < n {
n = f.a.Cols
}
return cblas128.Triangular{
Data: f.a.Data,
N: n,
Stride: f.a.Stride,
Uplo: blas.Upper,
Diag: blas.NonUnit,
}
}
func (f QRFact) Solve(B zbw.General) zbw.General {
func (f QRFact) Solve(B cblas128.General) cblas128.General {
if f.a.Cols != B.Cols {
panic("dimension missmatch")
}
impl.Zunmqr(blas.Left, blas.ConjTrans, f.a.Rows, B.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, B.Data, B.Stride)
B.Rows = f.a.Cols
zbw.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B)
cblas128.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B)
return B
}

View File

@@ -1,22 +1,24 @@
//+build cblas
package zla
import (
"github.com/gonum/blas"
"github.com/gonum/blas/zbw"
"github.com/gonum/blas/cblas128"
"github.com/gonum/lapack"
)
func SVD(A zbw.General) (U zbw.General, s []float64, Vt zbw.General) {
func SVD(A cblas128.General) (U cblas128.General, s []float64, Vt cblas128.General) {
m := A.Rows
n := A.Cols
U.Stride = 1
Vt.Stride = 1
if m >= n {
Vt = zbw.NewGeneral(n, n, nil)
Vt = cblas128.General{Rows: n, Cols: n, Stride: n, Data: make([]complex128, n*n)}
s = make([]float64, n)
U = A
} else {
U = zbw.NewGeneral(m, m, nil)
U = cblas128.General{Rows: n, Cols: n, Stride: n, Data: make([]complex128, n*n)}
s = make([]float64, m)
Vt = A
}
@@ -26,54 +28,61 @@ func SVD(A zbw.General) (U zbw.General, s []float64, Vt zbw.General) {
return
}
func c128col(i int, a cblas128.General) cblas128.Vector {
return cblas128.Vector{
Inc: a.Stride,
Data: a.Data[i:],
}
}
//Lanczos bidiagonalization with full reorthogonalization
func LanczosBi(L zbw.General, u []complex128, numIter int) (U zbw.General, V zbw.General, a []float64, b []float64) {
func LanczosBi(L cblas128.General, u []complex128, numIter int) (U cblas128.General, V cblas128.General, a []float64, b []float64) {
m := L.Rows
n := L.Cols
uv := zbw.NewVector(u)
zbw.Scal(complex(1/zbw.Nrm2(uv), 0), uv)
uv := cblas128.Vector{Inc: 1, Data: u}
cblas128.Scal(len(u), complex(1/cblas128.Nrm2(len(u), uv), 0), uv)
U = zbw.NewGeneral(m, numIter, nil)
V = zbw.NewGeneral(n, numIter, nil)
U = cblas128.General{Rows: m, Cols: numIter, Stride: numIter, Data: make([]complex128, m*numIter)}
V = cblas128.General{Rows: n, Cols: numIter, Stride: numIter, Data: make([]complex128, n*numIter)}
a = make([]float64, numIter)
b = make([]float64, numIter)
zbw.Copy(uv, U.Col(0))
cblas128.Copy(len(u), uv, c128col(0, U))
tr := zbw.NewVector(zbw.Allocate(n))
zbw.Gemv(blas.ConjTrans, 1, L, uv, 0, tr)
a[0] = zbw.Nrm2(tr)
zbw.Copy(tr, V.Col(0))
zbw.Scal(complex(1/a[0], 0), V.Col(0))
tr := cblas128.Vector{Inc: 1, Data: make([]complex128, n)}
cblas128.Gemv(blas.ConjTrans, 1, L, uv, 0, tr)
a[0] = cblas128.Nrm2(n, tr)
cblas128.Copy(n, tr, c128col(0, V))
cblas128.Scal(n, complex(1/a[0], 0), c128col(0, V))
tl := zbw.NewVector(zbw.Allocate(m))
tl := cblas128.Vector{Inc: 1, Data: make([]complex128, m)}
for k := 0; k < numIter-1; k++ {
zbw.Copy(U.Col(k), tl)
zbw.Scal(complex(-a[k], 0), tl)
zbw.Gemv(blas.NoTrans, 1, L, V.Col(k), 1, tl)
cblas128.Copy(m, c128col(k, U), tl)
cblas128.Scal(m, complex(-a[k], 0), tl)
cblas128.Gemv(blas.NoTrans, 1, L, c128col(k, V), 1, tl)
for i := 0; i <= k; i++ {
zbw.Axpy(-zbw.Dotc(U.Col(i), tl), U.Col(i), tl)
cblas128.Axpy(m, -cblas128.Dotc(m, c128col(i, U), tl), c128col(i, U), tl)
}
b[k] = zbw.Nrm2(tl)
zbw.Copy(tl, U.Col(k+1))
zbw.Scal(complex(1/b[k], 0), U.Col(k+1))
b[k] = cblas128.Nrm2(m, tl)
cblas128.Copy(m, tl, c128col(k+1, U))
cblas128.Scal(m, complex(1/b[k], 0), c128col(k+1, U))
zbw.Copy(V.Col(k), tr)
zbw.Scal(complex(-b[k], 0), tr)
zbw.Gemv(blas.ConjTrans, 1, L, U.Col(k+1), 1, tr)
cblas128.Copy(n, c128col(k, V), tr)
cblas128.Scal(n, complex(-b[k], 0), tr)
cblas128.Gemv(blas.ConjTrans, 1, L, c128col(k+1, U), 1, tr)
for i := 0; i <= k; i++ {
zbw.Axpy(-zbw.Dotc(V.Col(i), tr), V.Col(i), tr)
cblas128.Axpy(n, -cblas128.Dotc(n, c128col(i, V), tr), c128col(i, V), tr)
}
a[k+1] = zbw.Nrm2(tr)
zbw.Copy(tr, V.Col(k+1))
zbw.Scal(complex(1/a[k+1], 0), V.Col(k+1))
a[k+1] = cblas128.Nrm2(n, tr)
cblas128.Copy(n, tr, c128col(k+1, V))
cblas128.Scal(n, complex(1/a[k+1], 0), c128col(k+1, V))
}
return
}

View File

@@ -1,3 +1,5 @@
//+build cblas
package zla
import (
@@ -7,9 +9,8 @@ import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/blas/cblas"
"github.com/gonum/blas/dbw"
"github.com/gonum/blas/zbw"
"github.com/gonum/blas/blas64"
"github.com/gonum/blas/cblas128"
"github.com/gonum/lapack/clapack"
"github.com/gonum/lapack/dla"
)
@@ -17,8 +18,6 @@ import (
func init() {
Register(clapack.Lapack{})
dla.Register(clapack.Lapack{})
zbw.Register(cblas.Blas{})
dbw.Register(cblas.Blas{})
}
func fillRandn(a []complex128, mu complex128, sigmaSq float64) {
@@ -29,10 +28,20 @@ func fillRandn(a []complex128, mu complex128, sigmaSq float64) {
}
func TestQR(t *testing.T) {
A := zbw.NewGeneral(3, 2, []complex128{complex(1, 0), complex(2, 0), complex(3, 0), complex(4, 0), complex(5, 0), complex(6, 0)})
B := zbw.NewGeneral(3, 2, []complex128{complex(1, 0), complex(1, 0), complex(1, 0), complex(2, 0), complex(2, 0), complex(2, 0)})
A := cblas128.General{
Rows: 3,
Cols: 2,
Stride: 2,
Data: []complex128{complex(1, 0), complex(2, 0), complex(3, 0), complex(4, 0), complex(5, 0), complex(6, 0)},
}
B := cblas128.General{
Rows: 3,
Cols: 2,
Stride: 2,
Data: []complex128{complex(1, 0), complex(1, 0), complex(1, 0), complex(2, 0), complex(2, 0), complex(2, 0)},
}
tau := zbw.Allocate(2)
tau := make([]complex128, 2)
f := QR(A, tau)
@@ -41,11 +50,18 @@ func TestQR(t *testing.T) {
//fmt.Println(B)
}
func f64col(i int, a blas64.General) blas64.Vector {
return blas64.Vector{
Inc: a.Stride,
Data: a.Data[i:],
}
}
func TestLanczos(t *testing.T) {
A := zbw.NewGeneral(3, 4, nil)
A := cblas128.General{Rows: 3, Cols: 4, Stride: 4, Data: make([]complex128, 3*4)}
fillRandn(A.Data, 0, 1)
Acpy := zbw.NewGeneral(3, 4, nil)
Acpy := cblas128.General{Rows: 3, Cols: 4, Stride: 4, Data: make([]complex128, 3*4)}
copy(Acpy.Data, A.Data)
u0 := make([]complex128, 3)
@@ -55,25 +71,25 @@ func TestLanczos(t *testing.T) {
fmt.Println(a, b)
tmpc := zbw.NewGeneral(3, 3, nil)
bidic := zbw.NewGeneral(3, 3, nil)
tmpc := cblas128.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]complex128, 3*3)}
bidic := cblas128.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]complex128, 3*3)}
zbw.Gemm(blas.NoTrans, blas.NoTrans, 1, A, Vl, 0, tmpc)
zbw.Gemm(blas.ConjTrans, blas.NoTrans, 1, Ul, tmpc, 0, bidic)
cblas128.Gemm(blas.NoTrans, blas.NoTrans, 1, A, Vl, 0, tmpc)
cblas128.Gemm(blas.ConjTrans, blas.NoTrans, 1, Ul, tmpc, 0, bidic)
fmt.Println(bidic)
Ur, s, Vr := dla.SVDbd(blas.Lower, a, b)
tmp := dbw.NewGeneral(3, 3, nil)
bidi := dbw.NewGeneral(3, 3, nil)
tmp := blas64.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]float64, 3*3)}
bidi := blas64.General{Rows: 3, Cols: 3, Stride: 3, Data: make([]float64, 3*3)}
copy(tmp.Data, Ur.Data)
for i := 0; i < 3; i++ {
dbw.Scal(s[i], tmp.Col(i))
blas64.Scal(3, s[i], f64col(i, tmp))
}
dbw.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, Vr, 0, bidi)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, Vr, 0, bidi)
fmt.Println(bidi)
/*