diff --git a/clapack/clapack.go b/clapack/clapack.go index cb38738e..df77c45c 100644 --- a/clapack/clapack.go +++ b/clapack/clapack.go @@ -9,19 +9,22 @@ package clapack */ import "C" import ( + "github.com/dane-unltd/lapack" "github.com/gonum/blas" - "github.com/gonum/blas/dbw" - "github.com/gonum/blas/zbw" ) type La struct{} -func (La) Dgeqrf(A dbw.General, tau []float64) { - C.LAPACKE_dgeqrf(C.int(A.Order), C.int(A.Rows), C.int(A.Cols), - (*C.double)(&A.Data[0]), C.int(A.Stride), (*C.double)(&tau[0])) +func init() { + _ = lapack.Complex128(La{}) + _ = lapack.Float64(La{}) } -func (La) Dormqr(s blas.Side, t blas.Transpose, A dbw.General, tau []float64, B dbw.General) { +func (La) Dgeqrf(o blas.Order, m, n int, a []float64, lda int, tau []float64) { + C.LAPACKE_dgeqrf(C.int(o), C.int(m), C.int(n), (*C.double)(&a[0]), C.int(lda), (*C.double)(&tau[0])) +} + +func (La) Dormqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []float64, lda int, tau []float64, c []float64, ldc int) { var cs, ct C.char if s == blas.Left { cs = 'l' @@ -34,24 +37,118 @@ func (La) Dormqr(s blas.Side, t blas.Transpose, A dbw.General, tau []float64, B ct = 't' } - C.LAPACKE_dormqr(C.int(A.Order), cs, ct, C.int(B.Rows), - C.int(B.Cols), C.int(A.Cols), (*C.double)(&A.Data[0]), - C.int(A.Stride), (*C.double)(&tau[0]), (*C.double)(&B.Data[0]), C.int(B.Stride)) + C.LAPACKE_dormqr(C.int(o), cs, ct, C.int(m), + C.int(n), C.int(k), (*C.double)(&a[0]), + C.int(lda), (*C.double)(&tau[0]), (*C.double)(&c[0]), C.int(ldc)) } -func (La) Zgesvd(jobz byte, A zbw.General, s []float64, U zbw.General, Vt zbw.General) { +func (La) Dgesdd(o blas.Order, job lapack.Job, m, n int, a []float64, lda int, s []float64, u []float64, ldu int, vt []float64, ldvt int) { + pU := (*float64)(nil) + if len(u) > 0 { + pU = &u[0] + } + pVt := (*float64)(nil) + if len(vt) > 0 { + pVt = &vt[0] + } + C.LAPACKE_dgesdd( + C.int(o), C.char(job), + C.int(m), C.int(n), (*C.double)(&a[0]), C.int(lda), + (*C.double)(&s[0]), + (*C.double)(pU), C.int(ldu), + (*C.double)(pVt), C.int(ldvt)) +} + +func (La) Dgebrd(o blas.Order, m, n int, a []float64, lda int, d, e, tauq, taup []float64) { + C.LAPACKE_dgebrd( + C.int(o), C.int(m), C.int(n), (*C.double)(&a[0]), C.int(lda), + (*C.double)(&d[0]), + (*C.double)(&e[0]), + (*C.double)(&tauq[0]), + (*C.double)(&taup[0])) +} + +func (La) Dbdsdc(o blas.Order, uplo blas.Uplo, compq lapack.CompSV, n int, + d, e []float64, u []float64, ldu int, vt []float64, ldvt int, q []float64, iq []int32) { + pU := (*float64)(nil) + if len(u) > 0 { + pU = &u[0] + } + pVt := (*float64)(nil) + if len(vt) > 0 { + pVt = &vt[0] + } + pq := (*float64)(nil) + if len(q) > 0 { + pU = &q[0] + } + piq := (*int32)(nil) + if len(iq) > 0 { + piq = &iq[0] + } + + cuplo := C.char('u') + if uplo == blas.Lower { + cuplo = 'l' + } + + C.LAPACKE_dbdsdc(C.int(o), cuplo, C.char(compq), + (C.int)(n), + (*C.double)(&d[0]), + (*C.double)(&e[0]), + (*C.double)(pU), + (C.int)(ldu), + (*C.double)(pVt), + (C.int)(ldvt), + (*C.double)(pq), + (*C.int)(piq)) +} + +func (La) Zgeqrf(o blas.Order, m, n int, a []complex128, lda int, tau []complex128) { + C.LAPACKE_zgeqrf(C.int(o), C.int(m), C.int(n), (*C.complex)(&a[0]), C.int(lda), (*C.complex)(&tau[0])) +} + +func (La) Zunmqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []complex128, lda int, tau []complex128, c []complex128, ldc int) { + var cs, ct C.char + if s == blas.Left { + cs = 'l' + } else { + cs = 'r' + } + if t == blas.NoTrans { + ct = 'n' + } else { + ct = 'c' + } + + C.LAPACKE_zunmqr(C.int(o), cs, ct, C.int(m), + C.int(n), C.int(k), (*C.complex)(&a[0]), + C.int(lda), (*C.complex)(&tau[0]), (*C.complex)(&c[0]), C.int(ldc)) +} + +func (La) Zgesdd(o blas.Order, job lapack.Job, m, n int, a []complex128, lda int, s []float64, u []complex128, ldu int, vt []complex128, ldvt int) { pU := (*complex128)(nil) - if len(U.Data) > 0 { - pU = &U.Data[0] + if len(u) > 0 { + pU = &u[0] } pVt := (*complex128)(nil) - if len(Vt.Data) > 0 { - pVt = &Vt.Data[0] + if len(vt) > 0 { + pVt = &vt[0] } C.LAPACKE_zgesdd( - C.int(A.Order), C.char(jobz), - C.int(A.Rows), C.int(A.Cols), (*C.complex)(&A.Data[0]), C.int(A.Stride), + C.int(o), C.char(job), + C.int(m), C.int(n), (*C.complex)(&a[0]), C.int(lda), (*C.double)(&s[0]), - (*C.complex)(pU), C.int(U.Stride), - (*C.complex)(pVt), C.int(Vt.Stride)) + (*C.complex)(pU), C.int(ldu), + (*C.complex)(pVt), C.int(ldvt)) +} + +func (La) Zgebrd(o blas.Order, m, n int, a []complex128, lda int, d, e []float64, tauq, taup []complex128) { + C.LAPACKE_zgebrd( + C.int(o), + C.int(m), C.int(n), (*C.complex)(&a[0]), C.int(lda), + (*C.double)(&d[0]), + (*C.double)(&e[0]), + (*C.complex)(&tauq[0]), + (*C.complex)(&taup[0])) } diff --git a/dla/dqr.go b/dla/qr.go similarity index 58% rename from dla/dqr.go rename to dla/qr.go index 67f10f65..07aa1e22 100644 --- a/dla/dqr.go +++ b/dla/qr.go @@ -11,7 +11,7 @@ type QRFact struct { } func QR(A dbw.General, tau []float64) QRFact { - impl.Dgeqrf(A, tau) + impl.Dgeqrf(A.Order, A.Rows, A.Cols, A.Data, A.Stride, tau) return QRFact{A, tau} } @@ -20,7 +20,13 @@ func (f QRFact) R() dbw.Triangular { } func (f QRFact) Solve(B dbw.General) dbw.General { - impl.Dormqr(blas.Left, blas.Trans, f.a, f.tau, B) + if B.Order != f.a.Order { + panic("Order missmatch") + } + if f.a.Cols != B.Cols { + panic("dimension missmatch") + } + impl.Dormqr(B.Order, blas.Left, blas.Trans, f.a.Rows, B.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, B.Data, B.Stride) B.Rows = f.a.Cols dbw.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B) return B diff --git a/dla/svd.go b/dla/svd.go new file mode 100644 index 00000000..3a077882 --- /dev/null +++ b/dla/svd.go @@ -0,0 +1,41 @@ +package dla + +import ( + "github.com/dane-unltd/lapack" + "github.com/gonum/blas" + "github.com/gonum/blas/dbw" +) + +func SVD(A dbw.General) (U dbw.General, s []float64, Vt dbw.General) { + m := A.Rows + n := A.Cols + U.Stride = 1 + Vt.Stride = 1 + if m >= n { + Vt = dbw.NewGeneral(A.Order, n, n, nil) + s = make([]float64, n) + U = A + } else { + U = dbw.NewGeneral(A.Order, m, m, nil) + s = make([]float64, m) + Vt = A + } + + impl.Dgesdd(A.Order, lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride) + + return +} + +func SVDbd(uplo blas.Uplo, d, e []float64) (U dbw.General, s []float64, Vt dbw.General) { + n := len(d) + if len(e) != n { + panic("dimensionality missmatch") + } + + U = dbw.NewGeneral(blas.ColMajor, n, n, nil) + Vt = dbw.NewGeneral(blas.ColMajor, n, n, nil) + + impl.Dbdsdc(blas.ColMajor, uplo, lapack.Explicit, n, d, e, U.Data, U.Stride, Vt.Data, Vt.Stride, nil, nil) + s = d + return +} diff --git a/lapack.go b/lapack.go index ae0d8621..32b831ee 100644 --- a/lapack.go +++ b/lapack.go @@ -2,15 +2,38 @@ package lapack import ( "github.com/gonum/blas" - "github.com/gonum/blas/dbw" - "github.com/gonum/blas/zbw" +) + +const None = 'N' + +type Job byte + +const ( + All (Job) = 'A' + Slim (Job) = 'S' + Overwrite (Job) = 'O' +) + +type CompSV byte + +const ( + Compact (CompSV) = 'P' + Explicit (CompSV) = 'I' ) type Float64 interface { - Dgeqrf(A dbw.General, tau []float64) - Dormqr(s blas.Side, t blas.Transpose, A dbw.General, tau []float64, B dbw.General) + Dgeqrf(o blas.Order, m, n int, a []float64, lda int, tau []float64) + Dormqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []float64, lda int, tau []float64, c []float64, ldc int) + Dgesdd(o blas.Order, job Job, m, n int, a []float64, lda int, s []float64, u []float64, ldu int, vt []float64, ldvt int) + Dgebrd(o blas.Order, m, n int, a []float64, lda int, d, e, tauq, taup []float64) + Dbdsdc(o blas.Order, uplo blas.Uplo, compq CompSV, n int, d, e []float64, u []float64, ldu int, vt []float64, ldvt int, q []float64, iq []int32) } type Complex128 interface { - Zgesvd(jobu byte, jobvt byte, A zbw.General, s []float64, U zbw.General, Vt zbw.General, superb []float64) + Float64 + + Zgeqrf(o blas.Order, m, n int, a []complex128, lda int, tau []complex128) + Zunmqr(o blas.Order, s blas.Side, t blas.Transpose, m, n, k int, a []complex128, lda int, tau []complex128, c []complex128, ldc int) + Zgesdd(o blas.Order, job Job, m, n int, a []complex128, lda int, s []float64, u []complex128, ldu int, vt []complex128, ldvt int) + Zgebrd(o blas.Order, m, n int, a []complex128, lda int, d, e []float64, tauq, taup []complex128) } diff --git a/zla/impl.go b/zla/impl.go new file mode 100644 index 00000000..9bd5fb35 --- /dev/null +++ b/zla/impl.go @@ -0,0 +1,9 @@ +package zla + +import "github.com/dane-unltd/lapack" + +var impl lapack.Complex128 + +func Register(i lapack.Complex128) { + impl = i +} diff --git a/zla/qr.go b/zla/qr.go new file mode 100644 index 00000000..1ef89949 --- /dev/null +++ b/zla/qr.go @@ -0,0 +1,33 @@ +package zla + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/zbw" +) + +type QRFact struct { + a zbw.General + tau []complex128 +} + +func QR(A zbw.General, tau []complex128) QRFact { + impl.Zgeqrf(A.Order, A.Rows, A.Cols, A.Data, A.Stride, tau) + return QRFact{A, tau} +} + +func (f QRFact) R() zbw.Triangular { + return zbw.Ge2Tr(f.a, blas.NonUnit, blas.Upper) +} + +func (f QRFact) Solve(B zbw.General) zbw.General { + if B.Order != f.a.Order { + panic("Order missmatch") + } + if f.a.Cols != B.Cols { + panic("dimension missmatch") + } + impl.Zunmqr(B.Order, blas.Left, blas.Trans, f.a.Rows, B.Cols, f.a.Cols, f.a.Data, f.a.Stride, f.tau, B.Data, B.Stride) + B.Rows = f.a.Cols + zbw.Trsm(blas.Left, blas.NoTrans, 1, f.R(), B) + return B +} diff --git a/zla/svd.go b/zla/svd.go new file mode 100644 index 00000000..e3e59f4f --- /dev/null +++ b/zla/svd.go @@ -0,0 +1,79 @@ +package zla + +import ( + "github.com/dane-unltd/lapack" + "github.com/gonum/blas" + "github.com/gonum/blas/zbw" +) + +func SVD(A zbw.General) (U zbw.General, s []float64, Vt zbw.General) { + m := A.Rows + n := A.Cols + U.Stride = 1 + Vt.Stride = 1 + if m >= n { + Vt = zbw.NewGeneral(A.Order, n, n, nil) + s = make([]float64, n) + U = A + } else { + U = zbw.NewGeneral(A.Order, m, m, nil) + s = make([]float64, m) + Vt = A + } + + impl.Zgesdd(A.Order, lapack.Overwrite, A.Rows, A.Cols, A.Data, A.Stride, s, U.Data, U.Stride, Vt.Data, Vt.Stride) + + return +} + +//Lanczos bidiagonalization with full reorthogonalization +func LanczosBi(L zbw.General, u []complex128, numIter int) (U zbw.General, V zbw.General, a []float64, b []float64) { + + m := L.Rows + n := L.Cols + + uv := zbw.NewVector(u) + zbw.Scal(complex(1/zbw.Nrm2(uv), 0), uv) + + U = zbw.NewGeneral(blas.ColMajor, m, numIter, nil) + V = zbw.NewGeneral(blas.ColMajor, n, numIter, nil) + + a = make([]float64, numIter) + b = make([]float64, numIter) + + zbw.Copy(uv, U.Col(0)) + + tr := zbw.NewVector(zbw.Allocate(n)) + zbw.Gemv(blas.ConjTrans, 1, L, uv, 0, tr) + a[0] = zbw.Nrm2(tr) + zbw.Copy(tr, V.Col(0)) + zbw.Scal(complex(1/a[0], 0), V.Col(0)) + + tl := zbw.NewVector(zbw.Allocate(m)) + for k := 0; k < numIter-1; k++ { + zbw.Copy(U.Col(k), tl) + zbw.Scal(complex(-a[k], 0), tl) + zbw.Gemv(blas.NoTrans, 1, L, V.Col(k), 1, tl) + + for i := 0; i <= k; i++ { + zbw.Axpy(-zbw.Dotc(U.Col(i), tl), U.Col(i), tl) + } + + b[k] = zbw.Nrm2(tl) + zbw.Copy(tl, U.Col(k+1)) + zbw.Scal(complex(1/b[k], 0), U.Col(k+1)) + + zbw.Copy(V.Col(k), tr) + zbw.Scal(complex(-b[k], 0), tr) + zbw.Gemv(blas.ConjTrans, 1, L, U.Col(k+1), 1, tr) + + for i := 0; i <= k; i++ { + zbw.Axpy(-zbw.Dotc(V.Col(i), tr), V.Col(i), tr) + } + + a[k+1] = zbw.Nrm2(tr) + zbw.Copy(tr, V.Col(k+1)) + zbw.Scal(complex(1/a[k+1], 0), V.Col(k+1)) + } + return +} diff --git a/zla/zla_test.go b/zla/zla_test.go new file mode 100644 index 00000000..31d36260 --- /dev/null +++ b/zla/zla_test.go @@ -0,0 +1,94 @@ +package zla + +import ( + "fmt" + "math" + "math/rand" + "testing" + + "github.com/dane-unltd/lapack/clapack" + "github.com/dane-unltd/lapack/dla" + "github.com/gonum/blas" + "github.com/gonum/blas/cblas" + "github.com/gonum/blas/dbw" + "github.com/gonum/blas/zbw" +) + +func init() { + Register(clapack.La{}) + dla.Register(clapack.La{}) + zbw.Register(cblas.Blas{}) + dbw.Register(cblas.Blas{}) +} + +func fillRandn(a []complex128, mu complex128, sigmaSq float64) { + fact := math.Sqrt(0.5 * sigmaSq) + for i := range a { + a[i] = complex(fact*rand.NormFloat64(), fact*rand.NormFloat64()) + mu + } +} + +func TestQR(t *testing.T) { + A := zbw.NewGeneral(blas.ColMajor, 3, 2, + []complex128{complex(1, 0), complex(2, 0), complex(3, 0), + complex(4, 0), complex(5, 0), complex(6, 0)}) + B := zbw.NewGeneral(blas.ColMajor, 3, 2, + []complex128{complex(1, 0), complex(1, 0), complex(1, 0), complex(2, 0), complex(2, 0), complex(2, 0)}) + + tau := zbw.Allocate(2) + + f := QR(A, tau) + + //fmt.Println(B) + f.Solve(B) + //fmt.Println(B) +} + +func TestLanczos(t *testing.T) { + A := zbw.NewGeneral(blas.ColMajor, 3, 4, nil) + fillRandn(A.Data, 0, 1) + + Acpy := zbw.NewGeneral(blas.ColMajor, 3, 4, nil) + copy(Acpy.Data, A.Data) + + u0 := make([]complex128, 3) + fillRandn(u0, 0, 1) + + Ul, Vl, a, b := LanczosBi(Acpy, u0, 3) + + fmt.Println(a, b) + + tmpc := zbw.NewGeneral(blas.ColMajor, 3, 3, nil) + bidic := zbw.NewGeneral(blas.ColMajor, 3, 3, nil) + + zbw.Gemm(blas.NoTrans, blas.NoTrans, 1, A, Vl, 0, tmpc) + zbw.Gemm(blas.ConjTrans, blas.NoTrans, 1, Ul, tmpc, 0, bidic) + + fmt.Println(bidic) + + Ur, s, Vr := dla.SVDbd(blas.Lower, a, b) + + tmp := dbw.NewGeneral(blas.ColMajor, 3, 3, nil) + bidi := dbw.NewGeneral(blas.ColMajor, 3, 3, nil) + + copy(tmp.Data, Ur.Data) + for i := 0; i < 3; i++ { + dbw.Scal(s[i], tmp.Col(i)) + } + + dbw.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, Vr, 0, bidi) + + fmt.Println(bidi) + /* + + _ = Ul + _ = Vl + Uc := zbw.NewGeneral(blas.ColMajor, 3, 3, nil) + zbw.Real2Cmplx(Ur.Data[:3*3], Uc.Data) + + fmt.Println(Uc.Data) + + U := zbw.NewGeneral(blas.ColMajor, M, K, nil) + zbw.Gemm(blas.NoTrans, blas.NoTrans, 1, U1, Uc, 0, U) + */ +}