diff --git a/blas/testblas/common.go b/blas/testblas/common.go index e7509105..8038bf57 100644 --- a/blas/testblas/common.go +++ b/blas/testblas/common.go @@ -254,22 +254,46 @@ func allPairs(x, y []int) [][2]int { return p } +func sameFloat64(a, b float64) bool { + return a == b || math.IsNaN(a) && math.IsNaN(b) +} + +func sameComplex128(x, y complex128) bool { + return sameFloat64(real(x), real(y)) && sameFloat64(imag(x), imag(y)) +} + func zsame(x, y []complex128) bool { if len(x) != len(y) { return false } for i, v := range x { w := y[i] - if math.IsNaN(real(v)) && math.IsNaN(imag(v)) && math.IsNaN(real(w)) && math.IsNaN(imag(w)) { + if !sameComplex128(v, w) { + return false + } + } + return true +} + +// zEqualApprox returns whether vectors x and y with stride inc +// are approximately equal within tol. Elements at non-strided +// positions must be same in both x and y. +func zEqualApprox(x, y []complex128, inc int, tol float64) bool { + if len(x) != len(y) { + return false + } + if inc < 0 { + inc = -inc + } + for i, v := range x { + w := y[i] + if i%inc == 0 { + if cmplx.Abs(v-w) > tol { + return false + } continue } - if math.IsNaN(real(v)) && math.IsNaN(real(w)) && imag(v) == imag(w) { - continue - } - if math.IsNaN(imag(v)) && math.IsNaN(imag(w)) && real(v) == real(w) { - continue - } - if v != w { + if !sameComplex128(v, w) { return false } }