all: use scalar.Same instead of local same function

This commit is contained in:
Dan Kortschak
2020-08-04 14:18:47 +09:30
parent da72779e7a
commit 0e6fb8d22a
11 changed files with 46 additions and 74 deletions

View File

@@ -33,7 +33,7 @@ func areSlicesSame(t *testing.T, truth, comp []float64, str string) {
ok := len(truth) == len(comp)
if ok {
for i, a := range truth {
if !scalar.EqualWithinAbsOrRel(a, comp[i], EqTolerance, EqTolerance) && !same(a, comp[i]) {
if !scalar.EqualWithinAbsOrRel(a, comp[i], EqTolerance, EqTolerance) && !scalar.Same(a, comp[i]) {
ok = false
break
}
@@ -44,10 +44,6 @@ func areSlicesSame(t *testing.T, truth, comp []float64, str string) {
}
}
func same(a, b float64) bool {
return a == b || (math.IsNaN(a) && math.IsNaN(b))
}
func Panics(fun func()) (b bool) {
defer func() {
err := recover()
@@ -735,7 +731,7 @@ func TestMaxAndIdx(t *testing.T) {
t.Errorf("Wrong index "+test.desc+": got:%d want:%d", ind, test.wantIdx)
}
val := Max(test.in)
if !same(val, test.wantVal) {
if !scalar.Same(val, test.wantVal) {
t.Errorf("Wrong value "+test.desc+": got:%f want:%f", val, test.wantVal)
}
}
@@ -788,7 +784,7 @@ func TestMinAndIdx(t *testing.T) {
t.Errorf("Wrong index "+test.desc+": got:%d want:%d", ind, test.wantIdx)
}
val := Min(test.in)
if !same(val, test.wantVal) {
if !scalar.Same(val, test.wantVal) {
t.Errorf("Wrong value "+test.desc+": got:%f want:%f", val, test.wantVal)
}
}

View File

@@ -8,13 +8,13 @@ package testgraph // import "gonum.org/v1/gonum/graph/testgraph"
import (
"fmt"
"math"
"reflect"
"sort"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats/scalar"
"gonum.org/v1/gonum/graph"
"gonum.org/v1/gonum/graph/internal/ordered"
"gonum.org/v1/gonum/graph/internal/set"
@@ -1141,12 +1141,12 @@ func Weight(t *testing.T, b Builder) {
if e != nil {
t.Errorf("missing edge weight for existing edge for test %q: (%v)--(%v)", test.name, x.ID(), y.ID())
}
if !same(w, absent) {
if !scalar.Same(w, absent) {
t.Errorf("unexpected absent weight for test %q: got:%v want:%v", test.name, w, absent)
}
case !multi && x.ID() == y.ID():
if !same(w, self) {
if !scalar.Same(w, self) {
t.Errorf("unexpected self weight for test %q: got:%v want:%v", test.name, w, self)
}
@@ -1196,15 +1196,15 @@ func AdjacencyMatrix(t *testing.T, b Builder) {
w, ok := wg.Weight(x.ID(), y.ID())
switch {
case !ok:
if !same(m.At(i, j), absent) {
if !scalar.Same(m.At(i, j), absent) {
t.Errorf("weight mismatch for test %q: (%v)--(%v) matrix=%v graph=%v", test.name, x.ID(), y.ID(), m.At(i, j), w)
}
case x.ID() == y.ID():
if !same(m.At(i, j), self) {
if !scalar.Same(m.At(i, j), self) {
t.Errorf("weight mismatch for test %q: (%v)--(%v) matrix=%v graph=%v", test.name, x.ID(), y.ID(), m.At(i, j), w)
}
default:
if !same(m.At(i, j), w) {
if !scalar.Same(m.At(i, j), w) {
t.Errorf("weight mismatch for test %q: (%v)--(%v) matrix=%v graph=%v", test.name, x.ID(), y.ID(), m.At(i, j), w)
}
}
@@ -2037,10 +2037,6 @@ type edge struct {
f, t, id int64
}
func same(a, b float64) bool {
return (math.IsNaN(a) && math.IsNaN(b)) || a == b
}
func panics(fn func()) (ok bool) {
defer func() {
ok = recover() != nil

View File

@@ -91,14 +91,9 @@ func checkValidIncGuard(t *testing.T, vec []complex128, gdVal complex128, inc, g
}
}
// same tests for nan-aware equality.
func same(a, b float64) bool {
return a == b || (math.IsNaN(a) && math.IsNaN(b))
}
// sameApprox tests for nan-aware equality within tolerance.
func sameApprox(a, b, tol float64) bool {
return same(a, b) || scalar.EqualWithinAbsOrRel(a, b, tol, tol)
return scalar.Same(a, b) || scalar.EqualWithinAbsOrRel(a, b, tol, tol)
}
// sameCmplx tests for nan-aware equality.

View File

@@ -91,14 +91,9 @@ func checkValidIncGuard(t *testing.T, vec []complex64, gdVal complex64, inc, gdL
}
}
// same tests for nan-aware equality.
func same(a, b float32) bool {
return a == b || (math32.IsNaN(a) && math32.IsNaN(b))
}
// sameApprox tests for nan-aware equality within tolerance.
func sameApprox(a, b, tol float32) bool {
return same(a, b) || scalar.EqualWithinAbsOrRel(float64(a), float64(b), float64(tol), float64(tol))
return scalar.Same(float64(a), float64(b)) || scalar.EqualWithinAbsOrRel(float64(a), float64(b), float64(tol), float64(tol))
}
// sameCmplx tests for nan-aware equality.

View File

@@ -9,6 +9,7 @@ import (
"math"
"testing"
"gonum.org/v1/gonum/floats/scalar"
. "gonum.org/v1/gonum/internal/asm/f32"
)
@@ -130,7 +131,7 @@ func TestDdotUnitary(t *testing.T) {
xg, yg := guardVector(test.x, xGdVal, xgLn), guardVector(test.y, yGdVal, ygLn)
x, y := xg[xgLn:len(xg)-xgLn], yg[ygLn:len(yg)-ygLn]
res := DdotUnitary(x, y)
if !same64(res, test.dWant) {
if !scalar.Same(res, test.dWant) {
t.Errorf(msgRes, prefix, res, test.dWant)
}
if !isValidGuard(xg, xGdVal, xgLn) {
@@ -162,7 +163,7 @@ func TestDdotInc(t *testing.T) {
want = test.dWantRev
}
res := DdotInc(x, y, uintptr(test.n), uintptr(inc.x), uintptr(inc.y), uintptr(ix), uintptr(iy))
if !same64(res, want) {
if !scalar.Same(res, want) {
t.Errorf(msgRes, prefix, res, want)
}
checkValidIncGuard(t, xg, xGdVal, inc.x, gdLn)

View File

@@ -30,12 +30,7 @@ func sameApprox(x, y, tol float32) bool {
}
func same(x, y float32) bool {
a, b := float64(x), float64(y)
return a == b || (math.IsNaN(a) && math.IsNaN(b))
}
func same64(a, b float64) bool {
return a == b || (math.IsNaN(a) && math.IsNaN(b))
return scalar.Same(float64(x), float64(y))
}
// sameStrided returns true if the strided vector x contains elements of the

View File

@@ -63,7 +63,7 @@ func equalStrided(ref, x []float64, inc int) bool {
inc = -inc
}
for i, v := range ref {
if !same(x[i*inc], v) {
if !scalar.Same(x[i*inc], v) {
return false
}
}
@@ -99,7 +99,7 @@ func guardVector(vec []float64, gdVal float64, gdLn int) (guarded []float64) {
// isValidGuard will test for violated guards, generated by guardVector.
func isValidGuard(vec []float64, gdVal float64, gdLn int) bool {
for i := 0; i < gdLn; i++ {
if !same(vec[i], gdVal) || !same(vec[len(vec)-1-i], gdVal) {
if !scalar.Same(vec[i], gdVal) || !scalar.Same(vec[len(vec)-1-i], gdVal) {
return false
}
}
@@ -129,7 +129,7 @@ func checkValidIncGuard(t *testing.T, vec []float64, gdVal float64, inc, gdLen i
srcLn := len(vec) - 2*gdLen
for i := range vec {
switch {
case same(vec[i], gdVal):
case scalar.Same(vec[i], gdVal):
// Correct value
case (i-gdLen)%inc == 0 && (i-gdLen)/inc < len(vec):
// Ignore input values
@@ -143,14 +143,9 @@ func checkValidIncGuard(t *testing.T, vec []float64, gdVal float64, inc, gdLen i
}
}
// same tests for nan-aware equality.
func same(a, b float64) bool {
return a == b || (math.IsNaN(a) && math.IsNaN(b))
}
// sameApprox tests for nan-aware equality within tolerance.
func sameApprox(a, b, tol float64) bool {
return same(a, b) || scalar.EqualWithinAbsOrRel(a, b, tol, tol)
return scalar.Same(a, b) || scalar.EqualWithinAbsOrRel(a, b, tol, tol)
}
var ( // Offset sets for testing alignment handling in Unitary assembly functions.

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"testing"
"gonum.org/v1/gonum/floats/scalar"
. "gonum.org/v1/gonum/internal/asm/f64"
)
@@ -114,7 +115,7 @@ func TestAxpyUnitary(t *testing.T) {
x, y := xg[xgLn:len(xg)-xgLn], yg[ygLn:len(yg)-ygLn]
AxpyUnitary(test.alpha, x, y)
for i := range test.want {
if !same(y[i], test.want[i]) {
if !scalar.Same(y[i], test.want[i]) {
t.Errorf(msgVal, prefix, i, y[i], test.want[i])
}
}
@@ -146,7 +147,7 @@ func TestAxpyUnitaryTo(t *testing.T) {
AxpyUnitaryTo(dst, test.alpha, x, y)
for i := range test.want {
if !same(dst[i], test.want[i]) {
if !scalar.Same(dst[i], test.want[i]) {
t.Errorf(msgVal, prefix, i, dst[i], test.want[i])
}
}
@@ -198,7 +199,7 @@ func TestAxpyInc(t *testing.T) {
inc.y = -inc.y
}
for i := range want {
if !same(y[i*inc.y], want[i]) {
if !scalar.Same(y[i*inc.y], want[i]) {
t.Errorf(msgVal, prefix, i, y[iy+i*inc.y], want[i])
}
}
@@ -252,7 +253,7 @@ func TestAxpyIncTo(t *testing.T) {
inc.dst = -inc.dst
}
for i := range want {
if !same(dst[i*inc.dst], want[iW+i*incW]) {
if !scalar.Same(dst[i*inc.dst], want[iW+i*incW]) {
t.Errorf(msgVal, prefix, i, dst[i*inc.dst], want[iW+i*incW])
}
}

View File

@@ -10,6 +10,7 @@ import (
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats/scalar"
. "gonum.org/v1/gonum/internal/asm/f64"
)
@@ -87,7 +88,7 @@ func TestScalUnitary(t *testing.T) {
ScalUnitary(test.alpha, x)
for i := range test.want {
if !same(x[i], test.want[i]) {
if !scalar.Same(x[i], test.want[i]) {
t.Errorf(msgVal, prefix, i, x[i], test.want[i])
}
}
@@ -113,7 +114,7 @@ func TestScalUnitaryTo(t *testing.T) {
ScalUnitaryTo(dst, test.alpha, x)
for i := range test.want {
if !same(dst[i], test.want[i]) {
if !scalar.Same(dst[i], test.want[i]) {
t.Errorf(msgVal, prefix, i, dst[i], test.want[i])
}
}
@@ -143,7 +144,7 @@ func TestScalInc(t *testing.T) {
ScalInc(test.alpha, x, uintptr(n), uintptr(incX))
for i := range test.want {
if !same(x[i*incX], test.want[i]) {
if !scalar.Same(x[i*incX], test.want[i]) {
t.Errorf(msgVal, prefix, i, x[i*incX], test.want[i])
}
}
@@ -167,7 +168,7 @@ func TestScalIncTo(t *testing.T) {
ScalIncTo(dst, uintptr(inc.y), test.alpha, x, uintptr(n), uintptr(inc.x))
for i := range test.want {
if !same(dst[i*inc.y], test.want[i]) {
if !scalar.Same(dst[i*inc.y], test.want[i]) {
t.Errorf(msgVal, prefix, i, dst[i*inc.y], test.want[i])
}
}

View File

@@ -7,6 +7,7 @@ package f64_test
import (
"testing"
"gonum.org/v1/gonum/floats/scalar"
. "gonum.org/v1/gonum/internal/asm/f64"
)
@@ -28,7 +29,7 @@ func TestL1Norm(t *testing.T) {
v.x = guardVector(v.x, src_gd, g_ln)
src := v.x[g_ln : len(v.x)-g_ln]
ret := L1Norm(src)
if !same(ret, v.want) {
if !scalar.Same(ret, v.want) {
t.Errorf("Test %d L1Norm error Got: %f Expected: %f", j, ret, v.want)
}
if !isValidGuard(v.x, src_gd, g_ln) {
@@ -56,7 +57,7 @@ func TestL1NormInc(t *testing.T) {
v.x = guardIncVector(v.x, src_gd, v.inc, g_ln)
src := v.x[g_ln : len(v.x)-g_ln]
ret := L1NormInc(src, ln, v.inc)
if !same(ret, v.want) {
if !scalar.Same(ret, v.want) {
t.Errorf("Test %d L1NormInc error Got: %f Expected: %f", j, ret, v.want)
}
checkValidIncGuard(t, v.x, src_gd, v.inc, g_ln)
@@ -109,7 +110,7 @@ func TestAdd(t *testing.T) {
src, dst := v.src[sg_ln:len(v.src)-sg_ln], v.dst[dg_ln:len(v.dst)-dg_ln]
Add(dst, src)
for i := range v.expect {
if !same(dst[i], v.expect[i]) {
if !scalar.Same(dst[i], v.expect[i]) {
t.Errorf("Test %d Add error at %d Got: %v Expected: %v", j, i, dst[i], v.expect[i])
}
}
@@ -159,7 +160,7 @@ func TestAddConst(t *testing.T) {
src := v.src[g_ln : len(v.src)-g_ln]
AddConst(v.alpha, src)
for i := range v.expect {
if !same(src[i], v.expect[i]) {
if !scalar.Same(src[i], v.expect[i]) {
t.Errorf("Test %d AddConst error at %d Got: %v Expected: %v", j, i, src[i], v.expect[i])
}
}
@@ -225,10 +226,10 @@ func TestCumSum(t *testing.T) {
src, dst := v.src[g_ln:len(v.src)-g_ln], v.dst[g_ln:len(v.dst)-g_ln]
ret := CumSum(dst, src)
for i := range v.expect {
if !same(ret[i], v.expect[i]) {
if !scalar.Same(ret[i], v.expect[i]) {
t.Errorf("Test %d CumSum error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i])
}
if !same(ret[i], dst[i]) {
if !scalar.Same(ret[i], dst[i]) {
t.Errorf("Test %d CumSum ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i])
}
}
@@ -297,10 +298,10 @@ func TestCumProd(t *testing.T) {
src, dst := v.src[sg_ln:len(v.src)-sg_ln], v.dst[dg_ln:len(v.dst)-dg_ln]
ret := CumProd(dst, src)
for i := range v.expect {
if !same(ret[i], v.expect[i]) {
if !scalar.Same(ret[i], v.expect[i]) {
t.Errorf("Test %d CumProd error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i])
}
if !same(ret[i], dst[i]) {
if !scalar.Same(ret[i], dst[i]) {
t.Errorf("Test %d CumProd ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i])
}
}
@@ -364,7 +365,7 @@ func TestDiv(t *testing.T) {
src, dst := v.src[sg_ln:len(v.src)-sg_ln], v.dst[dg_ln:len(v.dst)-dg_ln]
Div(dst, src)
for i := range v.expect {
if !same(dst[i], v.expect[i]) {
if !scalar.Same(dst[i], v.expect[i]) {
t.Errorf("Test %d Div error at %d Got: %v Expected: %v", j, i, dst[i], v.expect[i])
}
}
@@ -432,10 +433,10 @@ func TestDivTo(t *testing.T) {
dst := v.dst[xg_ln : len(v.dst)-xg_ln]
ret := DivTo(dst, x, y)
for i := range v.expect {
if !same(ret[i], v.expect[i]) {
if !scalar.Same(ret[i], v.expect[i]) {
t.Errorf("Test %d DivTo error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i])
}
if !same(ret[i], dst[i]) {
if !scalar.Same(ret[i], dst[i]) {
t.Errorf("Test %d DivTo ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i])
}
}
@@ -502,7 +503,7 @@ func TestL1Dist(t *testing.T) {
v.s, v.t = guardVector(v.s, s_gd, sg_ln), guardVector(v.t, t_gd, tg_ln)
s_lc, t_lc := v.s[sg_ln:len(v.s)-sg_ln], v.t[tg_ln:len(v.t)-tg_ln]
ret := L1Dist(s_lc, t_lc)
if !same(ret, v.expect) {
if !scalar.Same(ret, v.expect) {
t.Errorf("Test %d L1Dist error Got: %f Expected: %f", j, ret, v.expect)
}
if !isValidGuard(v.s, s_gd, sg_ln) {
@@ -565,7 +566,7 @@ func TestLinfDist(t *testing.T) {
v.s, v.t = guardVector(v.s, s_gd, sg_ln), guardVector(v.t, t_gd, tg_ln)
s_lc, t_lc := v.s[sg_ln:len(v.s)-sg_ln], v.t[tg_ln:len(v.t)-tg_ln]
ret := LinfDist(s_lc, t_lc)
if !same(ret, v.expect) {
if !scalar.Same(ret, v.expect) {
t.Errorf("Test %d LinfDist error Got: %f Expected: %f", j, ret, v.expect)
}
if !isValidGuard(v.s, s_gd, sg_ln) {
@@ -628,7 +629,7 @@ func TestSum(t *testing.T) {
gsrc := guardVector(v.src, srcGd, gdLn)
src := gsrc[gdLn : len(gsrc)-gdLn]
ret := Sum(src)
if !same(ret, v.expect) {
if !scalar.Same(ret, v.expect) {
t.Errorf("Test %d Sum error Got: %v Expected: %v", j, ret, v.expect)
}
if !isValidGuard(gsrc, srcGd, gdLn) {

View File

@@ -111,10 +111,10 @@ func TestUniformScore(t *testing.T) {
{1.001, math.NaN(), math.NaN()},
} {
score := u.Score(nil, test.x)
if !same(score[0], test.wantMin) {
if !scalar.Same(score[0], test.wantMin) {
t.Errorf("Score[0] mismatch for at %g: got %v, want %g", test.x, score[0], test.wantMin)
}
if !same(score[1], test.wantMax) {
if !scalar.Same(score[1], test.wantMax) {
t.Errorf("Score[1] mismatch for at %g: got %v, want %g", test.x, score[1], test.wantMax)
}
}
@@ -135,7 +135,3 @@ func TestUniformScoreInput(t *testing.T) {
}
}
}
func same(a, b float64) bool {
return a == b || (math.IsNaN(a) && math.IsNaN(b))
}