From 77e56a422f9c8ed4d2ab9ea915fc0188c52bdd2d Mon Sep 17 00:00:00 2001 From: Vladimir Chalupecky Date: Fri, 18 Nov 2016 16:21:41 +0100 Subject: [PATCH] testlapack: rewrite test for Dorgql The test was using LAPACK routines that are not exposed by LAPACKE and therefore the cgo implementation could not be tested. --- testlapack/dorgql.go | 131 +++++++++++++++++++++++++++++++------------ 1 file changed, 96 insertions(+), 35 deletions(-) diff --git a/testlapack/dorgql.go b/testlapack/dorgql.go index 52ad21ba..591c74cd 100644 --- a/testlapack/dorgql.go +++ b/testlapack/dorgql.go @@ -5,57 +5,118 @@ package testlapack import ( + "fmt" "math/rand" "testing" - "github.com/gonum/floats" + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" ) type Dorgqler interface { Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) - Dorg2ler + + Dlarfger } func DorgqlTest(t *testing.T, impl Dorgqler) { + const tol = 1e-14 + + type Dorg2ler interface { + Dorg2l(m, n, k int, a []float64, lda int, tau, work []float64) + } + dorg2ler, hasDorg2l := impl.(Dorg2ler) + rnd := rand.New(rand.NewSource(1)) - for _, test := range []struct { - m, n, k, lda int - }{ - {5, 4, 3, 0}, - {100, 100, 100, 0}, - {200, 100, 50, 0}, - {200, 200, 50, 0}, - } { - m := test.m - n := test.n - k := test.k - lda := test.lda - if lda == 0 { - lda = n - } - a := make([]float64, m*lda) - for i := range a { - a[i] = rnd.NormFloat64() - } - tau := nanSlice(min(m, n)) - work := nanSlice(max(m, n)) + for _, m := range []int{0, 1, 2, 3, 4, 5, 7, 10, 15, 30, 50, 150} { + for _, extra := range []int{0, 11} { + for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} { + n := rnd.Intn(m + 1) + k := rnd.Intn(n + 1) + if m == 0 || n == 0 { + m = 0 + n = 0 + k = 0 + } - impl.Dgeql2(m, n, a, lda, tau, work) + // Generate k elementary reflectors in the last + // k columns of A. + a := nanGeneral(m, n, n+extra) + tau := make([]float64, k) + for l := 0; l < k; l++ { + jj := m - k + l + v := randomSlice(jj, rnd) + _, tau[l] = impl.Dlarfg(len(v)+1, rnd.NormFloat64(), v, 1) + j := n - k + l + for i := 0; i < jj; i++ { + a.Data[i*a.Stride+j] = v[i] + } + } + aCopy := cloneGeneral(a) - aCopy := make([]float64, len(a)) - copy(aCopy, a) + // Compute the full matrix Q by forming the + // Householder reflectors explicitly. + q := eye(m, m) + qCopy := eye(m, m) + for l := 0; l < k; l++ { + h := eye(m, m) + jj := m - k + l + j := n - k + l + v := blas64.Vector{1, make([]float64, m)} + for i := 0; i < jj; i++ { + v.Data[i] = a.Data[i*a.Stride+j] + } + v.Data[jj] = 1 + blas64.Ger(-tau[l], v, v, h) + copy(qCopy.Data, q.Data) + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qCopy, 0, q) + } + // View the last n columns of Q as 'want'. + want := blas64.General{ + Rows: m, + Cols: n, + Stride: q.Stride, + Data: q.Data[m-n:], + } - impl.Dorg2l(m, n, k, a, lda, tau, work) - ans := make([]float64, len(a)) - copy(ans, a) + var lwork int + switch wl { + case minimumWork: + lwork = max(1, n) + case mediumWork: + work := make([]float64, 1) + impl.Dorgql(m, n, k, nil, a.Stride, nil, work, -1) + lwork = (int(work[0]) + n) / 2 + lwork = max(1, lwork) + case optimumWork: + work := make([]float64, 1) + impl.Dorgql(m, n, k, nil, a.Stride, nil, work, -1) + lwork = int(work[0]) + } + work := make([]float64, lwork) - impl.Dorgql(m, n, k, a, lda, tau, work, -1) - work = make([]float64, int(work[0])) - copy(a, aCopy) - impl.Dorgql(m, n, k, a, lda, tau, work, len(work)) + // Compute the last n columns of Q by a call to + // Dorgql. + impl.Dorgql(m, n, k, a.Data, a.Stride, tau, work, len(work)) - if !floats.EqualApprox(a, ans, 1e-8) { - t.Errorf("Answer mismatch. m = %v, n = %v, k = %v", m, n, k) + prefix := fmt.Sprintf("Case m=%v,n=%v,k=%v,wl=%v", m, n, k, wl) + if !generalOutsideAllNaN(a) { + t.Errorf("%v: out-of-range write to A", prefix) + } + if !equalApproxGeneral(want, a, tol) { + t.Errorf("%v: unexpected Q", prefix) + } + + // Compute the last n columns of Q by a call to + // Dorg2l and check that we get the same result. + if !hasDorg2l { + continue + } + dorg2ler.Dorg2l(m, n, k, aCopy.Data, aCopy.Stride, tau, work) + if !equalApproxGeneral(aCopy, a, tol) { + t.Errorf("%v: mismatch between Dorgql and Dorg2l", prefix) + } + } } } }