mirror of
https://github.com/gonum/gonum.git
synced 2025-10-24 15:43:07 +08:00
Add LQ factorization to cgo and tests
Responded to PR comments
This commit is contained in:
@@ -77,6 +77,60 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok
|
|||||||
return clapack.Dpotrf(ul, n, a, lda)
|
return clapack.Dpotrf(ul, n, a, lda)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dgelq2 computes the LQ factorization of the m×n matrix A.
|
||||||
|
//
|
||||||
|
// In an LQ factorization, L is a lower triangular m×n matrix, and Q is an n×n
|
||||||
|
// orthornormal matrix.
|
||||||
|
//
|
||||||
|
// a is modified to contain the information to construct L and Q.
|
||||||
|
// The lower triangle of a contains the matrix L. The upper triangular elements
|
||||||
|
// (not including the diagonal) contain the elementary reflectors. Tau is modified
|
||||||
|
// to contain the reflector scales. tau must have length of at least k = min(m,n)
|
||||||
|
// and this function will panic otherwise.
|
||||||
|
//
|
||||||
|
// See Dgeqr2 for a description of the elementary reflectors and orthonormal
|
||||||
|
// matrix Q. Q is constructed as a product of these elementary reflectors,
|
||||||
|
// Q = H_k ... H_2*H_1.
|
||||||
|
//
|
||||||
|
// Work is temporary storage of length at least m and this function will panic otherwise.
|
||||||
|
func (impl Implementation) Dgelq2(m, n int, a []float64, lda int, tau, work []float64) {
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if len(tau) < min(m, n) {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
if len(work) < m {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
clapack.Dgelq2(m, n, a, lda, tau)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dgelqf computes the LQ factorization of the m×n matrix A using a blocked
|
||||||
|
// algorithm. See the documentation for Dgelq2 for a description of the
|
||||||
|
// parameters at entry and exit.
|
||||||
|
//
|
||||||
|
// The C interface does not support providing temporary storage. To provide compatibility
|
||||||
|
// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum
|
||||||
|
// work necessary to work[0]. If len(work) < lwork, Dgeqrf will panic.
|
||||||
|
//
|
||||||
|
// tau must have length at least min(m,n), and this function will panic otherwise.
|
||||||
|
func (impl Implementation) Dgelqf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
|
||||||
|
if lwork == -1 {
|
||||||
|
work[0] = float64(m)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if len(work) < lwork {
|
||||||
|
panic(shortWork)
|
||||||
|
}
|
||||||
|
if lwork < m {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
if len(tau) < min(m, n) {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
clapack.Dgelqf(m, n, a, lda, tau)
|
||||||
|
}
|
||||||
|
|
||||||
// Dgeqr2 computes a QR factorization of the m×n matrix A.
|
// Dgeqr2 computes a QR factorization of the m×n matrix A.
|
||||||
//
|
//
|
||||||
// In a QR factorization, Q is an m×m orthonormal matrix, and R is an
|
// In a QR factorization, Q is an m×m orthonormal matrix, and R is an
|
||||||
@@ -100,9 +154,6 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok
|
|||||||
//
|
//
|
||||||
// Work is temporary storage of length at least n and this function will panic otherwise.
|
// Work is temporary storage of length at least n and this function will panic otherwise.
|
||||||
func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []float64) {
|
func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []float64) {
|
||||||
// TODO(btracey): This is oriented such that columns of a are eliminated.
|
|
||||||
// This likely could be re-arranged to take better advantage of row-major
|
|
||||||
// storage.
|
|
||||||
checkMatrix(m, n, a, lda)
|
checkMatrix(m, n, a, lda)
|
||||||
if len(work) < n {
|
if len(work) < n {
|
||||||
panic(badWork)
|
panic(badWork)
|
||||||
@@ -120,7 +171,7 @@ func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []fl
|
|||||||
//
|
//
|
||||||
// The C interface does not support providing temporary storage. To provide compatibility
|
// The C interface does not support providing temporary storage. To provide compatibility
|
||||||
// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum
|
// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum
|
||||||
// work necessary to work[0]. If len(work) < lwork, Dgels will panic.
|
// work necessary to work[0]. If len(work) < lwork, Dgeqrf will panic.
|
||||||
//
|
//
|
||||||
// tau must have length at least min(m,n), and this function will panic otherwise.
|
// tau must have length at least min(m,n), and this function will panic otherwise.
|
||||||
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
|
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
|
||||||
|
|||||||
@@ -16,6 +16,14 @@ func TestDpotrf(t *testing.T) {
|
|||||||
testlapack.DpotrfTest(t, impl)
|
testlapack.DpotrfTest(t, impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDgelq2(t *testing.T) {
|
||||||
|
testlapack.Dgelq2Test(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDgelqf(t *testing.T) {
|
||||||
|
testlapack.DgelqfTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDgeqr2(t *testing.T) {
|
func TestDgeqr2(t *testing.T) {
|
||||||
testlapack.Dgeqr2Test(t, impl)
|
testlapack.Dgeqr2Test(t, impl)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ type Complex128 interface{}
|
|||||||
|
|
||||||
// Float64 defines the public float64 LAPACK API supported by gonum/lapack.
|
// Float64 defines the public float64 LAPACK API supported by gonum/lapack.
|
||||||
type Float64 interface {
|
type Float64 interface {
|
||||||
|
Dgelqf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
|
||||||
Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
|
Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
|
||||||
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
|
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,9 +67,29 @@ func Potrf(a blas64.Symmetric) (t blas64.Triangular, ok bool) {
|
|||||||
//
|
//
|
||||||
// Work is temporary storage, and lwork specifies the usable memory length.
|
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||||
// At minimum, lwork >= m and this function will panic otherwise.
|
// At minimum, lwork >= m and this function will panic otherwise.
|
||||||
// Dgeqrf is a blocked LQ factorization, but the block size is limited
|
// Dgeqrf is a blocked QR factorization, but the block size is limited
|
||||||
// by the temporary space available. If lwork == -1, instead of performing Dgelqf,
|
// by the temporary space available. If lwork == -1, instead of performing Geqrf,
|
||||||
// the optimal work length will be stored into work[0].
|
// the optimal work length will be stored into work[0].
|
||||||
func Geqrf(a blas64.General, tau, work []float64, lwork int) {
|
func Geqrf(a blas64.General, tau, work []float64, lwork int) {
|
||||||
lapack64.Dgeqrf(a.Rows, a.Cols, a.Data, a.Stride, tau, work, lwork)
|
lapack64.Dgeqrf(a.Rows, a.Cols, a.Data, a.Stride, tau, work, lwork)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Gelqf computes the QR factorization of the m×n matrix A using a blocked
|
||||||
|
// algorithm. A is modified to contain the information to construct L and Q.
|
||||||
|
// The lower triangle of a contains the matrix L. The lower triangular elements
|
||||||
|
// (not including the diagonal) contain the elementary reflectors. Tau is modified
|
||||||
|
// to contain the reflector scales. Tau must have length at least min(m,n), and
|
||||||
|
// this function will panic otherwise.
|
||||||
|
//
|
||||||
|
// See Geqrf for a description of the elementary reflectors and orthonormal
|
||||||
|
// matrix Q. Q is constructed as a product of these elementary reflectors,
|
||||||
|
// Q = H_k ... H_2*H_1.
|
||||||
|
//
|
||||||
|
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||||
|
// At minimum, lwork >= m and this function will panic otherwise.
|
||||||
|
// Dgeqrf is a blocked LQ factorization, but the block size is limited
|
||||||
|
// by the temporary space available. If lwork == -1, instead of performing Gelqf,
|
||||||
|
// the optimal work length will be stored into work[0].
|
||||||
|
func Gelqf(a blas64.General, tau, work []float64, lwork int) {
|
||||||
|
lapack64.Dgelqf(a.Rows, a.Cols, a.Data, a.Stride, tau, work, lwork)
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ package native
|
|||||||
|
|
||||||
import "github.com/gonum/blas"
|
import "github.com/gonum/blas"
|
||||||
|
|
||||||
// Dgelq2 computes the LQ factorization of the m×n matrix a.
|
// Dgelq2 computes the LQ factorization of the m×n matrix A.
|
||||||
//
|
//
|
||||||
// During Dgelq2, a is modified to contain the information to construct Q and L.
|
// In an LQ factorization, L is a lower triangular m×n matrix, and Q is an n×n
|
||||||
|
// orthornormal matrix.
|
||||||
|
//
|
||||||
|
// a is modified to contain the information to construct L and Q.
|
||||||
// The lower triangle of a contains the matrix L. The upper triangular elements
|
// The lower triangle of a contains the matrix L. The upper triangular elements
|
||||||
// (not including the diagonal) contain the elementary reflectors. Tau is modified
|
// (not including the diagonal) contain the elementary reflectors. Tau is modified
|
||||||
// to contain the reflector scales. Tau must have length of at least k = min(m,n)
|
// to contain the reflector scales. Tau must have length of at least k = min(m,n)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/gonum/lapack"
|
"github.com/gonum/lapack"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Dgelqf computes the LQ factorization of the m×n matrix a using a blocked
|
// Dgelqf computes the LQ factorization of the m×n matrix A using a blocked
|
||||||
// algorithm. Please see the documentation for Dgelq2 for a description of the
|
// algorithm. Please see the documentation for Dgelq2 for a description of the
|
||||||
// parameters at entry and exit.
|
// parameters at entry and exit.
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ func DgelqfTest(t *testing.T, impl Dgelqfer) {
|
|||||||
impl.Dgelq2(m, n, ans, lda, tau, work)
|
impl.Dgelq2(m, n, ans, lda, tau, work)
|
||||||
// Compute blocked QR with small work.
|
// Compute blocked QR with small work.
|
||||||
impl.Dgelqf(m, n, a, lda, tau, work, len(work))
|
impl.Dgelqf(m, n, a, lda, tau, work, len(work))
|
||||||
if !floats.EqualApprox(ans, a, 1e-14) {
|
if !floats.EqualApprox(ans, a, 1e-12) {
|
||||||
t.Errorf("Case %v, mismatch small work.", c)
|
t.Errorf("Case %v, mismatch small work.", c)
|
||||||
}
|
}
|
||||||
// Try the full length of work.
|
// Try the full length of work.
|
||||||
@@ -83,6 +83,9 @@ func DgelqfTest(t *testing.T, impl Dgelqfer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try a slightly smaller version of work to test blocking code.
|
// Try a slightly smaller version of work to test blocking code.
|
||||||
|
if len(work) <= m {
|
||||||
|
continue
|
||||||
|
}
|
||||||
work = work[1:]
|
work = work[1:]
|
||||||
lwork--
|
lwork--
|
||||||
copy(a, aCopy)
|
copy(a, aCopy)
|
||||||
|
|||||||
@@ -20,19 +20,19 @@ func DgetrsTest(t *testing.T, impl Dgetrser) {
|
|||||||
n, nrhs, lda, ldb int
|
n, nrhs, lda, ldb int
|
||||||
tol float64
|
tol float64
|
||||||
}{
|
}{
|
||||||
{3, 3, 0, 0, 1e-14},
|
{3, 3, 0, 0, 1e-12},
|
||||||
{3, 3, 0, 0, 1e-14},
|
{3, 3, 0, 0, 1e-12},
|
||||||
{3, 5, 0, 0, 1e-14},
|
{3, 5, 0, 0, 1e-12},
|
||||||
{3, 5, 0, 0, 1e-14},
|
{3, 5, 0, 0, 1e-12},
|
||||||
{5, 3, 0, 0, 1e-14},
|
{5, 3, 0, 0, 1e-12},
|
||||||
{5, 3, 0, 0, 1e-14},
|
{5, 3, 0, 0, 1e-12},
|
||||||
|
|
||||||
{3, 3, 8, 10, 1e-14},
|
{3, 3, 8, 10, 1e-12},
|
||||||
{3, 3, 8, 10, 1e-14},
|
{3, 3, 8, 10, 1e-12},
|
||||||
{3, 5, 8, 10, 1e-14},
|
{3, 5, 8, 10, 1e-12},
|
||||||
{3, 5, 8, 10, 1e-14},
|
{3, 5, 8, 10, 1e-12},
|
||||||
{5, 3, 8, 10, 1e-14},
|
{5, 3, 8, 10, 1e-12},
|
||||||
{5, 3, 8, 10, 1e-14},
|
{5, 3, 8, 10, 1e-12},
|
||||||
|
|
||||||
{300, 300, 0, 0, 1e-10},
|
{300, 300, 0, 0, 1e-10},
|
||||||
{300, 300, 0, 0, 1e-10},
|
{300, 300, 0, 0, 1e-10},
|
||||||
@@ -45,7 +45,7 @@ func DgetrsTest(t *testing.T, impl Dgetrser) {
|
|||||||
{300, 300, 700, 600, 1e-10},
|
{300, 300, 700, 600, 1e-10},
|
||||||
{300, 500, 700, 600, 1e-10},
|
{300, 500, 700, 600, 1e-10},
|
||||||
{300, 500, 700, 600, 1e-10},
|
{300, 500, 700, 600, 1e-10},
|
||||||
{500, 300, 700, 600, 1e-10},
|
{500, 300, 700, 600, 1e-8},
|
||||||
{500, 300, 700, 600, 1e-10},
|
{500, 300, 700, 600, 1e-10},
|
||||||
} {
|
} {
|
||||||
n := test.n
|
n := test.n
|
||||||
|
|||||||
Reference in New Issue
Block a user