all: use cscalar.Same instead of local same function

This commit is contained in:
Dan Kortschak
2020-08-04 14:46:15 +09:30
parent 4f194cd672
commit dba48453fd
6 changed files with 29 additions and 37 deletions

View File

@@ -11,7 +11,7 @@ import (
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/cmplxs"
"gonum.org/v1/gonum/cmplxs/cscalar"
"gonum.org/v1/gonum/floats/scalar"
)
@@ -47,7 +47,7 @@ func guardVector(vec []complex128, gdVal complex128, gdLn int) (guarded []comple
// isValidGuard will test for violated guards, generated by guardVector.
func isValidGuard(vec []complex128, gdVal complex128, gdLn int) bool {
for i := 0; i < gdLn; i++ {
if !sameCmplx(vec[i], gdVal) || !sameCmplx(vec[len(vec)-1-i], gdVal) {
if !cscalar.Same(vec[i], gdVal) || !cscalar.Same(vec[len(vec)-1-i], gdVal) {
return false
}
}
@@ -77,7 +77,7 @@ func checkValidIncGuard(t *testing.T, vec []complex128, gdVal complex128, inc, g
srcLn := len(vec) - 2*gdLen
for i := range vec {
switch {
case sameCmplx(vec[i], gdVal):
case cscalar.Same(vec[i], gdVal):
// Correct value
case (i-gdLen)%inc == 0 && (i-gdLen)/inc < len(vec):
// Ignore input values
@@ -96,14 +96,9 @@ func sameApprox(a, b, tol float64) bool {
return scalar.Same(a, b) || scalar.EqualWithinAbsOrRel(a, b, tol, tol)
}
// sameCmplx tests for nan-aware equality.
func sameCmplx(a, b complex128) bool {
return a == b || (cmplx.IsNaN(a) && cmplx.IsNaN(b))
}
// sameCmplxApprox tests for nan-aware equality within tolerance.
func sameCmplxApprox(a, b complex128, tol float64) bool {
return sameCmplx(a, b) || cmplxs.EqualWithinAbsOrRel(a, b, tol, tol)
return cscalar.Same(a, b) || cscalar.EqualWithinAbsOrRel(a, b, tol, tol)
}
var ( // Offset sets for testing alignment handling in Unitary assembly functions.

View File

@@ -7,6 +7,8 @@ package c128
import (
"fmt"
"testing"
"gonum.org/v1/gonum/cmplxs/cscalar"
)
var dotTests = []struct {
@@ -75,7 +77,7 @@ func TestDotcUnitary(t *testing.T) {
xg, yg := guardVector(test.x, gd, xgLn), guardVector(test.y, gd, ygLn)
x, y := xg[xgLn:len(xg)-xgLn], yg[ygLn:len(yg)-ygLn]
res := DotcUnitary(x, y)
if !same(res, test.wantc) {
if !cscalar.Same(res, test.wantc) {
t.Errorf(msgVal, prefix, i, res, test.wantc)
}
if !isValidGuard(xg, gd, xgLn) {
@@ -107,7 +109,7 @@ func TestDotcInc(t *testing.T) {
if inc.x*inc.y > 0 {
want = test.wantc
}
if !same(res, want) {
if !cscalar.Same(res, want) {
t.Errorf(msgVal, prefix, i, res, want)
t.Error(x, y)
}
@@ -126,7 +128,7 @@ func TestDotuUnitary(t *testing.T) {
xg, yg := guardVector(test.x, gd, xgLn), guardVector(test.y, gd, ygLn)
x, y := xg[xgLn:len(xg)-xgLn], yg[ygLn:len(yg)-ygLn]
res := DotuUnitary(x, y)
if !same(res, test.wantu) {
if !cscalar.Same(res, test.wantu) {
t.Errorf(msgVal, prefix, i, res, test.wantu)
}
if !isValidGuard(xg, gd, xgLn) {
@@ -158,7 +160,7 @@ func TestDotuInc(t *testing.T) {
if inc.x*inc.y > 0 {
want = test.wantu
}
if !same(res, want) {
if !cscalar.Same(res, want) {
t.Errorf(msgVal, prefix, i, res, want)
}
checkValidIncGuard(t, xg, gd, inc.x, gdLn)

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"testing"
"gonum.org/v1/gonum/cmplxs/cscalar"
. "gonum.org/v1/gonum/internal/asm/c128"
)
@@ -57,7 +58,7 @@ func TestAdd(t *testing.T) {
src, dst := v.src[sg_ln:len(v.src)-sg_ln], v.dst[dg_ln:len(v.dst)-dg_ln]
Add(dst, src)
for i := range v.expect {
if !sameCmplx(dst[i], v.expect[i]) {
if !cscalar.Same(dst[i], v.expect[i]) {
t.Errorf("Test %d Add error at %d Got: %v Expected: %v", j, i, dst[i], v.expect[i])
}
}
@@ -107,7 +108,7 @@ func TestAddConst(t *testing.T) {
src := v.src[g_ln : len(v.src)-g_ln]
AddConst(v.alpha, src)
for i := range v.expect {
if !sameCmplx(src[i], v.expect[i]) {
if !cscalar.Same(src[i], v.expect[i]) {
t.Errorf("Test %d AddConst error at %d Got: %v Expected: %v", j, i, src[i], v.expect[i])
}
}
@@ -334,10 +335,10 @@ func TestCumSum(t *testing.T) {
src, dst := v.src[g_ln:len(v.src)-g_ln], v.dst[g_ln:len(v.dst)-g_ln]
ret := CumSum(dst, src)
for i := range v.expect {
if !sameCmplx(ret[i], v.expect[i]) {
if !cscalar.Same(ret[i], v.expect[i]) {
t.Errorf("Test %d CumSum error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i])
}
if !sameCmplx(ret[i], dst[i]) {
if !cscalar.Same(ret[i], dst[i]) {
t.Errorf("Test %d CumSum ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i])
}
}
@@ -406,10 +407,10 @@ func TestCumProd(t *testing.T) {
src, dst := v.src[sg_ln:len(v.src)-sg_ln], v.dst[dg_ln:len(v.dst)-dg_ln]
ret := CumProd(dst, src)
for i := range v.expect {
if !sameCmplx(ret[i], v.expect[i]) {
if !cscalar.Same(ret[i], v.expect[i]) {
t.Errorf("Test %d CumProd error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i])
}
if !sameCmplx(ret[i], dst[i]) {
if !cscalar.Same(ret[i], dst[i]) {
t.Errorf("Test %d CumProd ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i])
}
}
@@ -548,7 +549,7 @@ func TestDivTo(t *testing.T) {
if !sameCmplxApprox(ret[i], v.expect[i], tol) {
t.Errorf("Test %d DivTo error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i])
}
if !sameCmplx(ret[i], dst[i]) {
if !cscalar.Same(ret[i], dst[i]) {
t.Errorf("Test %d DivTo ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i])
}
}
@@ -613,7 +614,7 @@ func TestDscalUnitary(t *testing.T) {
DscalUnitary(test.alpha, x)
for i := range test.want {
if !sameCmplx(x[i], test.want[i]) {
if !cscalar.Same(x[i], test.want[i]) {
t.Errorf(msgVal, prefix, i, x[i], test.want[i])
}
}
@@ -637,7 +638,7 @@ func TestDscalInc(t *testing.T) {
DscalInc(test.alpha, x, uintptr(n), uintptr(incX))
for i := range test.want {
if !sameCmplx(x[i*incX], test.want[i]) {
if !cscalar.Same(x[i*incX], test.want[i]) {
t.Errorf(msgVal, prefix, i, x[i*incX], test.want[i])
}
}
@@ -695,7 +696,7 @@ func TestScalUnitary(t *testing.T) {
ScalUnitary(test.alpha, x)
for i := range test.want {
if !sameCmplx(x[i], test.want[i]) {
if !cscalar.Same(x[i], test.want[i]) {
t.Errorf(msgVal, prefix, i, x[i], test.want[i])
}
}
@@ -719,7 +720,7 @@ func TestScalInc(t *testing.T) {
ScalInc(test.alpha, x, uintptr(n), uintptr(inc))
for i := range test.want {
if !sameCmplx(x[i*inc], test.want[i]) {
if !cscalar.Same(x[i*inc], test.want[i]) {
t.Errorf(msgVal, prefix, i, x[i*inc], test.want[i])
}
}
@@ -779,7 +780,7 @@ func TestSum(t *testing.T) {
gsrc := guardVector(v.src, srcGd, gdLn)
src := gsrc[gdLn : len(gsrc)-gdLn]
ret := Sum(src)
if !sameCmplx(ret, v.expect) {
if !cscalar.Same(ret, v.expect) {
t.Errorf("Test %d Sum error Got: %v Expected: %v", j, ret, v.expect)
}
if !isValidGuard(gsrc, srcGd, gdLn) {

View File

@@ -19,13 +19,6 @@ var (
benchSink complex128
)
func same(x, y complex128) bool {
return (x == y ||
math.IsNaN(real(x)) && math.IsNaN(real(y)) && imag(x) == imag(y) ||
math.IsNaN(imag(y)) && math.IsNaN(imag(x)) && real(y) == real(x) ||
math.IsNaN(real(x)) && math.IsNaN(real(y)) && math.IsNaN(imag(y)) && math.IsNaN(imag(x)))
}
func guardVector(vec []complex128, guard_val complex128, guard_len int) (guarded []complex128) {
guarded = make([]complex128, len(vec)+guard_len*2)
copy(guarded[guard_len:], vec)

View File

@@ -9,7 +9,7 @@ import (
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/cmplxs"
"gonum.org/v1/gonum/cmplxs/cscalar"
"gonum.org/v1/gonum/floats/scalar"
"gonum.org/v1/gonum/internal/cmplx64"
"gonum.org/v1/gonum/internal/math32"
@@ -98,12 +98,12 @@ func sameApprox(a, b, tol float32) bool {
// sameCmplx tests for nan-aware equality.
func sameCmplx(a, b complex64) bool {
return a == b || (cmplx64.IsNaN(a) && cmplx64.IsNaN(b))
return cscalar.Same(complex128(a), complex128(b))
}
// sameCmplxApprox tests for nan-aware equality within tolerance.
func sameCmplxApprox(a, b complex64, tol float32) bool {
return sameCmplx(a, b) || cmplxs.EqualWithinAbsOrRel(complex128(a), complex128(b), float64(tol), float64(tol))
return sameCmplx(a, b) || cscalar.EqualWithinAbsOrRel(complex128(a), complex128(b), float64(tol), float64(tol))
}
var ( // Offset sets for testing alignment handling in Unitary assembly functions.

View File

@@ -5,12 +5,13 @@
package c64
import (
"math/cmplx"
"testing"
"gonum.org/v1/gonum/cmplxs/cscalar"
)
func same(x, y complex64) bool {
return x == y || (cmplx.IsNaN(complex128(x)) && cmplx.IsNaN(complex128(y)))
return cscalar.Same(complex128(x), complex128(y))
}
func guardVector(vec []complex64, gdVal complex64, gdLen int) (guarded []complex64) {