diff --git a/testlapack/dsytrd.go b/testlapack/dsytrd.go index 6c7f5728..1bd70188 100644 --- a/testlapack/dsytrd.go +++ b/testlapack/dsytrd.go @@ -10,71 +10,150 @@ import ( "testing" "github.com/gonum/blas" - "github.com/gonum/floats" + "github.com/gonum/blas/blas64" ) type Dsytrder interface { Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int) - Dsytd2er + + Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) + Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) } func DsytrdTest(t *testing.T, impl Dsytrder) { + const tol = 1e-13 rnd := rand.New(rand.NewSource(1)) - for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { - for _, test := range []struct { - n, lda int - }{ - {10, 0}, - {50, 0}, - {100, 0}, - {150, 0}, - {300, 0}, + for tc, test := range []struct { + n, lda int + }{ + {1, 0}, + {2, 0}, + {3, 0}, + {4, 0}, + {10, 0}, + {50, 0}, + {100, 0}, + {150, 0}, + {300, 0}, - {10, 20}, - {50, 70}, - {100, 120}, - {150, 170}, - {300, 320}, - } { - n := test.n - lda := test.lda - if lda == 0 { - lda = n - } - a := make([]float64, n*lda) - for i := range a { - a[i] = rnd.NormFloat64() - } - d2 := make([]float64, n) - e2 := make([]float64, n) - tau2 := make([]float64, n) + {1, 3}, + {2, 3}, + {3, 7}, + {4, 9}, + {10, 20}, + {50, 70}, + {100, 120}, + {150, 170}, + {300, 320}, + } { + for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { + for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} { + n := test.n + lda := test.lda + if lda == 0 { + lda = n + } + a := randomGeneral(n, n, lda, rnd) + for i := 1; i < n; i++ { + for j := 0; j < i; j++ { + a.Data[i*a.Stride+j] = a.Data[j*a.Stride+i] + } + } + aCopy := cloneGeneral(a) - aCopy := make([]float64, len(a)) - copy(aCopy, a) - impl.Dsytd2(uplo, n, a, lda, d2, e2, tau2) - aAns := make([]float64, len(a)) - copy(aAns, a) + d := nanSlice(n) + e := nanSlice(n - 1) + tau := nanSlice(n - 1) - copy(a, aCopy) - d := make([]float64, n) - e := make([]float64, n) - tau := make([]float64, n) - work := make([]float64, 1) - impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, -1) - work = make([]float64, int(work[0])) - impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work)) - errStr := fmt.Sprintf("upper = %v, n = %v", uplo == blas.Upper, n) - if !floats.EqualApprox(a, aAns, 1e-8) { - t.Errorf("A mismatch: %s", errStr) - } - if !floats.EqualApprox(d, d2, 1e-8) { - t.Errorf("D mismatch: %s", errStr) - } - if !floats.EqualApprox(e, e2, 1e-8) { - t.Errorf("E mismatch: %s", errStr) - } - if !floats.EqualApprox(tau, tau2, 1e-8) { - t.Errorf("Tau mismatch: %s", errStr) + var lwork int + switch wl { + case minimumWork: + lwork = 1 + case mediumWork: + work := make([]float64, 1) + impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1) + lwork = (int(work[0]) + 1) / 2 + lwork = max(1, lwork) + case optimumWork: + work := make([]float64, 1) + impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1) + lwork = int(work[0]) + } + work := make([]float64, lwork) + + impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, lwork) + + prefix := fmt.Sprintf("Case #%v: uplo=%v,n=%v,lda=%v,work=%v", + tc, uplo, n, lda, wl) + + if !generalOutsideAllNaN(a) { + t.Errorf("%v: out-of-range write to A", prefix) + } + + // Extract Q by doing what Dorgtr does. + q := cloneGeneral(a) + if uplo == blas.Upper { + for j := 0; j < n-1; j++ { + for i := 0; i < j; i++ { + q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j+1] + } + q.Data[(n-1)*q.Stride+j] = 0 + } + for i := 0; i < n-1; i++ { + q.Data[i*q.Stride+n-1] = 0 + } + q.Data[(n-1)*q.Stride+n-1] = 1 + if n > 1 { + work = make([]float64, n-1) + impl.Dorgql(n-1, n-1, n-1, q.Data, q.Stride, tau, work, len(work)) + } + } else { + for j := n - 1; j > 0; j-- { + q.Data[j] = 0 + for i := j + 1; i < n; i++ { + q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j-1] + } + } + q.Data[0] = 1 + for i := 1; i < n; i++ { + q.Data[i*q.Stride] = 0 + } + if n > 1 { + work = make([]float64, n-1) + impl.Dorgqr(n-1, n-1, n-1, q.Data[q.Stride+1:], q.Stride, tau, work, len(work)) + } + } + if !isOrthonormal(q) { + t.Errorf("%v: Q not orthogonal", prefix) + } + + // Contruct symmetric tridiagonal T from d and e. + tMat := zeros(n, n, n) + for i := 0; i < n; i++ { + tMat.Data[i*tMat.Stride+i] = d[i] + } + if uplo == blas.Upper { + for j := 1; j < n; j++ { + tMat.Data[(j-1)*tMat.Stride+j] = e[j-1] + tMat.Data[j*tMat.Stride+j-1] = e[j-1] + } + } else { + for j := 0; j < n-1; j++ { + tMat.Data[(j+1)*tMat.Stride+j] = e[j] + tMat.Data[j*tMat.Stride+j+1] = e[j] + } + } + + // Compute Q^T * A * Q. + tmp := zeros(n, n, n) + blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aCopy, 0, tmp) + got := zeros(n, n, n) + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, got) + + // Compare with T. + if !equalApproxGeneral(got, tMat, tol) { + t.Errorf("%v: Q^T*A*Q != T", prefix) + } } } }