testlapack: rewrite test for Dsytrd to remove its dependence on Dsytd2

Dsytd2 is not provided by LAPACKE and using blockedTranslate eliminates the
possibility to test the work slice.
This commit is contained in:
Vladimir Chalupecky
2017-02-04 23:38:07 +01:00
parent 8f019c8eaa
commit ad7b4a7b7f

View File

@@ -10,71 +10,150 @@ import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
"github.com/gonum/blas/blas64"
)
type Dsytrder interface {
Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int)
Dsytd2er
Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
}
func DsytrdTest(t *testing.T, impl Dsytrder) {
const tol = 1e-13
rnd := rand.New(rand.NewSource(1))
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
for _, test := range []struct {
for tc, test := range []struct {
n, lda int
}{
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{10, 0},
{50, 0},
{100, 0},
{150, 0},
{300, 0},
{1, 3},
{2, 3},
{3, 7},
{4, 9},
{10, 20},
{50, 70},
{100, 120},
{150, 170},
{300, 320},
} {
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
n := test.n
lda := test.lda
if lda == 0 {
lda = n
}
a := make([]float64, n*lda)
for i := range a {
a[i] = rnd.NormFloat64()
a := randomGeneral(n, n, lda, rnd)
for i := 1; i < n; i++ {
for j := 0; j < i; j++ {
a.Data[i*a.Stride+j] = a.Data[j*a.Stride+i]
}
d2 := make([]float64, n)
e2 := make([]float64, n)
tau2 := make([]float64, n)
}
aCopy := cloneGeneral(a)
aCopy := make([]float64, len(a))
copy(aCopy, a)
impl.Dsytd2(uplo, n, a, lda, d2, e2, tau2)
aAns := make([]float64, len(a))
copy(aAns, a)
d := nanSlice(n)
e := nanSlice(n - 1)
tau := nanSlice(n - 1)
copy(a, aCopy)
d := make([]float64, n)
e := make([]float64, n)
tau := make([]float64, n)
var lwork int
switch wl {
case minimumWork:
lwork = 1
case mediumWork:
work := make([]float64, 1)
impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, -1)
work = make([]float64, int(work[0]))
impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work))
errStr := fmt.Sprintf("upper = %v, n = %v", uplo == blas.Upper, n)
if !floats.EqualApprox(a, aAns, 1e-8) {
t.Errorf("A mismatch: %s", errStr)
impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
lwork = (int(work[0]) + 1) / 2
lwork = max(1, lwork)
case optimumWork:
work := make([]float64, 1)
impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
lwork = int(work[0])
}
if !floats.EqualApprox(d, d2, 1e-8) {
t.Errorf("D mismatch: %s", errStr)
work := make([]float64, lwork)
impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, lwork)
prefix := fmt.Sprintf("Case #%v: uplo=%v,n=%v,lda=%v,work=%v",
tc, uplo, n, lda, wl)
if !generalOutsideAllNaN(a) {
t.Errorf("%v: out-of-range write to A", prefix)
}
if !floats.EqualApprox(e, e2, 1e-8) {
t.Errorf("E mismatch: %s", errStr)
// Extract Q by doing what Dorgtr does.
q := cloneGeneral(a)
if uplo == blas.Upper {
for j := 0; j < n-1; j++ {
for i := 0; i < j; i++ {
q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j+1]
}
q.Data[(n-1)*q.Stride+j] = 0
}
for i := 0; i < n-1; i++ {
q.Data[i*q.Stride+n-1] = 0
}
q.Data[(n-1)*q.Stride+n-1] = 1
if n > 1 {
work = make([]float64, n-1)
impl.Dorgql(n-1, n-1, n-1, q.Data, q.Stride, tau, work, len(work))
}
} else {
for j := n - 1; j > 0; j-- {
q.Data[j] = 0
for i := j + 1; i < n; i++ {
q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j-1]
}
}
q.Data[0] = 1
for i := 1; i < n; i++ {
q.Data[i*q.Stride] = 0
}
if n > 1 {
work = make([]float64, n-1)
impl.Dorgqr(n-1, n-1, n-1, q.Data[q.Stride+1:], q.Stride, tau, work, len(work))
}
}
if !isOrthonormal(q) {
t.Errorf("%v: Q not orthogonal", prefix)
}
// Contruct symmetric tridiagonal T from d and e.
tMat := zeros(n, n, n)
for i := 0; i < n; i++ {
tMat.Data[i*tMat.Stride+i] = d[i]
}
if uplo == blas.Upper {
for j := 1; j < n; j++ {
tMat.Data[(j-1)*tMat.Stride+j] = e[j-1]
tMat.Data[j*tMat.Stride+j-1] = e[j-1]
}
} else {
for j := 0; j < n-1; j++ {
tMat.Data[(j+1)*tMat.Stride+j] = e[j]
tMat.Data[j*tMat.Stride+j+1] = e[j]
}
}
// Compute Q^T * A * Q.
tmp := zeros(n, n, n)
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aCopy, 0, tmp)
got := zeros(n, n, n)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, got)
// Compare with T.
if !equalApproxGeneral(got, tMat, tol) {
t.Errorf("%v: Q^T*A*Q != T", prefix)
}
if !floats.EqualApprox(tau, tau2, 1e-8) {
t.Errorf("Tau mismatch: %s", errStr)
}
}
}