lapack/gonum: fix ldwork in Dgeqrf and update its test

This commit is contained in:
Vladimir Chalupecky
2019-01-22 18:34:21 +01:00
committed by Vladimír Chalupecký
parent 08a35caaad
commit dd4cc715c5
2 changed files with 43 additions and 37 deletions

View File

@@ -18,6 +18,7 @@ type Dgeqrfer interface {
}
func DgeqrfTest(t *testing.T, impl Dgeqrfer) {
const tol = 1e-12
rnd := rand.New(rand.NewSource(1))
for c, test := range []struct {
m, n, lda int
@@ -49,48 +50,53 @@ func DgeqrfTest(t *testing.T, impl Dgeqrfer) {
if lda == 0 {
lda = test.n
}
// Allocate m×n matrix A and fill it with random numbers.
a := make([]float64, m*lda)
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
a[i*lda+j] = rnd.Float64()
}
for i := range a {
a[i] = rnd.NormFloat64()
}
// Store a copy of A for later comparison.
aCopy := make([]float64, len(a))
copy(aCopy, a)
// Allocate a slice for scalar factors of elementary reflectors
// and fill it with random numbers.
tau := make([]float64, n)
for i := 0; i < n; i++ {
tau[i] = rnd.Float64()
}
aCopy := make([]float64, len(a))
copy(aCopy, a)
ans := make([]float64, len(a))
copy(ans, a)
work := make([]float64, n)
// Compute unblocked QR.
impl.Dgeqr2(m, n, ans, lda, tau, work)
// Compute blocked QR with small work.
impl.Dgeqrf(m, n, a, lda, tau, work, len(work))
if !floats.EqualApprox(ans, a, 1e-12) {
t.Errorf("Case %v, mismatch small work.", c)
}
// Try the full length of work.
impl.Dgeqrf(m, n, a, lda, tau, work, -1)
lwork := int(work[0])
work = make([]float64, lwork)
copy(a, aCopy)
impl.Dgeqrf(m, n, a, lda, tau, work, lwork)
if !floats.EqualApprox(ans, a, 1e-12) {
t.Errorf("Case %v, mismatch large work.", c)
}
// Try a slightly smaller version of work to test blocking.
if len(work) <= n {
continue
}
work = work[1:]
lwork--
copy(a, aCopy)
impl.Dgeqrf(m, n, a, lda, tau, work, lwork)
if !floats.EqualApprox(ans, a, 1e-12) {
t.Errorf("Case %v, mismatch large work.", c)
// Compute the expected result using unblocked QR algorithm and
// store it in want.
want := make([]float64, len(a))
copy(want, a)
impl.Dgeqr2(m, n, want, lda, tau, make([]float64, n))
for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
copy(a, aCopy)
var lwork int
switch wl {
case minimumWork:
lwork = n
case mediumWork:
work := make([]float64, 1)
impl.Dgeqrf(m, n, a, lda, tau, work, -1)
lwork = int(work[0]) - 2*n
case optimumWork:
work := make([]float64, 1)
impl.Dgeqrf(m, n, a, lda, tau, work, -1)
lwork = int(work[0])
}
work := make([]float64, lwork)
// Compute the QR factorization of A.
impl.Dgeqrf(m, n, a, lda, tau, work, len(work))
// Compare the result with Dgeqr2.
if !floats.EqualApprox(want, a, tol) {
t.Errorf("Case %v, workspace %v, unexpected result.", c, wl)
}
}
}
}