mirror of
https://github.com/gonum/gonum.git
synced 2025-10-21 14:19:35 +08:00
testlapack: rewrite test for Dsytrd to remove its dependence on Dsytd2
Dsytd2 is not provided by LAPACKE and using blockedTranslate eliminates the possibility to test the work slice.
This commit is contained in:
@@ -10,71 +10,150 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
type Dsytrder interface {
|
||||
Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int)
|
||||
Dsytd2er
|
||||
|
||||
Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
|
||||
Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
|
||||
}
|
||||
|
||||
func DsytrdTest(t *testing.T, impl Dsytrder) {
|
||||
const tol = 1e-13
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
|
||||
for _, test := range []struct {
|
||||
n, lda int
|
||||
}{
|
||||
{10, 0},
|
||||
{50, 0},
|
||||
{100, 0},
|
||||
{150, 0},
|
||||
{300, 0},
|
||||
for tc, test := range []struct {
|
||||
n, lda int
|
||||
}{
|
||||
{1, 0},
|
||||
{2, 0},
|
||||
{3, 0},
|
||||
{4, 0},
|
||||
{10, 0},
|
||||
{50, 0},
|
||||
{100, 0},
|
||||
{150, 0},
|
||||
{300, 0},
|
||||
|
||||
{10, 20},
|
||||
{50, 70},
|
||||
{100, 120},
|
||||
{150, 170},
|
||||
{300, 320},
|
||||
} {
|
||||
n := test.n
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
a := make([]float64, n*lda)
|
||||
for i := range a {
|
||||
a[i] = rnd.NormFloat64()
|
||||
}
|
||||
d2 := make([]float64, n)
|
||||
e2 := make([]float64, n)
|
||||
tau2 := make([]float64, n)
|
||||
{1, 3},
|
||||
{2, 3},
|
||||
{3, 7},
|
||||
{4, 9},
|
||||
{10, 20},
|
||||
{50, 70},
|
||||
{100, 120},
|
||||
{150, 170},
|
||||
{300, 320},
|
||||
} {
|
||||
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
|
||||
for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
|
||||
n := test.n
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
a := randomGeneral(n, n, lda, rnd)
|
||||
for i := 1; i < n; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
a.Data[i*a.Stride+j] = a.Data[j*a.Stride+i]
|
||||
}
|
||||
}
|
||||
aCopy := cloneGeneral(a)
|
||||
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
impl.Dsytd2(uplo, n, a, lda, d2, e2, tau2)
|
||||
aAns := make([]float64, len(a))
|
||||
copy(aAns, a)
|
||||
d := nanSlice(n)
|
||||
e := nanSlice(n - 1)
|
||||
tau := nanSlice(n - 1)
|
||||
|
||||
copy(a, aCopy)
|
||||
d := make([]float64, n)
|
||||
e := make([]float64, n)
|
||||
tau := make([]float64, n)
|
||||
work := make([]float64, 1)
|
||||
impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, -1)
|
||||
work = make([]float64, int(work[0]))
|
||||
impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work))
|
||||
errStr := fmt.Sprintf("upper = %v, n = %v", uplo == blas.Upper, n)
|
||||
if !floats.EqualApprox(a, aAns, 1e-8) {
|
||||
t.Errorf("A mismatch: %s", errStr)
|
||||
}
|
||||
if !floats.EqualApprox(d, d2, 1e-8) {
|
||||
t.Errorf("D mismatch: %s", errStr)
|
||||
}
|
||||
if !floats.EqualApprox(e, e2, 1e-8) {
|
||||
t.Errorf("E mismatch: %s", errStr)
|
||||
}
|
||||
if !floats.EqualApprox(tau, tau2, 1e-8) {
|
||||
t.Errorf("Tau mismatch: %s", errStr)
|
||||
var lwork int
|
||||
switch wl {
|
||||
case minimumWork:
|
||||
lwork = 1
|
||||
case mediumWork:
|
||||
work := make([]float64, 1)
|
||||
impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
|
||||
lwork = (int(work[0]) + 1) / 2
|
||||
lwork = max(1, lwork)
|
||||
case optimumWork:
|
||||
work := make([]float64, 1)
|
||||
impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
|
||||
lwork = int(work[0])
|
||||
}
|
||||
work := make([]float64, lwork)
|
||||
|
||||
impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, lwork)
|
||||
|
||||
prefix := fmt.Sprintf("Case #%v: uplo=%v,n=%v,lda=%v,work=%v",
|
||||
tc, uplo, n, lda, wl)
|
||||
|
||||
if !generalOutsideAllNaN(a) {
|
||||
t.Errorf("%v: out-of-range write to A", prefix)
|
||||
}
|
||||
|
||||
// Extract Q by doing what Dorgtr does.
|
||||
q := cloneGeneral(a)
|
||||
if uplo == blas.Upper {
|
||||
for j := 0; j < n-1; j++ {
|
||||
for i := 0; i < j; i++ {
|
||||
q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j+1]
|
||||
}
|
||||
q.Data[(n-1)*q.Stride+j] = 0
|
||||
}
|
||||
for i := 0; i < n-1; i++ {
|
||||
q.Data[i*q.Stride+n-1] = 0
|
||||
}
|
||||
q.Data[(n-1)*q.Stride+n-1] = 1
|
||||
if n > 1 {
|
||||
work = make([]float64, n-1)
|
||||
impl.Dorgql(n-1, n-1, n-1, q.Data, q.Stride, tau, work, len(work))
|
||||
}
|
||||
} else {
|
||||
for j := n - 1; j > 0; j-- {
|
||||
q.Data[j] = 0
|
||||
for i := j + 1; i < n; i++ {
|
||||
q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j-1]
|
||||
}
|
||||
}
|
||||
q.Data[0] = 1
|
||||
for i := 1; i < n; i++ {
|
||||
q.Data[i*q.Stride] = 0
|
||||
}
|
||||
if n > 1 {
|
||||
work = make([]float64, n-1)
|
||||
impl.Dorgqr(n-1, n-1, n-1, q.Data[q.Stride+1:], q.Stride, tau, work, len(work))
|
||||
}
|
||||
}
|
||||
if !isOrthonormal(q) {
|
||||
t.Errorf("%v: Q not orthogonal", prefix)
|
||||
}
|
||||
|
||||
// Contruct symmetric tridiagonal T from d and e.
|
||||
tMat := zeros(n, n, n)
|
||||
for i := 0; i < n; i++ {
|
||||
tMat.Data[i*tMat.Stride+i] = d[i]
|
||||
}
|
||||
if uplo == blas.Upper {
|
||||
for j := 1; j < n; j++ {
|
||||
tMat.Data[(j-1)*tMat.Stride+j] = e[j-1]
|
||||
tMat.Data[j*tMat.Stride+j-1] = e[j-1]
|
||||
}
|
||||
} else {
|
||||
for j := 0; j < n-1; j++ {
|
||||
tMat.Data[(j+1)*tMat.Stride+j] = e[j]
|
||||
tMat.Data[j*tMat.Stride+j+1] = e[j]
|
||||
}
|
||||
}
|
||||
|
||||
// Compute Q^T * A * Q.
|
||||
tmp := zeros(n, n, n)
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aCopy, 0, tmp)
|
||||
got := zeros(n, n, n)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, got)
|
||||
|
||||
// Compare with T.
|
||||
if !equalApproxGeneral(got, tMat, tol) {
|
||||
t.Errorf("%v: Q^T*A*Q != T", prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user