diff --git a/lapack/testlapack/dgetri.go b/lapack/testlapack/dgetri.go index 7e339c1c..4ae08436 100644 --- a/lapack/testlapack/dgetri.go +++ b/lapack/testlapack/dgetri.go @@ -5,7 +5,6 @@ package testlapack import ( - "math" "testing" "golang.org/x/exp/rand" @@ -67,23 +66,8 @@ func DgetriTest(t *testing.T, impl Dgetrier) { // Check that A(inv) * A = I. ans := make([]float64, len(a)) bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, aCopy, lda, a, lda, 0, ans, lda) - isEye := true - for i := 0; i < n; i++ { - for j := 0; j < n; j++ { - if i == j { - // This tolerance is so high because computing matrix inverses - // is very unstable. - if math.Abs(ans[i*lda+j]-1) > 5e-2 { - isEye = false - } - } else { - if math.Abs(ans[i*lda+j]) > 5e-2 { - isEye = false - } - } - } - } - if !isEye { + // The tolerance is so high because computing matrix inverses is very unstable. + if !isIdentity(n, ans, lda, 5e-2) { t.Errorf("Inv(A) * A != I. n = %v, lda = %v", n, lda) } } diff --git a/lapack/testlapack/dlarfg.go b/lapack/testlapack/dlarfg.go index f0cee867..dbed5d4d 100644 --- a/lapack/testlapack/dlarfg.go +++ b/lapack/testlapack/dlarfg.go @@ -104,21 +104,7 @@ func DlarfgTest(t *testing.T, impl Dlarfger) { Data: make([]float64, n*n), } blas64.Gemm(blas.Trans, blas.NoTrans, 1, hmat, hmat, 0, eye) - iseye := true - for i := 0; i < n; i++ { - for j := 0; j < n; j++ { - if i == j { - if math.Abs(eye.Data[i*n+j]-1) > 1e-14 { - iseye = false - } - } else { - if math.Abs(eye.Data[i*n+j]) > 1e-14 { - iseye = false - } - } - } - } - if !iseye { + if !isIdentity(n, eye.Data, n, 1e-14) { t.Errorf("H^T * H is not I %v", eye) } diff --git a/lapack/testlapack/dtrti2.go b/lapack/testlapack/dtrti2.go index 71950a62..40a4b874 100644 --- a/lapack/testlapack/dtrti2.go +++ b/lapack/testlapack/dtrti2.go @@ -5,7 +5,6 @@ package testlapack import ( - "math" "testing" "golang.org/x/exp/rand" @@ -149,23 +148,7 @@ func Dtrti2Test(t *testing.T, impl Dtrti2er) { ans := make([]float64, len(a)) bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda) // Check that ans is the identity matrix. - iseye := true - for i := 0; i < n; i++ { - for j := 0; j < n; j++ { - if i == j { - if math.Abs(ans[i*lda+i]-1) > tol { - iseye = false - break - } - } else { - if math.Abs(ans[i*lda+j]) > tol { - iseye = false - break - } - } - } - } - if !iseye { + if !isIdentity(n, ans, lda, tol) { t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, ans = %v", uplo == blas.Upper, diag == blas.Unit, ans) } } diff --git a/lapack/testlapack/dtrtri.go b/lapack/testlapack/dtrtri.go index 3eaa3f82..4b0345de 100644 --- a/lapack/testlapack/dtrtri.go +++ b/lapack/testlapack/dtrtri.go @@ -5,7 +5,6 @@ package testlapack import ( - "math" "testing" "golang.org/x/exp/rand" @@ -80,23 +79,7 @@ func DtrtriTest(t *testing.T, impl Dtrtrier) { ans := make([]float64, len(a)) bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda) // Check that ans is the identity matrix. - iseye := true - for i := 0; i < n; i++ { - for j := 0; j < n; j++ { - if i == j { - if math.Abs(ans[i*lda+i]-1) > tol { - iseye = false - break - } - } else { - if math.Abs(ans[i*lda+j]) > tol { - iseye = false - break - } - } - } - } - if !iseye { + if !isIdentity(n, ans, lda, tol) { t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, n = %v, lda = %v", uplo == blas.Upper, diag == blas.Unit, n, lda) } diff --git a/lapack/testlapack/general.go b/lapack/testlapack/general.go index 2206be27..d6d97e8f 100644 --- a/lapack/testlapack/general.go +++ b/lapack/testlapack/general.go @@ -1464,3 +1464,28 @@ func constructGSVPresults(n, p, m, k, l int, a, b blas64.General) (zeroA, zeroB return zeroA, zeroB } + +// isIdentity returns whether an n×n matrix A is approximately equal to the +// identity matrix. +func isIdentity(n int, a []float64, lda int, tol float64) bool { + for i := 0; i < n; i++ { + for j := 0; j < n; j++ { + aij := a[i*lda+j] + if math.IsNaN(aij) { + return false + } + if i == j { + if math.Abs(aij-1) > tol { + fmt.Println(i, j, aij) + return false + } + } else { + if math.Abs(aij) > tol { + fmt.Println(i, j, aij) + return false + } + } + } + } + return true +} diff --git a/lapack/testlapack/matgen_test.go b/lapack/testlapack/matgen_test.go index e631a114..0e77ca02 100644 --- a/lapack/testlapack/matgen_test.go +++ b/lapack/testlapack/matgen_test.go @@ -5,7 +5,6 @@ package testlapack import ( - "math" "testing" "golang.org/x/exp/rand" @@ -21,40 +20,19 @@ func TestDlagsy(t *testing.T) { if lda == 0 { lda = max(1, n) } + // D is the identity matrix I. d := make([]float64, n) for i := range d { d[i] = 1 } - a := blas64.General{ - Rows: n, - Cols: n, - Stride: lda, - Data: nanSlice(n * lda), - } - work := make([]float64, a.Rows+a.Cols) - - Dlagsy(a.Rows, 0, d, a.Data, a.Stride, rnd, work) - - isIdentity := true - identityLoop: - for i := 0; i < n; i++ { - for j := 0; j < n; j++ { - aij := a.Data[i*a.Stride+j] - if math.IsNaN(aij) { - isIdentity = false - } - if i == j && math.Abs(aij-1) > tol { - isIdentity = false - } - if i != j && math.Abs(aij) > tol { - isIdentity = false - } - if !isIdentity { - break identityLoop - } - } - } - if !isIdentity { + // Allocate an n×n symmetric matrix A and fill it with NaNs. + a := nanSlice(n * lda) + work := make([]float64, 2*n) + // Compute A = U * D * U^T where U is a random orthogonal matrix. + Dlagsy(n, 0, d, a, lda, rnd, work) + // A should be the identity matrix because + // A = U * D * U^T = U * I * U^T = U * U^T = I. + if !isIdentity(n, a, lda, tol) { t.Errorf("Case n=%v,lda=%v: unexpected result", n, lda) } }