mirror of
https://github.com/gonum/gonum.git
synced 2025-10-20 05:54:41 +08:00
Merge pull request #32 from gonum/qrlqsolve
Add cgo and lapack64 functions for performing a QR and LQ solve from …
This commit is contained in:
111
cgo/lapack.go
111
cgo/lapack.go
@@ -327,3 +327,114 @@ func (impl Implementation) Dgetrs(trans blas.Transpose, n, nrhs int, a []float64
|
||||
}
|
||||
clapack.Dgetrs(trans, n, nrhs, a, lda, ipiv32, b, ldb)
|
||||
}
|
||||
|
||||
// Dormlq multiplies the matrix C by the othogonal matrix Q defined by the
|
||||
// slices a and tau. A and tau are as returned from Dgelqf.
|
||||
// C = Q * C if side == blas.Left and trans == blas.NoTrans
|
||||
// C = Q^T * C if side == blas.Left and trans == blas.Trans
|
||||
// C = C * Q if side == blas.Right and trans == blas.NoTrans
|
||||
// C = C * Q^T if side == blas.Right and trans == blas.Trans
|
||||
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
|
||||
// A is of size k×n. This uses a blocked algorithm.
|
||||
//
|
||||
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
|
||||
// and this function will panic otherwise.
|
||||
// Dormlq uses a block algorithm, but the block size is limited
|
||||
// by the temporary space available. If lwork == -1, instead of performing Dormlq,
|
||||
// the optimal work length will be stored into work[0].
|
||||
//
|
||||
// tau contains the householder scales and must have length at least k, and
|
||||
// this function will panic otherwise.
|
||||
func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
|
||||
if side != blas.Left && side != blas.Right {
|
||||
panic(badSide)
|
||||
}
|
||||
if trans != blas.Trans && trans != blas.NoTrans {
|
||||
panic(badTrans)
|
||||
}
|
||||
left := side == blas.Left
|
||||
if left {
|
||||
checkMatrix(k, m, a, lda)
|
||||
} else {
|
||||
checkMatrix(k, n, a, lda)
|
||||
}
|
||||
checkMatrix(m, n, c, ldc)
|
||||
if len(tau) < k {
|
||||
panic(badTau)
|
||||
}
|
||||
if lwork == -1 {
|
||||
if left {
|
||||
work[0] = float64(n)
|
||||
return
|
||||
}
|
||||
work[0] = float64(m)
|
||||
return
|
||||
}
|
||||
if left {
|
||||
if lwork < n {
|
||||
panic(badWork)
|
||||
}
|
||||
} else {
|
||||
if lwork < m {
|
||||
panic(badWork)
|
||||
}
|
||||
}
|
||||
clapack.Dormlq(side, trans, m, n, k, a, lda, tau, c, ldc)
|
||||
}
|
||||
|
||||
// Dormqr multiplies the matrix C by the othogonal matrix Q defined by the
|
||||
// slices a and tau. a and tau are as returned from Dgeqrf.
|
||||
// C = Q * C if side == blas.Left and trans == blas.NoTrans
|
||||
// C = Q^T * C if side == blas.Left and trans == blas.Trans
|
||||
// C = C * Q if side == blas.Right and trans == blas.NoTrans
|
||||
// C = C * Q^T if side == blas.Right and trans == blas.Trans
|
||||
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
|
||||
// A is of size k×n. This uses a blocked algorithm.
|
||||
//
|
||||
// tau contains the householder scales and must have length at least k, and
|
||||
// this function will panic otherwise.
|
||||
//
|
||||
// The C interface does not support providing temporary storage. To provide compatibility
|
||||
// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum
|
||||
// work necessary to work[0]. If len(work) < lwork, Dgeqrf will panic.
|
||||
func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
|
||||
left := side == blas.Left
|
||||
if left {
|
||||
checkMatrix(m, k, a, lda)
|
||||
} else {
|
||||
checkMatrix(n, k, a, lda)
|
||||
}
|
||||
checkMatrix(m, n, c, ldc)
|
||||
|
||||
if len(tau) < k {
|
||||
panic(badTau)
|
||||
}
|
||||
|
||||
if lwork == -1 {
|
||||
if left {
|
||||
work[0] = float64(m)
|
||||
return
|
||||
}
|
||||
work[0] = float64(n)
|
||||
return
|
||||
}
|
||||
|
||||
if left {
|
||||
if lwork < n {
|
||||
panic(badWork)
|
||||
}
|
||||
} else {
|
||||
if lwork < m {
|
||||
panic(badWork)
|
||||
}
|
||||
}
|
||||
|
||||
clapack.Dormqr(side, trans, m, n, k, a, lda, tau, c, ldc)
|
||||
}
|
||||
|
||||
// Dtrtrs solves a triangular system of the form A * X = B or A^T * X = B. Dtrtrs
|
||||
// returns whether the solve completed successfully. If A is singular, no solve is performed.
|
||||
func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) {
|
||||
return clapack.Dtrtrs(uplo, trans, diag, n, nrhs, a, lda, b, ldb)
|
||||
}
|
||||
|
@@ -7,6 +7,7 @@ package cgo
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/lapack/testlapack"
|
||||
)
|
||||
|
||||
@@ -47,3 +48,29 @@ func TestDgetrf(t *testing.T) {
|
||||
func TestDgetrs(t *testing.T) {
|
||||
testlapack.DgetrsTest(t, impl)
|
||||
}
|
||||
|
||||
// blockedTranslate transforms some blocked C calls to be the unblocked algorithms
|
||||
// for testing, as several of the unblocked algorithms are not defined by the C
|
||||
// interface.
|
||||
type blockedTranslate struct {
|
||||
Implementation
|
||||
}
|
||||
|
||||
func (d blockedTranslate) Dorm2r(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64) {
|
||||
impl.Dormqr(side, trans, m, n, k, a, lda, tau, c, ldc, work, len(work))
|
||||
}
|
||||
|
||||
func (d blockedTranslate) Dorml2(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64) {
|
||||
impl.Dormlq(side, trans, m, n, k, a, lda, tau, c, ldc, work, len(work))
|
||||
}
|
||||
|
||||
func TestDormqr(t *testing.T) {
|
||||
testlapack.Dorm2rTest(t, blockedTranslate{impl})
|
||||
}
|
||||
|
||||
/*
|
||||
// Test disabled because of bug in c interface. Leaving stub for easy reproducer.
|
||||
func TestDormlq(t *testing.T) {
|
||||
testlapack.Dorml2Test(t, blockedTranslate{impl})
|
||||
}
|
||||
*/
|
||||
|
@@ -26,9 +26,12 @@ type Float64 interface {
|
||||
Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool
|
||||
Dgelqf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
|
||||
Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
|
||||
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
|
||||
Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool)
|
||||
Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int)
|
||||
Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
|
||||
Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
|
||||
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
|
||||
Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool)
|
||||
}
|
||||
|
||||
// Direct specifies the direction of the multiplication for the Householder matrix.
|
||||
|
@@ -162,3 +162,53 @@ func Getrf(a blas64.General, ipiv []int) bool {
|
||||
func Getrs(trans blas.Transpose, a blas64.General, b blas64.General, ipiv []int) {
|
||||
lapack64.Dgetrs(trans, a.Cols, b.Cols, a.Data, a.Stride, ipiv, b.Data, b.Stride)
|
||||
}
|
||||
|
||||
// Ormlq multiplies the matrix C by the othogonal matrix Q defined by
|
||||
// A and tau. A and tau are as returned from Gelqf.
|
||||
// C = Q * C if side == blas.Left and trans == blas.NoTrans
|
||||
// C = Q^T * C if side == blas.Left and trans == blas.Trans
|
||||
// C = C * Q if side == blas.Right and trans == blas.NoTrans
|
||||
// C = C * Q^T if side == blas.Right and trans == blas.Trans
|
||||
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
|
||||
// A is of size k×n. This uses a blocked algorithm.
|
||||
//
|
||||
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
|
||||
// and this function will panic otherwise.
|
||||
// Ormlq uses a block algorithm, but the block size is limited
|
||||
// by the temporary space available. If lwork == -1, instead of performing Ormlq,
|
||||
// the optimal work length will be stored into work[0].
|
||||
//
|
||||
// Tau contains the householder scales and must have length at least k, and
|
||||
// this function will panic otherwise.
|
||||
func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64, c blas64.General, work []float64, lwork int) {
|
||||
lapack64.Dormlq(side, trans, c.Rows, c.Cols, a.Rows, a.Data, a.Stride, tau, c.Data, c.Stride, work, lwork)
|
||||
}
|
||||
|
||||
// Ormqr multiplies the matrix C by the othogonal matrix Q defined by
|
||||
// A and tau. A and tau are as returned from Geqrf.
|
||||
// C = Q * C if side == blas.Left and trans == blas.NoTrans
|
||||
// C = Q^T * C if side == blas.Left and trans == blas.Trans
|
||||
// C = C * Q if side == blas.Right and trans == blas.NoTrans
|
||||
// C = C * Q^T if side == blas.Right and trans == blas.Trans
|
||||
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
|
||||
// A is of size k×n. This uses a blocked algorithm.
|
||||
//
|
||||
// tau contains the householder scales and must have length at least k, and
|
||||
// this function will panic otherwise.
|
||||
//
|
||||
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
|
||||
// and this function will panic otherwise.
|
||||
// Ormqr uses a block algorithm, but the block size is limited
|
||||
// by the temporary space available. If lwork == -1, instead of performing Ormqr,
|
||||
// the optimal work length will be stored into work[0].
|
||||
func Ormqr(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64, c blas64.General, work []float64, lwork int) {
|
||||
lapack64.Dormqr(side, trans, c.Rows, c.Cols, a.Cols, a.Data, a.Stride, tau, c.Data, c.Stride, work, lwork)
|
||||
}
|
||||
|
||||
// Trtrs solves a triangular system of the form A * X = B or A^T * X = B. Trtrs
|
||||
// returns whether the solve completed successfully. If A is singular, no solve is performed.
|
||||
func Trtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, a blas64.Triangular, b blas64.General) (ok bool) {
|
||||
return lapack64.Dtrtrs(uplo, trans, diag, a.N, b.Cols, a.Data, a.Stride, b.Data, b.Stride)
|
||||
}
|
||||
|
@@ -6,7 +6,7 @@ package native
|
||||
|
||||
import "github.com/gonum/blas"
|
||||
|
||||
// Dorm2r multiplies a general matrix c by an orthogonal matrix from a QR factorization
|
||||
// Dorm2r multiplies a general matrix C by an orthogonal matrix from a QR factorization
|
||||
// determined by Dgeqrf.
|
||||
// C = Q * C if side == blas.Left and trans == blas.NoTrans
|
||||
// C = Q^T * C if side == blas.Left and trans == blas.Trans
|
||||
|
@@ -9,14 +9,14 @@ import (
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
// Dormlq multiplies the matrix c by the othogonal matrix q defined by the
|
||||
// Dormlq multiplies the matrix C by the othogonal matrix Q defined by the
|
||||
// slices a and tau. A and tau are as returned from Dgelqf.
|
||||
// C = Q * C if side == blas.Left and trans == blas.NoTrans
|
||||
// C = Q^T * C if side == blas.Left and trans == blas.Trans
|
||||
// C = C * Q if side == blas.Right and trans == blas.NoTrans
|
||||
// C = C * Q^T if side == blas.Right and trans == blas.Trans
|
||||
// If side == blas.Left, a is a matrix of side k×m, and if side == blas.Right
|
||||
// a is of size k×n. This uses a blocked algorithm.
|
||||
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
|
||||
// A is of size k×n. This uses a blocked algorithm.
|
||||
//
|
||||
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
|
||||
@@ -25,7 +25,7 @@ import (
|
||||
// by the temporary space available. If lwork == -1, instead of performing Dormlq,
|
||||
// the optimal work length will be stored into work[0].
|
||||
//
|
||||
// Tau contains the householder scales and must have length at least k, and
|
||||
// tau contains the householder scales and must have length at least k, and
|
||||
// this function will panic otherwise.
|
||||
func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
|
||||
if side != blas.Left && side != blas.Right {
|
||||
|
@@ -9,14 +9,14 @@ import (
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
// Dormqr multiplies the matrix c by the othogonal matrix q defined by the
|
||||
// Dormqr multiplies the matrix C by the othogonal matrix Q defined by the
|
||||
// slices a and tau. A and tau are as returned from Dgeqrf.
|
||||
// C = Q * C if side == blas.Left and trans == blas.NoTrans
|
||||
// C = Q^T * C if side == blas.Left and trans == blas.Trans
|
||||
// C = C * Q if side == blas.Right and trans == blas.NoTrans
|
||||
// C = C * Q^T if side == blas.Right and trans == blas.Trans
|
||||
// If side == blas.Left, a is a matrix of side k×m, and if side == blas.Right
|
||||
// a is of size k×n. This uses a blocked algorithm.
|
||||
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
|
||||
// A is of size k×n. This uses a blocked algorithm.
|
||||
//
|
||||
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
|
||||
@@ -25,7 +25,7 @@ import (
|
||||
// by the temporary space available. If lwork == -1, instead of performing Dormqr,
|
||||
// the optimal work length will be stored into work[0].
|
||||
//
|
||||
// Tau contains the householder scales and must have length at least k, and
|
||||
// tau contains the householder scales and must have length at least k, and
|
||||
// this function will panic otherwise.
|
||||
func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
|
||||
left := side == blas.Left
|
||||
@@ -37,6 +37,10 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k
|
||||
}
|
||||
checkMatrix(m, n, c, ldc)
|
||||
|
||||
if len(tau) < k {
|
||||
panic(badTau)
|
||||
}
|
||||
|
||||
const nbmax = 64
|
||||
nw := n
|
||||
if side == blas.Right {
|
||||
|
@@ -9,9 +9,8 @@ import (
|
||||
"github.com/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
// Dtrtrs solves a triangular system of the form a * x = b or a^T * x = b. Dtrtrs
|
||||
// checks for singularity in a. If a is singular, false is returned and no solve
|
||||
// is performed. True is returned otherwise.
|
||||
// Dtrtrs solves a triangular system of the form A * X = B or A^T * X = B. Dtrtrs
|
||||
// returns whether the solve completed successfully. If A is singular, no solve is performed.
|
||||
func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) {
|
||||
nounit := diag == blas.NonUnit
|
||||
if n == 0 {
|
||||
|
Reference in New Issue
Block a user