added some functionality

This commit is contained in:
David Neumann
2014-04-28 13:54:52 +02:00
parent 2089c0d4fa
commit b1211f0cc9
8 changed files with 407 additions and 25 deletions

View File

@@ -9,19 +9,22 @@ package clapack
*/
import "C"
import (
"github.com/dane-unltd/lapack"
"github.com/gonum/blas"
"github.com/gonum/blas/dbw"
"github.com/gonum/blas/zbw"
)
type La struct{}
func (La) Dgeqrf(A dbw.General, tau []float64) {
C.LAPACKE_dgeqrf(C.int(A.Order), C.int(A.Rows), C.int(A.Cols),
(*C.double)(&A.Data[0]), C.int(A.Stride), (*C.double)(&tau[0]))
func init() {
_ = lapack.Complex128(La{})
_ = lapack.Float64(La{})
}
func (La) Dormqr(s blas.Side, t blas.Transpose, A dbw.General, tau []float64, B dbw.General) {
func (La) Dgeqrf(o blas.Order, m, n int, a []float64, lda int, tau []float64) {
C.LAPACKE_dgeqrf(C.int(o), C.int(m), C.int(n), (*C.double)(&a[0]), C.int(lda), (*C.double)(&tau[0]))
}
func (La) Dormqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []float64, lda int, tau []float64, c []float64, ldc int) {
var cs, ct C.char
if s == blas.Left {
cs = 'l'
@@ -34,24 +37,118 @@ func (La) Dormqr(s blas.Side, t blas.Transpose, A dbw.General, tau []float64, B
ct = 't'
}
C.LAPACKE_dormqr(C.int(A.Order), cs, ct, C.int(B.Rows),
C.int(B.Cols), C.int(A.Cols), (*C.double)(&A.Data[0]),
C.int(A.Stride), (*C.double)(&tau[0]), (*C.double)(&B.Data[0]), C.int(B.Stride))
C.LAPACKE_dormqr(C.int(o), cs, ct, C.int(m),
C.int(n), C.int(k), (*C.double)(&a[0]),
C.int(lda), (*C.double)(&tau[0]), (*C.double)(&c[0]), C.int(ldc))
}
func (La) Zgesvd(jobz byte, A zbw.General, s []float64, U zbw.General, Vt zbw.General) {
func (La) Dgesdd(o blas.Order, job lapack.Job, m, n int, a []float64, lda int, s []float64, u []float64, ldu int, vt []float64, ldvt int) {
pU := (*float64)(nil)
if len(u) > 0 {
pU = &u[0]
}
pVt := (*float64)(nil)
if len(vt) > 0 {
pVt = &vt[0]
}
C.LAPACKE_dgesdd(
C.int(o), C.char(job),
C.int(m), C.int(n), (*C.double)(&a[0]), C.int(lda),
(*C.double)(&s[0]),
(*C.double)(pU), C.int(ldu),
(*C.double)(pVt), C.int(ldvt))
}
func (La) Dgebrd(o blas.Order, m, n int, a []float64, lda int, d, e, tauq, taup []float64) {
C.LAPACKE_dgebrd(
C.int(o), C.int(m), C.int(n), (*C.double)(&a[0]), C.int(lda),
(*C.double)(&d[0]),
(*C.double)(&e[0]),
(*C.double)(&tauq[0]),
(*C.double)(&taup[0]))
}
func (La) Dbdsdc(o blas.Order, uplo blas.Uplo, compq lapack.CompSV, n int,
d, e []float64, u []float64, ldu int, vt []float64, ldvt int, q []float64, iq []int32) {
pU := (*float64)(nil)
if len(u) > 0 {
pU = &u[0]
}
pVt := (*float64)(nil)
if len(vt) > 0 {
pVt = &vt[0]
}
pq := (*float64)(nil)
if len(q) > 0 {
pU = &q[0]
}
piq := (*int32)(nil)
if len(iq) > 0 {
piq = &iq[0]
}
cuplo := C.char('u')
if uplo == blas.Lower {
cuplo = 'l'
}
C.LAPACKE_dbdsdc(C.int(o), cuplo, C.char(compq),
(C.int)(n),
(*C.double)(&d[0]),
(*C.double)(&e[0]),
(*C.double)(pU),
(C.int)(ldu),
(*C.double)(pVt),
(C.int)(ldvt),
(*C.double)(pq),
(*C.int)(piq))
}
func (La) Zgeqrf(o blas.Order, m, n int, a []complex128, lda int, tau []complex128) {
C.LAPACKE_zgeqrf(C.int(o), C.int(m), C.int(n), (*C.complex)(&a[0]), C.int(lda), (*C.complex)(&tau[0]))
}
func (La) Zunmqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []complex128, lda int, tau []complex128, c []complex128, ldc int) {
var cs, ct C.char
if s == blas.Left {
cs = 'l'
} else {
cs = 'r'
}
if t == blas.NoTrans {
ct = 'n'
} else {
ct = 'c'
}
C.LAPACKE_zunmqr(C.int(o), cs, ct, C.int(m),
C.int(n), C.int(k), (*C.complex)(&a[0]),
C.int(lda), (*C.complex)(&tau[0]), (*C.complex)(&c[0]), C.int(ldc))
}
func (La) Zgesdd(o blas.Order, job lapack.Job, m, n int, a []complex128, lda int, s []float64, u []complex128, ldu int, vt []complex128, ldvt int) {
pU := (*complex128)(nil)
if len(U.Data) > 0 {
pU = &U.Data[0]
if len(u) > 0 {
pU = &u[0]
}
pVt := (*complex128)(nil)
if len(Vt.Data) > 0 {
pVt = &Vt.Data[0]
if len(vt) > 0 {
pVt = &vt[0]
}
C.LAPACKE_zgesdd(
C.int(A.Order), C.char(jobz),
C.int(A.Rows), C.int(A.Cols), (*C.complex)(&A.Data[0]), C.int(A.Stride),
C.int(o), C.char(job),
C.int(m), C.int(n), (*C.complex)(&a[0]), C.int(lda),
(*C.double)(&s[0]),
(*C.complex)(pU), C.int(U.Stride),
(*C.complex)(pVt), C.int(Vt.Stride))
(*C.complex)(pU), C.int(ldu),
(*C.complex)(pVt), C.int(ldvt))
}
func (La) Zgebrd(o blas.Order, m, n int, a []complex128, lda int, d, e []float64, tauq, taup []complex128) {
C.LAPACKE_zgebrd(
C.int(o),
C.int(m), C.int(n), (*C.complex)(&a[0]), C.int(lda),
(*C.double)(&d[0]),
(*C.double)(&e[0]),
(*C.complex)(&tauq[0]),
(*C.complex)(&taup[0]))
}

View File

@@ -11,7 +11,7 @@ type QRFact struct {
}
func QR(A dbw.General, tau []float64) QRFact {
impl.Dgeqrf(A, tau)
impl.Dgeqrf(A.Order, A.Rows, A.Cols, A.Data, A.Stride, tau)
return QRFact{A, tau}
}
@@ -20,7 +20,13 @@ func (f QRFact) R() dbw.Triangular {
}
func (f QRFact) Solve(B dbw.General) dbw.General {
impl.Dormqr(blas.Left, blas.Trans, f.a, f.tau, B)
if B.Order != f.a.Order {
panic("Order missmatch")
}
if f.a.Cols != B.Cols {
panic("dimension missmatch")
}
impl.Dormqr(B.Order, blas.Left, blas.Trans, f.a.Rows, B.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, B.Data, B.Stride)
B.Rows = f.a.Cols
dbw.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B)
return B

41
dla/svd.go Normal file
View File

@@ -0,0 +1,41 @@
package dla
import (
"github.com/dane-unltd/lapack"
"github.com/gonum/blas"
"github.com/gonum/blas/dbw"
)
func SVD(A dbw.General) (U dbw.General, s []float64, Vt dbw.General) {
m := A.Rows
n := A.Cols
U.Stride = 1
Vt.Stride = 1
if m >= n {
Vt = dbw.NewGeneral(A.Order, n, n, nil)
s = make([]float64, n)
U = A
} else {
U = dbw.NewGeneral(A.Order, m, m, nil)
s = make([]float64, m)
Vt = A
}
impl.Dgesdd(A.Order, lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride)
return
}
func SVDbd(uplo blas.Uplo, d, e []float64) (U dbw.General, s []float64, Vt dbw.General) {
n := len(d)
if len(e) != n {
panic("dimensionality missmatch")
}
U = dbw.NewGeneral(blas.ColMajor, n, n, nil)
Vt = dbw.NewGeneral(blas.ColMajor, n, n, nil)
impl.Dbdsdc(blas.ColMajor, uplo, lapack.Explicit, n, d, e, U.Data, U.Stride, Vt.Data, Vt.Stride, nil, nil)
s = d
return
}

View File

@@ -2,15 +2,38 @@ package lapack
import (
"github.com/gonum/blas"
"github.com/gonum/blas/dbw"
"github.com/gonum/blas/zbw"
)
const None = 'N'
type Job byte
const (
All (Job) = 'A'
Slim (Job) = 'S'
Overwrite (Job) = 'O'
)
type CompSV byte
const (
Compact (CompSV) = 'P'
Explicit (CompSV) = 'I'
)
type Float64 interface {
Dgeqrf(A dbw.General, tau []float64)
Dormqr(s blas.Side, t blas.Transpose, A dbw.General, tau []float64, B dbw.General)
Dgeqrf(o blas.Order, m, n int, a []float64, lda int, tau []float64)
Dormqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []float64, lda int, tau []float64, c []float64, ldc int)
Dgesdd(o blas.Order, job Job, m, n int, a []float64, lda int, s []float64, u []float64, ldu int, vt []float64, ldvt int)
Dgebrd(o blas.Order, m, n int, a []float64, lda int, d, e, tauq, taup []float64)
Dbdsdc(o blas.Order, uplo blas.Uplo, compq CompSV, n int, d, e []float64, u []float64, ldu int, vt []float64, ldvt int, q []float64, iq []int32)
}
type Complex128 interface {
Zgesvd(jobu byte, jobvt byte, A zbw.General, s []float64, U zbw.General, Vt zbw.General, superb []float64)
Float64
Zgeqrf(o blas.Order, m, n int, a []complex128, lda int, tau []complex128)
Zunmqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []complex128, lda int, tau []complex128, c []complex128, ldc int)
Zgesdd(o blas.Order, job Job, m, n int, a []complex128, lda int, s []float64, u []complex128, ldu int, vt []complex128, ldvt int)
Zgebrd(o blas.Order, m, n int, a []complex128, lda int, d, e []float64, tauq, taup []complex128)
}

9
zla/impl.go Normal file
View File

@@ -0,0 +1,9 @@
package zla
import "github.com/dane-unltd/lapack"
var impl lapack.Complex128
func Register(i lapack.Complex128) {
impl = i
}

33
zla/qr.go Normal file
View File

@@ -0,0 +1,33 @@
package zla
import (
"github.com/gonum/blas"
"github.com/gonum/blas/zbw"
)
type QRFact struct {
a zbw.General
tau []complex128
}
func QR(A zbw.General, tau []complex128) QRFact {
impl.Zgeqrf(A.Order, A.Rows, A.Cols, A.Data, A.Stride, tau)
return QRFact{A, tau}
}
func (f QRFact) R() zbw.Triangular {
return zbw.Ge2Tr(f.a, blas.NonUnit, blas.Upper)
}
func (f QRFact) Solve(B zbw.General) zbw.General {
if B.Order != f.a.Order {
panic("Order missmatch")
}
if f.a.Cols != B.Cols {
panic("dimension missmatch")
}
impl.Zunmqr(B.Order, blas.Left, blas.Trans, f.a.Rows, B.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, B.Data, B.Stride)
B.Rows = f.a.Cols
zbw.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B)
return B
}

79
zla/svd.go Normal file
View File

@@ -0,0 +1,79 @@
package zla
import (
"github.com/dane-unltd/lapack"
"github.com/gonum/blas"
"github.com/gonum/blas/zbw"
)
func SVD(A zbw.General) (U zbw.General, s []float64, Vt zbw.General) {
m := A.Rows
n := A.Cols
U.Stride = 1
Vt.Stride = 1
if m >= n {
Vt = zbw.NewGeneral(A.Order, n, n, nil)
s = make([]float64, n)
U = A
} else {
U = zbw.NewGeneral(A.Order, m, m, nil)
s = make([]float64, m)
Vt = A
}
impl.Zgesdd(A.Order, lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride)
return
}
//Lanczos bidiagonalization with full reorthogonalization
func LanczosBi(L zbw.General, u []complex128, numIter int) (U zbw.General, V zbw.General, a []float64, b []float64) {
m := L.Rows
n := L.Cols
uv := zbw.NewVector(u)
zbw.Scal(complex(1/zbw.Nrm2(uv), 0), uv)
U = zbw.NewGeneral(blas.ColMajor, m, numIter, nil)
V = zbw.NewGeneral(blas.ColMajor, n, numIter, nil)
a = make([]float64, numIter)
b = make([]float64, numIter)
zbw.Copy(uv, U.Col(0))
tr := zbw.NewVector(zbw.Allocate(n))
zbw.Gemv(blas.ConjTrans, 1, L, uv, 0, tr)
a[0] = zbw.Nrm2(tr)
zbw.Copy(tr, V.Col(0))
zbw.Scal(complex(1/a[0], 0), V.Col(0))
tl := zbw.NewVector(zbw.Allocate(m))
for k := 0; k < numIter-1; k++ {
zbw.Copy(U.Col(k), tl)
zbw.Scal(complex(-a[k], 0), tl)
zbw.Gemv(blas.NoTrans, 1, L, V.Col(k), 1, tl)
for i := 0; i <= k; i++ {
zbw.Axpy(-zbw.Dotc(U.Col(i), tl), U.Col(i), tl)
}
b[k] = zbw.Nrm2(tl)
zbw.Copy(tl, U.Col(k+1))
zbw.Scal(complex(1/b[k], 0), U.Col(k+1))
zbw.Copy(V.Col(k), tr)
zbw.Scal(complex(-b[k], 0), tr)
zbw.Gemv(blas.ConjTrans, 1, L, U.Col(k+1), 1, tr)
for i := 0; i <= k; i++ {
zbw.Axpy(-zbw.Dotc(V.Col(i), tr), V.Col(i), tr)
}
a[k+1] = zbw.Nrm2(tr)
zbw.Copy(tr, V.Col(k+1))
zbw.Scal(complex(1/a[k+1], 0), V.Col(k+1))
}
return
}

94
zla/zla_test.go Normal file
View File

@@ -0,0 +1,94 @@
package zla
import (
"fmt"
"math"
"math/rand"
"testing"
"github.com/dane-unltd/lapack/clapack"
"github.com/dane-unltd/lapack/dla"
"github.com/gonum/blas"
"github.com/gonum/blas/cblas"
"github.com/gonum/blas/dbw"
"github.com/gonum/blas/zbw"
)
func init() {
Register(clapack.La{})
dla.Register(clapack.La{})
zbw.Register(cblas.Blas{})
dbw.Register(cblas.Blas{})
}
func fillRandn(a []complex128, mu complex128, sigmaSq float64) {
fact := math.Sqrt(0.5 * sigmaSq)
for i := range a {
a[i] = complex(fact*rand.NormFloat64(), fact*rand.NormFloat64()) + mu
}
}
func TestQR(t *testing.T) {
A := zbw.NewGeneral(blas.ColMajor, 3, 2,
[]complex128{complex(1, 0), complex(2, 0), complex(3, 0),
complex(4, 0), complex(5, 0), complex(6, 0)})
B := zbw.NewGeneral(blas.ColMajor, 3, 2,
[]complex128{complex(1, 0), complex(1, 0), complex(1, 0), complex(2, 0), complex(2, 0), complex(2, 0)})
tau := zbw.Allocate(2)
f := QR(A, tau)
//fmt.Println(B)
f.Solve(B)
//fmt.Println(B)
}
func TestLanczos(t *testing.T) {
A := zbw.NewGeneral(blas.ColMajor, 3, 4, nil)
fillRandn(A.Data, 0, 1)
Acpy := zbw.NewGeneral(blas.ColMajor, 3, 4, nil)
copy(Acpy.Data, A.Data)
u0 := make([]complex128, 3)
fillRandn(u0, 0, 1)
Ul, Vl, a, b := LanczosBi(Acpy, u0, 3)
fmt.Println(a, b)
tmpc := zbw.NewGeneral(blas.ColMajor, 3, 3, nil)
bidic := zbw.NewGeneral(blas.ColMajor, 3, 3, nil)
zbw.Gemm(blas.NoTrans, blas.NoTrans, 1, A, Vl, 0, tmpc)
zbw.Gemm(blas.ConjTrans, blas.NoTrans, 1, Ul, tmpc, 0, bidic)
fmt.Println(bidic)
Ur, s, Vr := dla.SVDbd(blas.Lower, a, b)
tmp := dbw.NewGeneral(blas.ColMajor, 3, 3, nil)
bidi := dbw.NewGeneral(blas.ColMajor, 3, 3, nil)
copy(tmp.Data, Ur.Data)
for i := 0; i < 3; i++ {
dbw.Scal(s[i], tmp.Col(i))
}
dbw.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, Vr, 0, bidi)
fmt.Println(bidi)
/*
_ = Ul
_ = Vl
Uc := zbw.NewGeneral(blas.ColMajor, 3, 3, nil)
zbw.Real2Cmplx(Ur.Data[:3*3], Uc.Data)
fmt.Println(Uc.Data)
U := zbw.NewGeneral(blas.ColMajor, M, K, nil)
zbw.Gemm(blas.NoTrans, blas.NoTrans, 1, U1, Uc, 0, U)
*/
}