mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00
testlapack: add isIdentity helper
This commit is contained in:

committed by
Vladimír Chalupecký

parent
27d556d1f9
commit
87489715e5
@@ -5,7 +5,6 @@
|
|||||||
package testlapack
|
package testlapack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
@@ -67,23 +66,8 @@ func DgetriTest(t *testing.T, impl Dgetrier) {
|
|||||||
// Check that A(inv) * A = I.
|
// Check that A(inv) * A = I.
|
||||||
ans := make([]float64, len(a))
|
ans := make([]float64, len(a))
|
||||||
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, aCopy, lda, a, lda, 0, ans, lda)
|
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, aCopy, lda, a, lda, 0, ans, lda)
|
||||||
isEye := true
|
// The tolerance is so high because computing matrix inverses is very unstable.
|
||||||
for i := 0; i < n; i++ {
|
if !isIdentity(n, ans, lda, 5e-2) {
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
if i == j {
|
|
||||||
// This tolerance is so high because computing matrix inverses
|
|
||||||
// is very unstable.
|
|
||||||
if math.Abs(ans[i*lda+j]-1) > 5e-2 {
|
|
||||||
isEye = false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if math.Abs(ans[i*lda+j]) > 5e-2 {
|
|
||||||
isEye = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !isEye {
|
|
||||||
t.Errorf("Inv(A) * A != I. n = %v, lda = %v", n, lda)
|
t.Errorf("Inv(A) * A != I. n = %v, lda = %v", n, lda)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -104,21 +104,7 @@ func DlarfgTest(t *testing.T, impl Dlarfger) {
|
|||||||
Data: make([]float64, n*n),
|
Data: make([]float64, n*n),
|
||||||
}
|
}
|
||||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, hmat, hmat, 0, eye)
|
blas64.Gemm(blas.Trans, blas.NoTrans, 1, hmat, hmat, 0, eye)
|
||||||
iseye := true
|
if !isIdentity(n, eye.Data, n, 1e-14) {
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
if i == j {
|
|
||||||
if math.Abs(eye.Data[i*n+j]-1) > 1e-14 {
|
|
||||||
iseye = false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if math.Abs(eye.Data[i*n+j]) > 1e-14 {
|
|
||||||
iseye = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !iseye {
|
|
||||||
t.Errorf("H^T * H is not I %v", eye)
|
t.Errorf("H^T * H is not I %v", eye)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -5,7 +5,6 @@
|
|||||||
package testlapack
|
package testlapack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
@@ -149,23 +148,7 @@ func Dtrti2Test(t *testing.T, impl Dtrti2er) {
|
|||||||
ans := make([]float64, len(a))
|
ans := make([]float64, len(a))
|
||||||
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda)
|
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda)
|
||||||
// Check that ans is the identity matrix.
|
// Check that ans is the identity matrix.
|
||||||
iseye := true
|
if !isIdentity(n, ans, lda, tol) {
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
if i == j {
|
|
||||||
if math.Abs(ans[i*lda+i]-1) > tol {
|
|
||||||
iseye = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if math.Abs(ans[i*lda+j]) > tol {
|
|
||||||
iseye = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !iseye {
|
|
||||||
t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, ans = %v", uplo == blas.Upper, diag == blas.Unit, ans)
|
t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, ans = %v", uplo == blas.Upper, diag == blas.Unit, ans)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -5,7 +5,6 @@
|
|||||||
package testlapack
|
package testlapack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
@@ -80,23 +79,7 @@ func DtrtriTest(t *testing.T, impl Dtrtrier) {
|
|||||||
ans := make([]float64, len(a))
|
ans := make([]float64, len(a))
|
||||||
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda)
|
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda)
|
||||||
// Check that ans is the identity matrix.
|
// Check that ans is the identity matrix.
|
||||||
iseye := true
|
if !isIdentity(n, ans, lda, tol) {
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
if i == j {
|
|
||||||
if math.Abs(ans[i*lda+i]-1) > tol {
|
|
||||||
iseye = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if math.Abs(ans[i*lda+j]) > tol {
|
|
||||||
iseye = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !iseye {
|
|
||||||
t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, n = %v, lda = %v",
|
t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, n = %v, lda = %v",
|
||||||
uplo == blas.Upper, diag == blas.Unit, n, lda)
|
uplo == blas.Upper, diag == blas.Unit, n, lda)
|
||||||
}
|
}
|
||||||
|
@@ -1464,3 +1464,28 @@ func constructGSVPresults(n, p, m, k, l int, a, b blas64.General) (zeroA, zeroB
|
|||||||
|
|
||||||
return zeroA, zeroB
|
return zeroA, zeroB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isIdentity returns whether an n×n matrix A is approximately equal to the
|
||||||
|
// identity matrix.
|
||||||
|
func isIdentity(n int, a []float64, lda int, tol float64) bool {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
aij := a[i*lda+j]
|
||||||
|
if math.IsNaN(aij) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if i == j {
|
||||||
|
if math.Abs(aij-1) > tol {
|
||||||
|
fmt.Println(i, j, aij)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if math.Abs(aij) > tol {
|
||||||
|
fmt.Println(i, j, aij)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
@@ -5,7 +5,6 @@
|
|||||||
package testlapack
|
package testlapack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
@@ -21,40 +20,19 @@ func TestDlagsy(t *testing.T) {
|
|||||||
if lda == 0 {
|
if lda == 0 {
|
||||||
lda = max(1, n)
|
lda = max(1, n)
|
||||||
}
|
}
|
||||||
|
// D is the identity matrix I.
|
||||||
d := make([]float64, n)
|
d := make([]float64, n)
|
||||||
for i := range d {
|
for i := range d {
|
||||||
d[i] = 1
|
d[i] = 1
|
||||||
}
|
}
|
||||||
a := blas64.General{
|
// Allocate an n×n symmetric matrix A and fill it with NaNs.
|
||||||
Rows: n,
|
a := nanSlice(n * lda)
|
||||||
Cols: n,
|
work := make([]float64, 2*n)
|
||||||
Stride: lda,
|
// Compute A = U * D * U^T where U is a random orthogonal matrix.
|
||||||
Data: nanSlice(n * lda),
|
Dlagsy(n, 0, d, a, lda, rnd, work)
|
||||||
}
|
// A should be the identity matrix because
|
||||||
work := make([]float64, a.Rows+a.Cols)
|
// A = U * D * U^T = U * I * U^T = U * U^T = I.
|
||||||
|
if !isIdentity(n, a, lda, tol) {
|
||||||
Dlagsy(a.Rows, 0, d, a.Data, a.Stride, rnd, work)
|
|
||||||
|
|
||||||
isIdentity := true
|
|
||||||
identityLoop:
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
aij := a.Data[i*a.Stride+j]
|
|
||||||
if math.IsNaN(aij) {
|
|
||||||
isIdentity = false
|
|
||||||
}
|
|
||||||
if i == j && math.Abs(aij-1) > tol {
|
|
||||||
isIdentity = false
|
|
||||||
}
|
|
||||||
if i != j && math.Abs(aij) > tol {
|
|
||||||
isIdentity = false
|
|
||||||
}
|
|
||||||
if !isIdentity {
|
|
||||||
break identityLoop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !isIdentity {
|
|
||||||
t.Errorf("Case n=%v,lda=%v: unexpected result", n, lda)
|
t.Errorf("Case n=%v,lda=%v: unexpected result", n, lda)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user