diff --git a/cgo/lapack.go b/cgo/lapack.go index cc5e0ef4..c76dcb16 100644 --- a/cgo/lapack.go +++ b/cgo/lapack.go @@ -77,6 +77,71 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok return clapack.Dpotrf(ul, n, a, lda) } +// Dgeqr2 computes a QR factorization of the m×n matrix A. +// +// In a QR factorization, Q is an m×m orthonormal matrix, and R is an +// upper triangular m×n matrix. +// +// During Dgeqr2, a is modified to contain the information to construct Q and R. +// The upper triangle of a contains the matrix R. The lower triangular elements +// (not including the diagonal) contain the elementary reflectors. Tau is modified +// to contain the reflector scales. Tau must have length at least k = min(m,n), and +// this function will panic otherwise. +// +// The ith elementary reflector can be explicitly constructed by first extracting +// the +// v[j] = 0 j < i +// v[j] = i j == i +// v[j] = a[i*lda+j] j > i +// and computing h_i = I - tau[i] * v * v^T. +// +// The orthonormal matrix Q can be constucted from a product of these elementary +// reflectors, Q = H_1*H_2 ... H_k, where k = min(m,n). +// +// Work is temporary storage of length at least n and this function will panic otherwise. +func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []float64) { + // TODO(btracey): This is oriented such that columns of a are eliminated. + // This likely could be re-arranged to take better advantage of row-major + // storage. + checkMatrix(m, n, a, lda) + if len(work) < n { + panic(badWork) + } + k := min(m, n) + if len(tau) < k { + panic(badTau) + } + clapack.Dgeqr2(m, n, a, lda, tau) +} + +// Dgeqrf computes the QR factorization of the m×n matrix A using a blocked +// algorithm. Please see the documentation for Dgeqr2 for a description of the +// parameters at entry and exit. +// +// The C interface does not support providing temporary storage. To provide compatibility +// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum +// work necessary to work[0]. If len(work) < lwork, Dgels will panic. +// +// tau must be at least len min(m,n), and this function will panic otherwise. +func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) { + if lwork == -1 { + work[0] = float64(n) + return + } + checkMatrix(m, n, a, lda) + if len(work) < lwork { + panic(shortWork) + } + if lwork < n { + panic(badWork) + } + k := min(m, n) + if len(tau) < k { + panic(badTau) + } + clapack.Dgeqrf(m, n, a, lda, tau) +} + // Dgetf2 computes the LU decomposition of the m×n matrix A. // The LU decomposition is a factorization of a into // A = P * L * U diff --git a/cgo/lapack_test.go b/cgo/lapack_test.go index 7a69d7e2..9f96c5a7 100644 --- a/cgo/lapack_test.go +++ b/cgo/lapack_test.go @@ -16,6 +16,14 @@ func TestDpotrf(t *testing.T) { testlapack.DpotrfTest(t, impl) } +func TestDgeqr2(t *testing.T) { + testlapack.Dgeqr2Test(t, impl) +} + +func TestDgeqrf(t *testing.T) { + testlapack.DgeqrfTest(t, impl) +} + func TestDgetf2(t *testing.T) { testlapack.Dgetf2Test(t, impl) } diff --git a/lapack.go b/lapack.go index f6fbb511..4ca3372b 100644 --- a/lapack.go +++ b/lapack.go @@ -23,6 +23,7 @@ type Complex128 interface{} // Float64 defines the public float64 LAPACK API supported by gonum/lapack. type Float64 interface { + Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) } diff --git a/lapack64/lapack64.go b/lapack64/lapack64.go index fe13d587..c2d58eb1 100644 --- a/lapack64/lapack64.go +++ b/lapack64/lapack64.go @@ -47,3 +47,29 @@ func Potrf(a blas64.Symmetric) (t blas64.Triangular, ok bool) { t.Diag = blas.NonUnit return } + +// Geqrf computes the QR factorization of the m×n matrix A using a blocked +// algorithm. During Dgeqr2, A is modified to contain the information to construct Q and R. +// The upper triangle of a contains the matrix R. The lower triangular elements +// (not including the diagonal) contain the elementary reflectors. Tau is modified +// to contain the reflector scales. Tau must have length at least k = min(m,n), and +// this function will panic otherwise. +// +// The ith elementary reflector can be explicitly constructed by first extracting +// the +// v[j] = 0 j < i +// v[j] = i j == i +// v[j] = a[i*lda+j] j > i +// and computing h_i = I - tau[i] * v * v^T. +// +// The orthonormal matrix Q can be constucted from a product of these elementary +// reflectors, Q = H_1*H_2 ... H_k, where k = min(m,n). +// +// Work is temporary storage, and lwork specifies the usable memory length. +// At minimum, lwork >= m and this function will panic otherwise. +// Dgeqrf is a blocked LQ factorization, but the block size is limited +// by the temporary space available. If lwork == -1, instead of performing Dgelqf, +// the optimal work length will be stored into work[0]. +func Geqrf(a blas64.General, tau, work []float64, lwork int) { + lapack64.Dgeqrf(a.Rows, a.Cols, a.Data, a.Stride, tau, work, lwork) +} diff --git a/native/dgeqr2.go b/native/dgeqr2.go index efae4a77..a1bb4d74 100644 --- a/native/dgeqr2.go +++ b/native/dgeqr2.go @@ -6,7 +6,7 @@ package native import "github.com/gonum/blas" -// Dgeqr2 computes a QR factorization of the m×n matrix a. +// Dgeqr2 computes a QR factorization of the m×n matrix A. // // In a QR factorization, Q is an m×m orthonormal matrix, and R is an // upper triangular m×n matrix. diff --git a/testlapack/dgeqrf.go b/testlapack/dgeqrf.go index c5964150..cbd5ba44 100644 --- a/testlapack/dgeqrf.go +++ b/testlapack/dgeqrf.go @@ -66,7 +66,7 @@ func DgeqrfTest(t *testing.T, impl Dgeqrfer) { impl.Dgeqr2(m, n, ans, lda, tau, work) // Compute blocked QR with small work. impl.Dgeqrf(m, n, a, lda, tau, work, len(work)) - if !floats.EqualApprox(ans, a, 1e-14) { + if !floats.EqualApprox(ans, a, 1e-12) { t.Errorf("Case %v, mismatch small work.", c) } // Try the full length of work. @@ -80,6 +80,9 @@ func DgeqrfTest(t *testing.T, impl Dgeqrfer) { } // Try a slightly smaller version of work to test blocking. + if len(work) <= n { + continue + } work = work[1:] lwork-- copy(a, aCopy)