Add QR factorization to lapack64 interface.

This commit is contained in:
btracey
2015-08-03 10:21:37 -06:00
parent a5bda2fc24
commit 0331cab04a
6 changed files with 105 additions and 2 deletions

View File

@@ -77,6 +77,71 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok
return clapack.Dpotrf(ul, n, a, lda)
}
// Dgeqr2 computes a QR factorization of the m×n matrix A.
//
// In a QR factorization, Q is an m×m orthonormal matrix, and R is an
// upper triangular m×n matrix.
//
// During Dgeqr2, a is modified to contain the information to construct Q and R.
// The upper triangle of a contains the matrix R. The lower triangular elements
// (not including the diagonal) contain the elementary reflectors. Tau is modified
// to contain the reflector scales. Tau must have length at least k = min(m,n), and
// this function will panic otherwise.
//
// The ith elementary reflector can be explicitly constructed by first extracting
// the
// v[j] = 0 j < i
// v[j] = i j == i
// v[j] = a[i*lda+j] j > i
// and computing h_i = I - tau[i] * v * v^T.
//
// The orthonormal matrix Q can be constucted from a product of these elementary
// reflectors, Q = H_1*H_2 ... H_k, where k = min(m,n).
//
// Work is temporary storage of length at least n and this function will panic otherwise.
func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []float64) {
// TODO(btracey): This is oriented such that columns of a are eliminated.
// This likely could be re-arranged to take better advantage of row-major
// storage.
checkMatrix(m, n, a, lda)
if len(work) < n {
panic(badWork)
}
k := min(m, n)
if len(tau) < k {
panic(badTau)
}
clapack.Dgeqr2(m, n, a, lda, tau)
}
// Dgeqrf computes the QR factorization of the m×n matrix A using a blocked
// algorithm. Please see the documentation for Dgeqr2 for a description of the
// parameters at entry and exit.
//
// The C interface does not support providing temporary storage. To provide compatibility
// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum
// work necessary to work[0]. If len(work) < lwork, Dgels will panic.
//
// tau must be at least len min(m,n), and this function will panic otherwise.
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
if lwork == -1 {
work[0] = float64(n)
return
}
checkMatrix(m, n, a, lda)
if len(work) < lwork {
panic(shortWork)
}
if lwork < n {
panic(badWork)
}
k := min(m, n)
if len(tau) < k {
panic(badTau)
}
clapack.Dgeqrf(m, n, a, lda, tau)
}
// Dgetf2 computes the LU decomposition of the m×n matrix A.
// The LU decomposition is a factorization of a into
// A = P * L * U

View File

@@ -16,6 +16,14 @@ func TestDpotrf(t *testing.T) {
testlapack.DpotrfTest(t, impl)
}
func TestDgeqr2(t *testing.T) {
testlapack.Dgeqr2Test(t, impl)
}
func TestDgeqrf(t *testing.T) {
testlapack.DgeqrfTest(t, impl)
}
func TestDgetf2(t *testing.T) {
testlapack.Dgetf2Test(t, impl)
}

View File

@@ -23,6 +23,7 @@ type Complex128 interface{}
// Float64 defines the public float64 LAPACK API supported by gonum/lapack.
type Float64 interface {
Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
}

View File

@@ -47,3 +47,29 @@ func Potrf(a blas64.Symmetric) (t blas64.Triangular, ok bool) {
t.Diag = blas.NonUnit
return
}
// Geqrf computes the QR factorization of the m×n matrix A using a blocked
// algorithm. During Dgeqr2, A is modified to contain the information to construct Q and R.
// The upper triangle of a contains the matrix R. The lower triangular elements
// (not including the diagonal) contain the elementary reflectors. Tau is modified
// to contain the reflector scales. Tau must have length at least k = min(m,n), and
// this function will panic otherwise.
//
// The ith elementary reflector can be explicitly constructed by first extracting
// the
// v[j] = 0 j < i
// v[j] = i j == i
// v[j] = a[i*lda+j] j > i
// and computing h_i = I - tau[i] * v * v^T.
//
// The orthonormal matrix Q can be constucted from a product of these elementary
// reflectors, Q = H_1*H_2 ... H_k, where k = min(m,n).
//
// Work is temporary storage, and lwork specifies the usable memory length.
// At minimum, lwork >= m and this function will panic otherwise.
// Dgeqrf is a blocked LQ factorization, but the block size is limited
// by the temporary space available. If lwork == -1, instead of performing Dgelqf,
// the optimal work length will be stored into work[0].
func Geqrf(a blas64.General, tau, work []float64, lwork int) {
lapack64.Dgeqrf(a.Rows, a.Cols, a.Data, a.Stride, tau, work, lwork)
}

View File

@@ -6,7 +6,7 @@ package native
import "github.com/gonum/blas"
// Dgeqr2 computes a QR factorization of the m×n matrix a.
// Dgeqr2 computes a QR factorization of the m×n matrix A.
//
// In a QR factorization, Q is an m×m orthonormal matrix, and R is an
// upper triangular m×n matrix.

View File

@@ -66,7 +66,7 @@ func DgeqrfTest(t *testing.T, impl Dgeqrfer) {
impl.Dgeqr2(m, n, ans, lda, tau, work)
// Compute blocked QR with small work.
impl.Dgeqrf(m, n, a, lda, tau, work, len(work))
if !floats.EqualApprox(ans, a, 1e-14) {
if !floats.EqualApprox(ans, a, 1e-12) {
t.Errorf("Case %v, mismatch small work.", c)
}
// Try the full length of work.
@@ -80,6 +80,9 @@ func DgeqrfTest(t *testing.T, impl Dgeqrfer) {
}
// Try a slightly smaller version of work to test blocking.
if len(work) <= n {
continue
}
work = work[1:]
lwork--
copy(a, aCopy)