mirror of
https://github.com/gonum/gonum.git
synced 2025-10-20 21:59:25 +08:00
added some functionality
This commit is contained in:
@@ -9,19 +9,22 @@ package clapack
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"github.com/dane-unltd/lapack"
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/dbw"
|
||||
"github.com/gonum/blas/zbw"
|
||||
)
|
||||
|
||||
type La struct{}
|
||||
|
||||
func (La) Dgeqrf(A dbw.General, tau []float64) {
|
||||
C.LAPACKE_dgeqrf(C.int(A.Order), C.int(A.Rows), C.int(A.Cols),
|
||||
(*C.double)(&A.Data[0]), C.int(A.Stride), (*C.double)(&tau[0]))
|
||||
func init() {
|
||||
_ = lapack.Complex128(La{})
|
||||
_ = lapack.Float64(La{})
|
||||
}
|
||||
|
||||
func (La) Dormqr(s blas.Side, t blas.Transpose, A dbw.General, tau []float64, B dbw.General) {
|
||||
func (La) Dgeqrf(o blas.Order, m, n int, a []float64, lda int, tau []float64) {
|
||||
C.LAPACKE_dgeqrf(C.int(o), C.int(m), C.int(n), (*C.double)(&a[0]), C.int(lda), (*C.double)(&tau[0]))
|
||||
}
|
||||
|
||||
func (La) Dormqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []float64, lda int, tau []float64, c []float64, ldc int) {
|
||||
var cs, ct C.char
|
||||
if s == blas.Left {
|
||||
cs = 'l'
|
||||
@@ -34,24 +37,118 @@ func (La) Dormqr(s blas.Side, t blas.Transpose, A dbw.General, tau []float64, B
|
||||
ct = 't'
|
||||
}
|
||||
|
||||
C.LAPACKE_dormqr(C.int(A.Order), cs, ct, C.int(B.Rows),
|
||||
C.int(B.Cols), C.int(A.Cols), (*C.double)(&A.Data[0]),
|
||||
C.int(A.Stride), (*C.double)(&tau[0]), (*C.double)(&B.Data[0]), C.int(B.Stride))
|
||||
C.LAPACKE_dormqr(C.int(o), cs, ct, C.int(m),
|
||||
C.int(n), C.int(k), (*C.double)(&a[0]),
|
||||
C.int(lda), (*C.double)(&tau[0]), (*C.double)(&c[0]), C.int(ldc))
|
||||
}
|
||||
|
||||
func (La) Zgesvd(jobz byte, A zbw.General, s []float64, U zbw.General, Vt zbw.General) {
|
||||
func (La) Dgesdd(o blas.Order, job lapack.Job, m, n int, a []float64, lda int, s []float64, u []float64, ldu int, vt []float64, ldvt int) {
|
||||
pU := (*float64)(nil)
|
||||
if len(u) > 0 {
|
||||
pU = &u[0]
|
||||
}
|
||||
pVt := (*float64)(nil)
|
||||
if len(vt) > 0 {
|
||||
pVt = &vt[0]
|
||||
}
|
||||
C.LAPACKE_dgesdd(
|
||||
C.int(o), C.char(job),
|
||||
C.int(m), C.int(n), (*C.double)(&a[0]), C.int(lda),
|
||||
(*C.double)(&s[0]),
|
||||
(*C.double)(pU), C.int(ldu),
|
||||
(*C.double)(pVt), C.int(ldvt))
|
||||
}
|
||||
|
||||
func (La) Dgebrd(o blas.Order, m, n int, a []float64, lda int, d, e, tauq, taup []float64) {
|
||||
C.LAPACKE_dgebrd(
|
||||
C.int(o), C.int(m), C.int(n), (*C.double)(&a[0]), C.int(lda),
|
||||
(*C.double)(&d[0]),
|
||||
(*C.double)(&e[0]),
|
||||
(*C.double)(&tauq[0]),
|
||||
(*C.double)(&taup[0]))
|
||||
}
|
||||
|
||||
func (La) Dbdsdc(o blas.Order, uplo blas.Uplo, compq lapack.CompSV, n int,
|
||||
d, e []float64, u []float64, ldu int, vt []float64, ldvt int, q []float64, iq []int32) {
|
||||
pU := (*float64)(nil)
|
||||
if len(u) > 0 {
|
||||
pU = &u[0]
|
||||
}
|
||||
pVt := (*float64)(nil)
|
||||
if len(vt) > 0 {
|
||||
pVt = &vt[0]
|
||||
}
|
||||
pq := (*float64)(nil)
|
||||
if len(q) > 0 {
|
||||
pU = &q[0]
|
||||
}
|
||||
piq := (*int32)(nil)
|
||||
if len(iq) > 0 {
|
||||
piq = &iq[0]
|
||||
}
|
||||
|
||||
cuplo := C.char('u')
|
||||
if uplo == blas.Lower {
|
||||
cuplo = 'l'
|
||||
}
|
||||
|
||||
C.LAPACKE_dbdsdc(C.int(o), cuplo, C.char(compq),
|
||||
(C.int)(n),
|
||||
(*C.double)(&d[0]),
|
||||
(*C.double)(&e[0]),
|
||||
(*C.double)(pU),
|
||||
(C.int)(ldu),
|
||||
(*C.double)(pVt),
|
||||
(C.int)(ldvt),
|
||||
(*C.double)(pq),
|
||||
(*C.int)(piq))
|
||||
}
|
||||
|
||||
func (La) Zgeqrf(o blas.Order, m, n int, a []complex128, lda int, tau []complex128) {
|
||||
C.LAPACKE_zgeqrf(C.int(o), C.int(m), C.int(n), (*C.complex)(&a[0]), C.int(lda), (*C.complex)(&tau[0]))
|
||||
}
|
||||
|
||||
func (La) Zunmqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []complex128, lda int, tau []complex128, c []complex128, ldc int) {
|
||||
var cs, ct C.char
|
||||
if s == blas.Left {
|
||||
cs = 'l'
|
||||
} else {
|
||||
cs = 'r'
|
||||
}
|
||||
if t == blas.NoTrans {
|
||||
ct = 'n'
|
||||
} else {
|
||||
ct = 'c'
|
||||
}
|
||||
|
||||
C.LAPACKE_zunmqr(C.int(o), cs, ct, C.int(m),
|
||||
C.int(n), C.int(k), (*C.complex)(&a[0]),
|
||||
C.int(lda), (*C.complex)(&tau[0]), (*C.complex)(&c[0]), C.int(ldc))
|
||||
}
|
||||
|
||||
func (La) Zgesdd(o blas.Order, job lapack.Job, m, n int, a []complex128, lda int, s []float64, u []complex128, ldu int, vt []complex128, ldvt int) {
|
||||
pU := (*complex128)(nil)
|
||||
if len(U.Data) > 0 {
|
||||
pU = &U.Data[0]
|
||||
if len(u) > 0 {
|
||||
pU = &u[0]
|
||||
}
|
||||
pVt := (*complex128)(nil)
|
||||
if len(Vt.Data) > 0 {
|
||||
pVt = &Vt.Data[0]
|
||||
if len(vt) > 0 {
|
||||
pVt = &vt[0]
|
||||
}
|
||||
C.LAPACKE_zgesdd(
|
||||
C.int(A.Order), C.char(jobz),
|
||||
C.int(A.Rows), C.int(A.Cols), (*C.complex)(&A.Data[0]), C.int(A.Stride),
|
||||
C.int(o), C.char(job),
|
||||
C.int(m), C.int(n), (*C.complex)(&a[0]), C.int(lda),
|
||||
(*C.double)(&s[0]),
|
||||
(*C.complex)(pU), C.int(U.Stride),
|
||||
(*C.complex)(pVt), C.int(Vt.Stride))
|
||||
(*C.complex)(pU), C.int(ldu),
|
||||
(*C.complex)(pVt), C.int(ldvt))
|
||||
}
|
||||
|
||||
func (La) Zgebrd(o blas.Order, m, n int, a []complex128, lda int, d, e []float64, tauq, taup []complex128) {
|
||||
C.LAPACKE_zgebrd(
|
||||
C.int(o),
|
||||
C.int(m), C.int(n), (*C.complex)(&a[0]), C.int(lda),
|
||||
(*C.double)(&d[0]),
|
||||
(*C.double)(&e[0]),
|
||||
(*C.complex)(&tauq[0]),
|
||||
(*C.complex)(&taup[0]))
|
||||
}
|
||||
|
@@ -11,7 +11,7 @@ type QRFact struct {
|
||||
}
|
||||
|
||||
func QR(A dbw.General, tau []float64) QRFact {
|
||||
impl.Dgeqrf(A, tau)
|
||||
impl.Dgeqrf(A.Order, A.Rows, A.Cols, A.Data, A.Stride, tau)
|
||||
return QRFact{A, tau}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,13 @@ func (f QRFact) R() dbw.Triangular {
|
||||
}
|
||||
|
||||
func (f QRFact) Solve(B dbw.General) dbw.General {
|
||||
impl.Dormqr(blas.Left, blas.Trans, f.a, f.tau, B)
|
||||
if B.Order != f.a.Order {
|
||||
panic("Order missmatch")
|
||||
}
|
||||
if f.a.Cols != B.Cols {
|
||||
panic("dimension missmatch")
|
||||
}
|
||||
impl.Dormqr(B.Order, blas.Left, blas.Trans, f.a.Rows, B.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, B.Data, B.Stride)
|
||||
B.Rows = f.a.Cols
|
||||
dbw.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B)
|
||||
return B
|
41
dla/svd.go
Normal file
41
dla/svd.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package dla
|
||||
|
||||
import (
|
||||
"github.com/dane-unltd/lapack"
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/dbw"
|
||||
)
|
||||
|
||||
func SVD(A dbw.General) (U dbw.General, s []float64, Vt dbw.General) {
|
||||
m := A.Rows
|
||||
n := A.Cols
|
||||
U.Stride = 1
|
||||
Vt.Stride = 1
|
||||
if m >= n {
|
||||
Vt = dbw.NewGeneral(A.Order, n, n, nil)
|
||||
s = make([]float64, n)
|
||||
U = A
|
||||
} else {
|
||||
U = dbw.NewGeneral(A.Order, m, m, nil)
|
||||
s = make([]float64, m)
|
||||
Vt = A
|
||||
}
|
||||
|
||||
impl.Dgesdd(A.Order, lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func SVDbd(uplo blas.Uplo, d, e []float64) (U dbw.General, s []float64, Vt dbw.General) {
|
||||
n := len(d)
|
||||
if len(e) != n {
|
||||
panic("dimensionality missmatch")
|
||||
}
|
||||
|
||||
U = dbw.NewGeneral(blas.ColMajor, n, n, nil)
|
||||
Vt = dbw.NewGeneral(blas.ColMajor, n, n, nil)
|
||||
|
||||
impl.Dbdsdc(blas.ColMajor, uplo, lapack.Explicit, n, d, e, U.Data, U.Stride, Vt.Data, Vt.Stride, nil, nil)
|
||||
s = d
|
||||
return
|
||||
}
|
33
lapack.go
33
lapack.go
@@ -2,15 +2,38 @@ package lapack
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/dbw"
|
||||
"github.com/gonum/blas/zbw"
|
||||
)
|
||||
|
||||
const None = 'N'
|
||||
|
||||
type Job byte
|
||||
|
||||
const (
|
||||
All (Job) = 'A'
|
||||
Slim (Job) = 'S'
|
||||
Overwrite (Job) = 'O'
|
||||
)
|
||||
|
||||
type CompSV byte
|
||||
|
||||
const (
|
||||
Compact (CompSV) = 'P'
|
||||
Explicit (CompSV) = 'I'
|
||||
)
|
||||
|
||||
type Float64 interface {
|
||||
Dgeqrf(A dbw.General, tau []float64)
|
||||
Dormqr(s blas.Side, t blas.Transpose, A dbw.General, tau []float64, B dbw.General)
|
||||
Dgeqrf(o blas.Order, m, n int, a []float64, lda int, tau []float64)
|
||||
Dormqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []float64, lda int, tau []float64, c []float64, ldc int)
|
||||
Dgesdd(o blas.Order, job Job, m, n int, a []float64, lda int, s []float64, u []float64, ldu int, vt []float64, ldvt int)
|
||||
Dgebrd(o blas.Order, m, n int, a []float64, lda int, d, e, tauq, taup []float64)
|
||||
Dbdsdc(o blas.Order, uplo blas.Uplo, compq CompSV, n int, d, e []float64, u []float64, ldu int, vt []float64, ldvt int, q []float64, iq []int32)
|
||||
}
|
||||
|
||||
type Complex128 interface {
|
||||
Zgesvd(jobu byte, jobvt byte, A zbw.General, s []float64, U zbw.General, Vt zbw.General, superb []float64)
|
||||
Float64
|
||||
|
||||
Zgeqrf(o blas.Order, m, n int, a []complex128, lda int, tau []complex128)
|
||||
Zunmqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []complex128, lda int, tau []complex128, c []complex128, ldc int)
|
||||
Zgesdd(o blas.Order, job Job, m, n int, a []complex128, lda int, s []float64, u []complex128, ldu int, vt []complex128, ldvt int)
|
||||
Zgebrd(o blas.Order, m, n int, a []complex128, lda int, d, e []float64, tauq, taup []complex128)
|
||||
}
|
||||
|
9
zla/impl.go
Normal file
9
zla/impl.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package zla
|
||||
|
||||
import "github.com/dane-unltd/lapack"
|
||||
|
||||
var impl lapack.Complex128
|
||||
|
||||
func Register(i lapack.Complex128) {
|
||||
impl = i
|
||||
}
|
33
zla/qr.go
Normal file
33
zla/qr.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package zla
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/zbw"
|
||||
)
|
||||
|
||||
type QRFact struct {
|
||||
a zbw.General
|
||||
tau []complex128
|
||||
}
|
||||
|
||||
func QR(A zbw.General, tau []complex128) QRFact {
|
||||
impl.Zgeqrf(A.Order, A.Rows, A.Cols, A.Data, A.Stride, tau)
|
||||
return QRFact{A, tau}
|
||||
}
|
||||
|
||||
func (f QRFact) R() zbw.Triangular {
|
||||
return zbw.Ge2Tr(f.a, blas.NonUnit, blas.Upper)
|
||||
}
|
||||
|
||||
func (f QRFact) Solve(B zbw.General) zbw.General {
|
||||
if B.Order != f.a.Order {
|
||||
panic("Order missmatch")
|
||||
}
|
||||
if f.a.Cols != B.Cols {
|
||||
panic("dimension missmatch")
|
||||
}
|
||||
impl.Zunmqr(B.Order, blas.Left, blas.Trans, f.a.Rows, B.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, B.Data, B.Stride)
|
||||
B.Rows = f.a.Cols
|
||||
zbw.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B)
|
||||
return B
|
||||
}
|
79
zla/svd.go
Normal file
79
zla/svd.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package zla
|
||||
|
||||
import (
|
||||
"github.com/dane-unltd/lapack"
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/zbw"
|
||||
)
|
||||
|
||||
func SVD(A zbw.General) (U zbw.General, s []float64, Vt zbw.General) {
|
||||
m := A.Rows
|
||||
n := A.Cols
|
||||
U.Stride = 1
|
||||
Vt.Stride = 1
|
||||
if m >= n {
|
||||
Vt = zbw.NewGeneral(A.Order, n, n, nil)
|
||||
s = make([]float64, n)
|
||||
U = A
|
||||
} else {
|
||||
U = zbw.NewGeneral(A.Order, m, m, nil)
|
||||
s = make([]float64, m)
|
||||
Vt = A
|
||||
}
|
||||
|
||||
impl.Zgesdd(A.Order, lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
//Lanczos bidiagonalization with full reorthogonalization
|
||||
func LanczosBi(L zbw.General, u []complex128, numIter int) (U zbw.General, V zbw.General, a []float64, b []float64) {
|
||||
|
||||
m := L.Rows
|
||||
n := L.Cols
|
||||
|
||||
uv := zbw.NewVector(u)
|
||||
zbw.Scal(complex(1/zbw.Nrm2(uv), 0), uv)
|
||||
|
||||
U = zbw.NewGeneral(blas.ColMajor, m, numIter, nil)
|
||||
V = zbw.NewGeneral(blas.ColMajor, n, numIter, nil)
|
||||
|
||||
a = make([]float64, numIter)
|
||||
b = make([]float64, numIter)
|
||||
|
||||
zbw.Copy(uv, U.Col(0))
|
||||
|
||||
tr := zbw.NewVector(zbw.Allocate(n))
|
||||
zbw.Gemv(blas.ConjTrans, 1, L, uv, 0, tr)
|
||||
a[0] = zbw.Nrm2(tr)
|
||||
zbw.Copy(tr, V.Col(0))
|
||||
zbw.Scal(complex(1/a[0], 0), V.Col(0))
|
||||
|
||||
tl := zbw.NewVector(zbw.Allocate(m))
|
||||
for k := 0; k < numIter-1; k++ {
|
||||
zbw.Copy(U.Col(k), tl)
|
||||
zbw.Scal(complex(-a[k], 0), tl)
|
||||
zbw.Gemv(blas.NoTrans, 1, L, V.Col(k), 1, tl)
|
||||
|
||||
for i := 0; i <= k; i++ {
|
||||
zbw.Axpy(-zbw.Dotc(U.Col(i), tl), U.Col(i), tl)
|
||||
}
|
||||
|
||||
b[k] = zbw.Nrm2(tl)
|
||||
zbw.Copy(tl, U.Col(k+1))
|
||||
zbw.Scal(complex(1/b[k], 0), U.Col(k+1))
|
||||
|
||||
zbw.Copy(V.Col(k), tr)
|
||||
zbw.Scal(complex(-b[k], 0), tr)
|
||||
zbw.Gemv(blas.ConjTrans, 1, L, U.Col(k+1), 1, tr)
|
||||
|
||||
for i := 0; i <= k; i++ {
|
||||
zbw.Axpy(-zbw.Dotc(V.Col(i), tr), V.Col(i), tr)
|
||||
}
|
||||
|
||||
a[k+1] = zbw.Nrm2(tr)
|
||||
zbw.Copy(tr, V.Col(k+1))
|
||||
zbw.Scal(complex(1/a[k+1], 0), V.Col(k+1))
|
||||
}
|
||||
return
|
||||
}
|
94
zla/zla_test.go
Normal file
94
zla/zla_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package zla
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/dane-unltd/lapack/clapack"
|
||||
"github.com/dane-unltd/lapack/dla"
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/cblas"
|
||||
"github.com/gonum/blas/dbw"
|
||||
"github.com/gonum/blas/zbw"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register(clapack.La{})
|
||||
dla.Register(clapack.La{})
|
||||
zbw.Register(cblas.Blas{})
|
||||
dbw.Register(cblas.Blas{})
|
||||
}
|
||||
|
||||
func fillRandn(a []complex128, mu complex128, sigmaSq float64) {
|
||||
fact := math.Sqrt(0.5 * sigmaSq)
|
||||
for i := range a {
|
||||
a[i] = complex(fact*rand.NormFloat64(), fact*rand.NormFloat64()) + mu
|
||||
}
|
||||
}
|
||||
|
||||
func TestQR(t *testing.T) {
|
||||
A := zbw.NewGeneral(blas.ColMajor, 3, 2,
|
||||
[]complex128{complex(1, 0), complex(2, 0), complex(3, 0),
|
||||
complex(4, 0), complex(5, 0), complex(6, 0)})
|
||||
B := zbw.NewGeneral(blas.ColMajor, 3, 2,
|
||||
[]complex128{complex(1, 0), complex(1, 0), complex(1, 0), complex(2, 0), complex(2, 0), complex(2, 0)})
|
||||
|
||||
tau := zbw.Allocate(2)
|
||||
|
||||
f := QR(A, tau)
|
||||
|
||||
//fmt.Println(B)
|
||||
f.Solve(B)
|
||||
//fmt.Println(B)
|
||||
}
|
||||
|
||||
func TestLanczos(t *testing.T) {
|
||||
A := zbw.NewGeneral(blas.ColMajor, 3, 4, nil)
|
||||
fillRandn(A.Data, 0, 1)
|
||||
|
||||
Acpy := zbw.NewGeneral(blas.ColMajor, 3, 4, nil)
|
||||
copy(Acpy.Data, A.Data)
|
||||
|
||||
u0 := make([]complex128, 3)
|
||||
fillRandn(u0, 0, 1)
|
||||
|
||||
Ul, Vl, a, b := LanczosBi(Acpy, u0, 3)
|
||||
|
||||
fmt.Println(a, b)
|
||||
|
||||
tmpc := zbw.NewGeneral(blas.ColMajor, 3, 3, nil)
|
||||
bidic := zbw.NewGeneral(blas.ColMajor, 3, 3, nil)
|
||||
|
||||
zbw.Gemm(blas.NoTrans, blas.NoTrans, 1, A, Vl, 0, tmpc)
|
||||
zbw.Gemm(blas.ConjTrans, blas.NoTrans, 1, Ul, tmpc, 0, bidic)
|
||||
|
||||
fmt.Println(bidic)
|
||||
|
||||
Ur, s, Vr := dla.SVDbd(blas.Lower, a, b)
|
||||
|
||||
tmp := dbw.NewGeneral(blas.ColMajor, 3, 3, nil)
|
||||
bidi := dbw.NewGeneral(blas.ColMajor, 3, 3, nil)
|
||||
|
||||
copy(tmp.Data, Ur.Data)
|
||||
for i := 0; i < 3; i++ {
|
||||
dbw.Scal(s[i], tmp.Col(i))
|
||||
}
|
||||
|
||||
dbw.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, Vr, 0, bidi)
|
||||
|
||||
fmt.Println(bidi)
|
||||
/*
|
||||
|
||||
_ = Ul
|
||||
_ = Vl
|
||||
Uc := zbw.NewGeneral(blas.ColMajor, 3, 3, nil)
|
||||
zbw.Real2Cmplx(Ur.Data[:3*3], Uc.Data)
|
||||
|
||||
fmt.Println(Uc.Data)
|
||||
|
||||
U := zbw.NewGeneral(blas.ColMajor, M, K, nil)
|
||||
zbw.Gemm(blas.NoTrans, blas.NoTrans, 1, U1, Uc, 0, U)
|
||||
*/
|
||||
}
|
Reference in New Issue
Block a user