Merge pull request #32 from gonum/qrlqsolve

Add cgo and lapack64 functions for performing a QR and LQ solve from …
This commit is contained in:
Brendan Tracey
2015-08-11 08:14:19 -06:00
8 changed files with 207 additions and 13 deletions

View File

@@ -327,3 +327,114 @@ func (impl Implementation) Dgetrs(trans blas.Transpose, n, nrhs int, a []float64
}
clapack.Dgetrs(trans, n, nrhs, a, lda, ipiv32, b, ldb)
}
// Dormlq multiplies the matrix C by the othogonal matrix Q defined by the
// slices a and tau. A and tau are as returned from Dgelqf.
// C = Q * C if side == blas.Left and trans == blas.NoTrans
// C = Q^T * C if side == blas.Left and trans == blas.Trans
// C = C * Q if side == blas.Right and trans == blas.NoTrans
// C = C * Q^T if side == blas.Right and trans == blas.Trans
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
// A is of size k×n. This uses a blocked algorithm.
//
// Work is temporary storage, and lwork specifies the usable memory length.
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
// and this function will panic otherwise.
// Dormlq uses a block algorithm, but the block size is limited
// by the temporary space available. If lwork == -1, instead of performing Dormlq,
// the optimal work length will be stored into work[0].
//
// tau contains the householder scales and must have length at least k, and
// this function will panic otherwise.
func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
if side != blas.Left && side != blas.Right {
panic(badSide)
}
if trans != blas.Trans && trans != blas.NoTrans {
panic(badTrans)
}
left := side == blas.Left
if left {
checkMatrix(k, m, a, lda)
} else {
checkMatrix(k, n, a, lda)
}
checkMatrix(m, n, c, ldc)
if len(tau) < k {
panic(badTau)
}
if lwork == -1 {
if left {
work[0] = float64(n)
return
}
work[0] = float64(m)
return
}
if left {
if lwork < n {
panic(badWork)
}
} else {
if lwork < m {
panic(badWork)
}
}
clapack.Dormlq(side, trans, m, n, k, a, lda, tau, c, ldc)
}
// Dormqr multiplies the matrix C by the othogonal matrix Q defined by the
// slices a and tau. a and tau are as returned from Dgeqrf.
// C = Q * C if side == blas.Left and trans == blas.NoTrans
// C = Q^T * C if side == blas.Left and trans == blas.Trans
// C = C * Q if side == blas.Right and trans == blas.NoTrans
// C = C * Q^T if side == blas.Right and trans == blas.Trans
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
// A is of size k×n. This uses a blocked algorithm.
//
// tau contains the householder scales and must have length at least k, and
// this function will panic otherwise.
//
// The C interface does not support providing temporary storage. To provide compatibility
// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum
// work necessary to work[0]. If len(work) < lwork, Dgeqrf will panic.
func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
left := side == blas.Left
if left {
checkMatrix(m, k, a, lda)
} else {
checkMatrix(n, k, a, lda)
}
checkMatrix(m, n, c, ldc)
if len(tau) < k {
panic(badTau)
}
if lwork == -1 {
if left {
work[0] = float64(m)
return
}
work[0] = float64(n)
return
}
if left {
if lwork < n {
panic(badWork)
}
} else {
if lwork < m {
panic(badWork)
}
}
clapack.Dormqr(side, trans, m, n, k, a, lda, tau, c, ldc)
}
// Dtrtrs solves a triangular system of the form A * X = B or A^T * X = B. Dtrtrs
// returns whether the solve completed successfully. If A is singular, no solve is performed.
func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) {
return clapack.Dtrtrs(uplo, trans, diag, n, nrhs, a, lda, b, ldb)
}

View File

@@ -7,6 +7,7 @@ package cgo
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/lapack/testlapack"
)
@@ -47,3 +48,29 @@ func TestDgetrf(t *testing.T) {
func TestDgetrs(t *testing.T) {
testlapack.DgetrsTest(t, impl)
}
// blockedTranslate transforms some blocked C calls to be the unblocked algorithms
// for testing, as several of the unblocked algorithms are not defined by the C
// interface.
type blockedTranslate struct {
Implementation
}
func (d blockedTranslate) Dorm2r(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64) {
impl.Dormqr(side, trans, m, n, k, a, lda, tau, c, ldc, work, len(work))
}
func (d blockedTranslate) Dorml2(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64) {
impl.Dormlq(side, trans, m, n, k, a, lda, tau, c, ldc, work, len(work))
}
func TestDormqr(t *testing.T) {
testlapack.Dorm2rTest(t, blockedTranslate{impl})
}
/*
// Test disabled because of bug in c interface. Leaving stub for easy reproducer.
func TestDormlq(t *testing.T) {
testlapack.Dorml2Test(t, blockedTranslate{impl})
}
*/

View File

@@ -26,9 +26,12 @@ type Float64 interface {
Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool
Dgelqf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool)
Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int)
Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool)
}
// Direct specifies the direction of the multiplication for the Householder matrix.

View File

@@ -162,3 +162,53 @@ func Getrf(a blas64.General, ipiv []int) bool {
func Getrs(trans blas.Transpose, a blas64.General, b blas64.General, ipiv []int) {
lapack64.Dgetrs(trans, a.Cols, b.Cols, a.Data, a.Stride, ipiv, b.Data, b.Stride)
}
// Ormlq multiplies the matrix C by the othogonal matrix Q defined by
// A and tau. A and tau are as returned from Gelqf.
// C = Q * C if side == blas.Left and trans == blas.NoTrans
// C = Q^T * C if side == blas.Left and trans == blas.Trans
// C = C * Q if side == blas.Right and trans == blas.NoTrans
// C = C * Q^T if side == blas.Right and trans == blas.Trans
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
// A is of size k×n. This uses a blocked algorithm.
//
// Work is temporary storage, and lwork specifies the usable memory length.
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
// and this function will panic otherwise.
// Ormlq uses a block algorithm, but the block size is limited
// by the temporary space available. If lwork == -1, instead of performing Ormlq,
// the optimal work length will be stored into work[0].
//
// Tau contains the householder scales and must have length at least k, and
// this function will panic otherwise.
func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64, c blas64.General, work []float64, lwork int) {
lapack64.Dormlq(side, trans, c.Rows, c.Cols, a.Rows, a.Data, a.Stride, tau, c.Data, c.Stride, work, lwork)
}
// Ormqr multiplies the matrix C by the othogonal matrix Q defined by
// A and tau. A and tau are as returned from Geqrf.
// C = Q * C if side == blas.Left and trans == blas.NoTrans
// C = Q^T * C if side == blas.Left and trans == blas.Trans
// C = C * Q if side == blas.Right and trans == blas.NoTrans
// C = C * Q^T if side == blas.Right and trans == blas.Trans
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
// A is of size k×n. This uses a blocked algorithm.
//
// tau contains the householder scales and must have length at least k, and
// this function will panic otherwise.
//
// Work is temporary storage, and lwork specifies the usable memory length.
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
// and this function will panic otherwise.
// Ormqr uses a block algorithm, but the block size is limited
// by the temporary space available. If lwork == -1, instead of performing Ormqr,
// the optimal work length will be stored into work[0].
func Ormqr(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64, c blas64.General, work []float64, lwork int) {
lapack64.Dormqr(side, trans, c.Rows, c.Cols, a.Cols, a.Data, a.Stride, tau, c.Data, c.Stride, work, lwork)
}
// Trtrs solves a triangular system of the form A * X = B or A^T * X = B. Trtrs
// returns whether the solve completed successfully. If A is singular, no solve is performed.
func Trtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, a blas64.Triangular, b blas64.General) (ok bool) {
return lapack64.Dtrtrs(uplo, trans, diag, a.N, b.Cols, a.Data, a.Stride, b.Data, b.Stride)
}

View File

@@ -6,7 +6,7 @@ package native
import "github.com/gonum/blas"
// Dorm2r multiplies a general matrix c by an orthogonal matrix from a QR factorization
// Dorm2r multiplies a general matrix C by an orthogonal matrix from a QR factorization
// determined by Dgeqrf.
// C = Q * C if side == blas.Left and trans == blas.NoTrans
// C = Q^T * C if side == blas.Left and trans == blas.Trans

View File

@@ -9,14 +9,14 @@ import (
"github.com/gonum/lapack"
)
// Dormlq multiplies the matrix c by the othogonal matrix q defined by the
// Dormlq multiplies the matrix C by the othogonal matrix Q defined by the
// slices a and tau. A and tau are as returned from Dgelqf.
// C = Q * C if side == blas.Left and trans == blas.NoTrans
// C = Q^T * C if side == blas.Left and trans == blas.Trans
// C = C * Q if side == blas.Right and trans == blas.NoTrans
// C = C * Q^T if side == blas.Right and trans == blas.Trans
// If side == blas.Left, a is a matrix of side k×m, and if side == blas.Right
// a is of size k×n. This uses a blocked algorithm.
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
// A is of size k×n. This uses a blocked algorithm.
//
// Work is temporary storage, and lwork specifies the usable memory length.
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
@@ -25,7 +25,7 @@ import (
// by the temporary space available. If lwork == -1, instead of performing Dormlq,
// the optimal work length will be stored into work[0].
//
// Tau contains the householder scales and must have length at least k, and
// tau contains the householder scales and must have length at least k, and
// this function will panic otherwise.
func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
if side != blas.Left && side != blas.Right {

View File

@@ -9,14 +9,14 @@ import (
"github.com/gonum/lapack"
)
// Dormqr multiplies the matrix c by the othogonal matrix q defined by the
// Dormqr multiplies the matrix C by the othogonal matrix Q defined by the
// slices a and tau. A and tau are as returned from Dgeqrf.
// C = Q * C if side == blas.Left and trans == blas.NoTrans
// C = Q^T * C if side == blas.Left and trans == blas.Trans
// C = C * Q if side == blas.Right and trans == blas.NoTrans
// C = C * Q^T if side == blas.Right and trans == blas.Trans
// If side == blas.Left, a is a matrix of side k×m, and if side == blas.Right
// a is of size k×n. This uses a blocked algorithm.
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
// A is of size k×n. This uses a blocked algorithm.
//
// Work is temporary storage, and lwork specifies the usable memory length.
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
@@ -25,7 +25,7 @@ import (
// by the temporary space available. If lwork == -1, instead of performing Dormqr,
// the optimal work length will be stored into work[0].
//
// Tau contains the householder scales and must have length at least k, and
// tau contains the householder scales and must have length at least k, and
// this function will panic otherwise.
func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
left := side == blas.Left
@@ -37,6 +37,10 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k
}
checkMatrix(m, n, c, ldc)
if len(tau) < k {
panic(badTau)
}
const nbmax = 64
nw := n
if side == blas.Right {

View File

@@ -9,9 +9,8 @@ import (
"github.com/gonum/blas/blas64"
)
// Dtrtrs solves a triangular system of the form a * x = b or a^T * x = b. Dtrtrs
// checks for singularity in a. If a is singular, false is returned and no solve
// is performed. True is returned otherwise.
// Dtrtrs solves a triangular system of the form A * X = B or A^T * X = B. Dtrtrs
// returns whether the solve completed successfully. If A is singular, no solve is performed.
func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) {
nounit := diag == blas.NonUnit
if n == 0 {