mirror of
https://github.com/gonum/gonum.git
synced 2025-10-22 06:39:26 +08:00
testlapack: rewrite test for Dsytrd to remove its dependence on Dsytd2
Dsytd2 is not provided by LAPACKE and using blockedTranslate eliminates the possibility to test the work slice.
This commit is contained in:
@@ -10,71 +10,150 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
type Dsytrder interface {
|
||||
Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int)
|
||||
Dsytd2er
|
||||
|
||||
Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
|
||||
Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
|
||||
}
|
||||
|
||||
func DsytrdTest(t *testing.T, impl Dsytrder) {
|
||||
const tol = 1e-13
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
|
||||
for _, test := range []struct {
|
||||
for tc, test := range []struct {
|
||||
n, lda int
|
||||
}{
|
||||
{1, 0},
|
||||
{2, 0},
|
||||
{3, 0},
|
||||
{4, 0},
|
||||
{10, 0},
|
||||
{50, 0},
|
||||
{100, 0},
|
||||
{150, 0},
|
||||
{300, 0},
|
||||
|
||||
{1, 3},
|
||||
{2, 3},
|
||||
{3, 7},
|
||||
{4, 9},
|
||||
{10, 20},
|
||||
{50, 70},
|
||||
{100, 120},
|
||||
{150, 170},
|
||||
{300, 320},
|
||||
} {
|
||||
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
|
||||
for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
|
||||
n := test.n
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
a := make([]float64, n*lda)
|
||||
for i := range a {
|
||||
a[i] = rnd.NormFloat64()
|
||||
a := randomGeneral(n, n, lda, rnd)
|
||||
for i := 1; i < n; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
a.Data[i*a.Stride+j] = a.Data[j*a.Stride+i]
|
||||
}
|
||||
d2 := make([]float64, n)
|
||||
e2 := make([]float64, n)
|
||||
tau2 := make([]float64, n)
|
||||
}
|
||||
aCopy := cloneGeneral(a)
|
||||
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
impl.Dsytd2(uplo, n, a, lda, d2, e2, tau2)
|
||||
aAns := make([]float64, len(a))
|
||||
copy(aAns, a)
|
||||
d := nanSlice(n)
|
||||
e := nanSlice(n - 1)
|
||||
tau := nanSlice(n - 1)
|
||||
|
||||
copy(a, aCopy)
|
||||
d := make([]float64, n)
|
||||
e := make([]float64, n)
|
||||
tau := make([]float64, n)
|
||||
var lwork int
|
||||
switch wl {
|
||||
case minimumWork:
|
||||
lwork = 1
|
||||
case mediumWork:
|
||||
work := make([]float64, 1)
|
||||
impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, -1)
|
||||
work = make([]float64, int(work[0]))
|
||||
impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work))
|
||||
errStr := fmt.Sprintf("upper = %v, n = %v", uplo == blas.Upper, n)
|
||||
if !floats.EqualApprox(a, aAns, 1e-8) {
|
||||
t.Errorf("A mismatch: %s", errStr)
|
||||
impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
|
||||
lwork = (int(work[0]) + 1) / 2
|
||||
lwork = max(1, lwork)
|
||||
case optimumWork:
|
||||
work := make([]float64, 1)
|
||||
impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
|
||||
lwork = int(work[0])
|
||||
}
|
||||
if !floats.EqualApprox(d, d2, 1e-8) {
|
||||
t.Errorf("D mismatch: %s", errStr)
|
||||
work := make([]float64, lwork)
|
||||
|
||||
impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, lwork)
|
||||
|
||||
prefix := fmt.Sprintf("Case #%v: uplo=%v,n=%v,lda=%v,work=%v",
|
||||
tc, uplo, n, lda, wl)
|
||||
|
||||
if !generalOutsideAllNaN(a) {
|
||||
t.Errorf("%v: out-of-range write to A", prefix)
|
||||
}
|
||||
if !floats.EqualApprox(e, e2, 1e-8) {
|
||||
t.Errorf("E mismatch: %s", errStr)
|
||||
|
||||
// Extract Q by doing what Dorgtr does.
|
||||
q := cloneGeneral(a)
|
||||
if uplo == blas.Upper {
|
||||
for j := 0; j < n-1; j++ {
|
||||
for i := 0; i < j; i++ {
|
||||
q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j+1]
|
||||
}
|
||||
q.Data[(n-1)*q.Stride+j] = 0
|
||||
}
|
||||
for i := 0; i < n-1; i++ {
|
||||
q.Data[i*q.Stride+n-1] = 0
|
||||
}
|
||||
q.Data[(n-1)*q.Stride+n-1] = 1
|
||||
if n > 1 {
|
||||
work = make([]float64, n-1)
|
||||
impl.Dorgql(n-1, n-1, n-1, q.Data, q.Stride, tau, work, len(work))
|
||||
}
|
||||
} else {
|
||||
for j := n - 1; j > 0; j-- {
|
||||
q.Data[j] = 0
|
||||
for i := j + 1; i < n; i++ {
|
||||
q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j-1]
|
||||
}
|
||||
}
|
||||
q.Data[0] = 1
|
||||
for i := 1; i < n; i++ {
|
||||
q.Data[i*q.Stride] = 0
|
||||
}
|
||||
if n > 1 {
|
||||
work = make([]float64, n-1)
|
||||
impl.Dorgqr(n-1, n-1, n-1, q.Data[q.Stride+1:], q.Stride, tau, work, len(work))
|
||||
}
|
||||
}
|
||||
if !isOrthonormal(q) {
|
||||
t.Errorf("%v: Q not orthogonal", prefix)
|
||||
}
|
||||
|
||||
// Contruct symmetric tridiagonal T from d and e.
|
||||
tMat := zeros(n, n, n)
|
||||
for i := 0; i < n; i++ {
|
||||
tMat.Data[i*tMat.Stride+i] = d[i]
|
||||
}
|
||||
if uplo == blas.Upper {
|
||||
for j := 1; j < n; j++ {
|
||||
tMat.Data[(j-1)*tMat.Stride+j] = e[j-1]
|
||||
tMat.Data[j*tMat.Stride+j-1] = e[j-1]
|
||||
}
|
||||
} else {
|
||||
for j := 0; j < n-1; j++ {
|
||||
tMat.Data[(j+1)*tMat.Stride+j] = e[j]
|
||||
tMat.Data[j*tMat.Stride+j+1] = e[j]
|
||||
}
|
||||
}
|
||||
|
||||
// Compute Q^T * A * Q.
|
||||
tmp := zeros(n, n, n)
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aCopy, 0, tmp)
|
||||
got := zeros(n, n, n)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, got)
|
||||
|
||||
// Compare with T.
|
||||
if !equalApproxGeneral(got, tMat, tol) {
|
||||
t.Errorf("%v: Q^T*A*Q != T", prefix)
|
||||
}
|
||||
if !floats.EqualApprox(tau, tau2, 1e-8) {
|
||||
t.Errorf("Tau mismatch: %s", errStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user