all: use scalar.Same instead of local same function

This commit is contained in:
Dan Kortschak
2020-08-04 14:18:47 +09:30
parent da72779e7a
commit 0e6fb8d22a
11 changed files with 46 additions and 74 deletions

View File

@@ -33,7 +33,7 @@ func areSlicesSame(t *testing.T, truth, comp []float64, str string) {
ok := len(truth) == len(comp) ok := len(truth) == len(comp)
if ok { if ok {
for i, a := range truth { for i, a := range truth {
if !scalar.EqualWithinAbsOrRel(a, comp[i], EqTolerance, EqTolerance) && !same(a, comp[i]) { if !scalar.EqualWithinAbsOrRel(a, comp[i], EqTolerance, EqTolerance) && !scalar.Same(a, comp[i]) {
ok = false ok = false
break break
} }
@@ -44,10 +44,6 @@ func areSlicesSame(t *testing.T, truth, comp []float64, str string) {
} }
} }
func same(a, b float64) bool {
return a == b || (math.IsNaN(a) && math.IsNaN(b))
}
func Panics(fun func()) (b bool) { func Panics(fun func()) (b bool) {
defer func() { defer func() {
err := recover() err := recover()
@@ -735,7 +731,7 @@ func TestMaxAndIdx(t *testing.T) {
t.Errorf("Wrong index "+test.desc+": got:%d want:%d", ind, test.wantIdx) t.Errorf("Wrong index "+test.desc+": got:%d want:%d", ind, test.wantIdx)
} }
val := Max(test.in) val := Max(test.in)
if !same(val, test.wantVal) { if !scalar.Same(val, test.wantVal) {
t.Errorf("Wrong value "+test.desc+": got:%f want:%f", val, test.wantVal) t.Errorf("Wrong value "+test.desc+": got:%f want:%f", val, test.wantVal)
} }
} }
@@ -788,7 +784,7 @@ func TestMinAndIdx(t *testing.T) {
t.Errorf("Wrong index "+test.desc+": got:%d want:%d", ind, test.wantIdx) t.Errorf("Wrong index "+test.desc+": got:%d want:%d", ind, test.wantIdx)
} }
val := Min(test.in) val := Min(test.in)
if !same(val, test.wantVal) { if !scalar.Same(val, test.wantVal) {
t.Errorf("Wrong value "+test.desc+": got:%f want:%f", val, test.wantVal) t.Errorf("Wrong value "+test.desc+": got:%f want:%f", val, test.wantVal)
} }
} }

View File

@@ -8,13 +8,13 @@ package testgraph // import "gonum.org/v1/gonum/graph/testgraph"
import ( import (
"fmt" "fmt"
"math"
"reflect" "reflect"
"sort" "sort"
"testing" "testing"
"golang.org/x/exp/rand" "golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats/scalar"
"gonum.org/v1/gonum/graph" "gonum.org/v1/gonum/graph"
"gonum.org/v1/gonum/graph/internal/ordered" "gonum.org/v1/gonum/graph/internal/ordered"
"gonum.org/v1/gonum/graph/internal/set" "gonum.org/v1/gonum/graph/internal/set"
@@ -1141,12 +1141,12 @@ func Weight(t *testing.T, b Builder) {
if e != nil { if e != nil {
t.Errorf("missing edge weight for existing edge for test %q: (%v)--(%v)", test.name, x.ID(), y.ID()) t.Errorf("missing edge weight for existing edge for test %q: (%v)--(%v)", test.name, x.ID(), y.ID())
} }
if !same(w, absent) { if !scalar.Same(w, absent) {
t.Errorf("unexpected absent weight for test %q: got:%v want:%v", test.name, w, absent) t.Errorf("unexpected absent weight for test %q: got:%v want:%v", test.name, w, absent)
} }
case !multi && x.ID() == y.ID(): case !multi && x.ID() == y.ID():
if !same(w, self) { if !scalar.Same(w, self) {
t.Errorf("unexpected self weight for test %q: got:%v want:%v", test.name, w, self) t.Errorf("unexpected self weight for test %q: got:%v want:%v", test.name, w, self)
} }
@@ -1196,15 +1196,15 @@ func AdjacencyMatrix(t *testing.T, b Builder) {
w, ok := wg.Weight(x.ID(), y.ID()) w, ok := wg.Weight(x.ID(), y.ID())
switch { switch {
case !ok: case !ok:
if !same(m.At(i, j), absent) { if !scalar.Same(m.At(i, j), absent) {
t.Errorf("weight mismatch for test %q: (%v)--(%v) matrix=%v graph=%v", test.name, x.ID(), y.ID(), m.At(i, j), w) t.Errorf("weight mismatch for test %q: (%v)--(%v) matrix=%v graph=%v", test.name, x.ID(), y.ID(), m.At(i, j), w)
} }
case x.ID() == y.ID(): case x.ID() == y.ID():
if !same(m.At(i, j), self) { if !scalar.Same(m.At(i, j), self) {
t.Errorf("weight mismatch for test %q: (%v)--(%v) matrix=%v graph=%v", test.name, x.ID(), y.ID(), m.At(i, j), w) t.Errorf("weight mismatch for test %q: (%v)--(%v) matrix=%v graph=%v", test.name, x.ID(), y.ID(), m.At(i, j), w)
} }
default: default:
if !same(m.At(i, j), w) { if !scalar.Same(m.At(i, j), w) {
t.Errorf("weight mismatch for test %q: (%v)--(%v) matrix=%v graph=%v", test.name, x.ID(), y.ID(), m.At(i, j), w) t.Errorf("weight mismatch for test %q: (%v)--(%v) matrix=%v graph=%v", test.name, x.ID(), y.ID(), m.At(i, j), w)
} }
} }
@@ -2037,10 +2037,6 @@ type edge struct {
f, t, id int64 f, t, id int64
} }
func same(a, b float64) bool {
return (math.IsNaN(a) && math.IsNaN(b)) || a == b
}
func panics(fn func()) (ok bool) { func panics(fn func()) (ok bool) {
defer func() { defer func() {
ok = recover() != nil ok = recover() != nil

View File

@@ -91,14 +91,9 @@ func checkValidIncGuard(t *testing.T, vec []complex128, gdVal complex128, inc, g
} }
} }
// same tests for nan-aware equality.
func same(a, b float64) bool {
return a == b || (math.IsNaN(a) && math.IsNaN(b))
}
// sameApprox tests for nan-aware equality within tolerance. // sameApprox tests for nan-aware equality within tolerance.
func sameApprox(a, b, tol float64) bool { func sameApprox(a, b, tol float64) bool {
return same(a, b) || scalar.EqualWithinAbsOrRel(a, b, tol, tol) return scalar.Same(a, b) || scalar.EqualWithinAbsOrRel(a, b, tol, tol)
} }
// sameCmplx tests for nan-aware equality. // sameCmplx tests for nan-aware equality.

View File

@@ -91,14 +91,9 @@ func checkValidIncGuard(t *testing.T, vec []complex64, gdVal complex64, inc, gdL
} }
} }
// same tests for nan-aware equality.
func same(a, b float32) bool {
return a == b || (math32.IsNaN(a) && math32.IsNaN(b))
}
// sameApprox tests for nan-aware equality within tolerance. // sameApprox tests for nan-aware equality within tolerance.
func sameApprox(a, b, tol float32) bool { func sameApprox(a, b, tol float32) bool {
return same(a, b) || scalar.EqualWithinAbsOrRel(float64(a), float64(b), float64(tol), float64(tol)) return scalar.Same(float64(a), float64(b)) || scalar.EqualWithinAbsOrRel(float64(a), float64(b), float64(tol), float64(tol))
} }
// sameCmplx tests for nan-aware equality. // sameCmplx tests for nan-aware equality.

View File

@@ -9,6 +9,7 @@ import (
"math" "math"
"testing" "testing"
"gonum.org/v1/gonum/floats/scalar"
. "gonum.org/v1/gonum/internal/asm/f32" . "gonum.org/v1/gonum/internal/asm/f32"
) )
@@ -130,7 +131,7 @@ func TestDdotUnitary(t *testing.T) {
xg, yg := guardVector(test.x, xGdVal, xgLn), guardVector(test.y, yGdVal, ygLn) xg, yg := guardVector(test.x, xGdVal, xgLn), guardVector(test.y, yGdVal, ygLn)
x, y := xg[xgLn:len(xg)-xgLn], yg[ygLn:len(yg)-ygLn] x, y := xg[xgLn:len(xg)-xgLn], yg[ygLn:len(yg)-ygLn]
res := DdotUnitary(x, y) res := DdotUnitary(x, y)
if !same64(res, test.dWant) { if !scalar.Same(res, test.dWant) {
t.Errorf(msgRes, prefix, res, test.dWant) t.Errorf(msgRes, prefix, res, test.dWant)
} }
if !isValidGuard(xg, xGdVal, xgLn) { if !isValidGuard(xg, xGdVal, xgLn) {
@@ -162,7 +163,7 @@ func TestDdotInc(t *testing.T) {
want = test.dWantRev want = test.dWantRev
} }
res := DdotInc(x, y, uintptr(test.n), uintptr(inc.x), uintptr(inc.y), uintptr(ix), uintptr(iy)) res := DdotInc(x, y, uintptr(test.n), uintptr(inc.x), uintptr(inc.y), uintptr(ix), uintptr(iy))
if !same64(res, want) { if !scalar.Same(res, want) {
t.Errorf(msgRes, prefix, res, want) t.Errorf(msgRes, prefix, res, want)
} }
checkValidIncGuard(t, xg, xGdVal, inc.x, gdLn) checkValidIncGuard(t, xg, xGdVal, inc.x, gdLn)

View File

@@ -30,12 +30,7 @@ func sameApprox(x, y, tol float32) bool {
} }
func same(x, y float32) bool { func same(x, y float32) bool {
a, b := float64(x), float64(y) return scalar.Same(float64(x), float64(y))
return a == b || (math.IsNaN(a) && math.IsNaN(b))
}
func same64(a, b float64) bool {
return a == b || (math.IsNaN(a) && math.IsNaN(b))
} }
// sameStrided returns true if the strided vector x contains elements of the // sameStrided returns true if the strided vector x contains elements of the

View File

@@ -63,7 +63,7 @@ func equalStrided(ref, x []float64, inc int) bool {
inc = -inc inc = -inc
} }
for i, v := range ref { for i, v := range ref {
if !same(x[i*inc], v) { if !scalar.Same(x[i*inc], v) {
return false return false
} }
} }
@@ -99,7 +99,7 @@ func guardVector(vec []float64, gdVal float64, gdLn int) (guarded []float64) {
// isValidGuard will test for violated guards, generated by guardVector. // isValidGuard will test for violated guards, generated by guardVector.
func isValidGuard(vec []float64, gdVal float64, gdLn int) bool { func isValidGuard(vec []float64, gdVal float64, gdLn int) bool {
for i := 0; i < gdLn; i++ { for i := 0; i < gdLn; i++ {
if !same(vec[i], gdVal) || !same(vec[len(vec)-1-i], gdVal) { if !scalar.Same(vec[i], gdVal) || !scalar.Same(vec[len(vec)-1-i], gdVal) {
return false return false
} }
} }
@@ -129,7 +129,7 @@ func checkValidIncGuard(t *testing.T, vec []float64, gdVal float64, inc, gdLen i
srcLn := len(vec) - 2*gdLen srcLn := len(vec) - 2*gdLen
for i := range vec { for i := range vec {
switch { switch {
case same(vec[i], gdVal): case scalar.Same(vec[i], gdVal):
// Correct value // Correct value
case (i-gdLen)%inc == 0 && (i-gdLen)/inc < len(vec): case (i-gdLen)%inc == 0 && (i-gdLen)/inc < len(vec):
// Ignore input values // Ignore input values
@@ -143,14 +143,9 @@ func checkValidIncGuard(t *testing.T, vec []float64, gdVal float64, inc, gdLen i
} }
} }
// same tests for nan-aware equality.
func same(a, b float64) bool {
return a == b || (math.IsNaN(a) && math.IsNaN(b))
}
// sameApprox tests for nan-aware equality within tolerance. // sameApprox tests for nan-aware equality within tolerance.
func sameApprox(a, b, tol float64) bool { func sameApprox(a, b, tol float64) bool {
return same(a, b) || scalar.EqualWithinAbsOrRel(a, b, tol, tol) return scalar.Same(a, b) || scalar.EqualWithinAbsOrRel(a, b, tol, tol)
} }
var ( // Offset sets for testing alignment handling in Unitary assembly functions. var ( // Offset sets for testing alignment handling in Unitary assembly functions.

View File

@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"testing" "testing"
"gonum.org/v1/gonum/floats/scalar"
. "gonum.org/v1/gonum/internal/asm/f64" . "gonum.org/v1/gonum/internal/asm/f64"
) )
@@ -114,7 +115,7 @@ func TestAxpyUnitary(t *testing.T) {
x, y := xg[xgLn:len(xg)-xgLn], yg[ygLn:len(yg)-ygLn] x, y := xg[xgLn:len(xg)-xgLn], yg[ygLn:len(yg)-ygLn]
AxpyUnitary(test.alpha, x, y) AxpyUnitary(test.alpha, x, y)
for i := range test.want { for i := range test.want {
if !same(y[i], test.want[i]) { if !scalar.Same(y[i], test.want[i]) {
t.Errorf(msgVal, prefix, i, y[i], test.want[i]) t.Errorf(msgVal, prefix, i, y[i], test.want[i])
} }
} }
@@ -146,7 +147,7 @@ func TestAxpyUnitaryTo(t *testing.T) {
AxpyUnitaryTo(dst, test.alpha, x, y) AxpyUnitaryTo(dst, test.alpha, x, y)
for i := range test.want { for i := range test.want {
if !same(dst[i], test.want[i]) { if !scalar.Same(dst[i], test.want[i]) {
t.Errorf(msgVal, prefix, i, dst[i], test.want[i]) t.Errorf(msgVal, prefix, i, dst[i], test.want[i])
} }
} }
@@ -198,7 +199,7 @@ func TestAxpyInc(t *testing.T) {
inc.y = -inc.y inc.y = -inc.y
} }
for i := range want { for i := range want {
if !same(y[i*inc.y], want[i]) { if !scalar.Same(y[i*inc.y], want[i]) {
t.Errorf(msgVal, prefix, i, y[iy+i*inc.y], want[i]) t.Errorf(msgVal, prefix, i, y[iy+i*inc.y], want[i])
} }
} }
@@ -252,7 +253,7 @@ func TestAxpyIncTo(t *testing.T) {
inc.dst = -inc.dst inc.dst = -inc.dst
} }
for i := range want { for i := range want {
if !same(dst[i*inc.dst], want[iW+i*incW]) { if !scalar.Same(dst[i*inc.dst], want[iW+i*incW]) {
t.Errorf(msgVal, prefix, i, dst[i*inc.dst], want[iW+i*incW]) t.Errorf(msgVal, prefix, i, dst[i*inc.dst], want[iW+i*incW])
} }
} }

View File

@@ -10,6 +10,7 @@ import (
"golang.org/x/exp/rand" "golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats/scalar"
. "gonum.org/v1/gonum/internal/asm/f64" . "gonum.org/v1/gonum/internal/asm/f64"
) )
@@ -87,7 +88,7 @@ func TestScalUnitary(t *testing.T) {
ScalUnitary(test.alpha, x) ScalUnitary(test.alpha, x)
for i := range test.want { for i := range test.want {
if !same(x[i], test.want[i]) { if !scalar.Same(x[i], test.want[i]) {
t.Errorf(msgVal, prefix, i, x[i], test.want[i]) t.Errorf(msgVal, prefix, i, x[i], test.want[i])
} }
} }
@@ -113,7 +114,7 @@ func TestScalUnitaryTo(t *testing.T) {
ScalUnitaryTo(dst, test.alpha, x) ScalUnitaryTo(dst, test.alpha, x)
for i := range test.want { for i := range test.want {
if !same(dst[i], test.want[i]) { if !scalar.Same(dst[i], test.want[i]) {
t.Errorf(msgVal, prefix, i, dst[i], test.want[i]) t.Errorf(msgVal, prefix, i, dst[i], test.want[i])
} }
} }
@@ -143,7 +144,7 @@ func TestScalInc(t *testing.T) {
ScalInc(test.alpha, x, uintptr(n), uintptr(incX)) ScalInc(test.alpha, x, uintptr(n), uintptr(incX))
for i := range test.want { for i := range test.want {
if !same(x[i*incX], test.want[i]) { if !scalar.Same(x[i*incX], test.want[i]) {
t.Errorf(msgVal, prefix, i, x[i*incX], test.want[i]) t.Errorf(msgVal, prefix, i, x[i*incX], test.want[i])
} }
} }
@@ -167,7 +168,7 @@ func TestScalIncTo(t *testing.T) {
ScalIncTo(dst, uintptr(inc.y), test.alpha, x, uintptr(n), uintptr(inc.x)) ScalIncTo(dst, uintptr(inc.y), test.alpha, x, uintptr(n), uintptr(inc.x))
for i := range test.want { for i := range test.want {
if !same(dst[i*inc.y], test.want[i]) { if !scalar.Same(dst[i*inc.y], test.want[i]) {
t.Errorf(msgVal, prefix, i, dst[i*inc.y], test.want[i]) t.Errorf(msgVal, prefix, i, dst[i*inc.y], test.want[i])
} }
} }

View File

@@ -7,6 +7,7 @@ package f64_test
import ( import (
"testing" "testing"
"gonum.org/v1/gonum/floats/scalar"
. "gonum.org/v1/gonum/internal/asm/f64" . "gonum.org/v1/gonum/internal/asm/f64"
) )
@@ -28,7 +29,7 @@ func TestL1Norm(t *testing.T) {
v.x = guardVector(v.x, src_gd, g_ln) v.x = guardVector(v.x, src_gd, g_ln)
src := v.x[g_ln : len(v.x)-g_ln] src := v.x[g_ln : len(v.x)-g_ln]
ret := L1Norm(src) ret := L1Norm(src)
if !same(ret, v.want) { if !scalar.Same(ret, v.want) {
t.Errorf("Test %d L1Norm error Got: %f Expected: %f", j, ret, v.want) t.Errorf("Test %d L1Norm error Got: %f Expected: %f", j, ret, v.want)
} }
if !isValidGuard(v.x, src_gd, g_ln) { if !isValidGuard(v.x, src_gd, g_ln) {
@@ -56,7 +57,7 @@ func TestL1NormInc(t *testing.T) {
v.x = guardIncVector(v.x, src_gd, v.inc, g_ln) v.x = guardIncVector(v.x, src_gd, v.inc, g_ln)
src := v.x[g_ln : len(v.x)-g_ln] src := v.x[g_ln : len(v.x)-g_ln]
ret := L1NormInc(src, ln, v.inc) ret := L1NormInc(src, ln, v.inc)
if !same(ret, v.want) { if !scalar.Same(ret, v.want) {
t.Errorf("Test %d L1NormInc error Got: %f Expected: %f", j, ret, v.want) t.Errorf("Test %d L1NormInc error Got: %f Expected: %f", j, ret, v.want)
} }
checkValidIncGuard(t, v.x, src_gd, v.inc, g_ln) checkValidIncGuard(t, v.x, src_gd, v.inc, g_ln)
@@ -109,7 +110,7 @@ func TestAdd(t *testing.T) {
src, dst := v.src[sg_ln:len(v.src)-sg_ln], v.dst[dg_ln:len(v.dst)-dg_ln] src, dst := v.src[sg_ln:len(v.src)-sg_ln], v.dst[dg_ln:len(v.dst)-dg_ln]
Add(dst, src) Add(dst, src)
for i := range v.expect { for i := range v.expect {
if !same(dst[i], v.expect[i]) { if !scalar.Same(dst[i], v.expect[i]) {
t.Errorf("Test %d Add error at %d Got: %v Expected: %v", j, i, dst[i], v.expect[i]) t.Errorf("Test %d Add error at %d Got: %v Expected: %v", j, i, dst[i], v.expect[i])
} }
} }
@@ -159,7 +160,7 @@ func TestAddConst(t *testing.T) {
src := v.src[g_ln : len(v.src)-g_ln] src := v.src[g_ln : len(v.src)-g_ln]
AddConst(v.alpha, src) AddConst(v.alpha, src)
for i := range v.expect { for i := range v.expect {
if !same(src[i], v.expect[i]) { if !scalar.Same(src[i], v.expect[i]) {
t.Errorf("Test %d AddConst error at %d Got: %v Expected: %v", j, i, src[i], v.expect[i]) t.Errorf("Test %d AddConst error at %d Got: %v Expected: %v", j, i, src[i], v.expect[i])
} }
} }
@@ -225,10 +226,10 @@ func TestCumSum(t *testing.T) {
src, dst := v.src[g_ln:len(v.src)-g_ln], v.dst[g_ln:len(v.dst)-g_ln] src, dst := v.src[g_ln:len(v.src)-g_ln], v.dst[g_ln:len(v.dst)-g_ln]
ret := CumSum(dst, src) ret := CumSum(dst, src)
for i := range v.expect { for i := range v.expect {
if !same(ret[i], v.expect[i]) { if !scalar.Same(ret[i], v.expect[i]) {
t.Errorf("Test %d CumSum error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i]) t.Errorf("Test %d CumSum error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i])
} }
if !same(ret[i], dst[i]) { if !scalar.Same(ret[i], dst[i]) {
t.Errorf("Test %d CumSum ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i]) t.Errorf("Test %d CumSum ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i])
} }
} }
@@ -297,10 +298,10 @@ func TestCumProd(t *testing.T) {
src, dst := v.src[sg_ln:len(v.src)-sg_ln], v.dst[dg_ln:len(v.dst)-dg_ln] src, dst := v.src[sg_ln:len(v.src)-sg_ln], v.dst[dg_ln:len(v.dst)-dg_ln]
ret := CumProd(dst, src) ret := CumProd(dst, src)
for i := range v.expect { for i := range v.expect {
if !same(ret[i], v.expect[i]) { if !scalar.Same(ret[i], v.expect[i]) {
t.Errorf("Test %d CumProd error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i]) t.Errorf("Test %d CumProd error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i])
} }
if !same(ret[i], dst[i]) { if !scalar.Same(ret[i], dst[i]) {
t.Errorf("Test %d CumProd ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i]) t.Errorf("Test %d CumProd ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i])
} }
} }
@@ -364,7 +365,7 @@ func TestDiv(t *testing.T) {
src, dst := v.src[sg_ln:len(v.src)-sg_ln], v.dst[dg_ln:len(v.dst)-dg_ln] src, dst := v.src[sg_ln:len(v.src)-sg_ln], v.dst[dg_ln:len(v.dst)-dg_ln]
Div(dst, src) Div(dst, src)
for i := range v.expect { for i := range v.expect {
if !same(dst[i], v.expect[i]) { if !scalar.Same(dst[i], v.expect[i]) {
t.Errorf("Test %d Div error at %d Got: %v Expected: %v", j, i, dst[i], v.expect[i]) t.Errorf("Test %d Div error at %d Got: %v Expected: %v", j, i, dst[i], v.expect[i])
} }
} }
@@ -432,10 +433,10 @@ func TestDivTo(t *testing.T) {
dst := v.dst[xg_ln : len(v.dst)-xg_ln] dst := v.dst[xg_ln : len(v.dst)-xg_ln]
ret := DivTo(dst, x, y) ret := DivTo(dst, x, y)
for i := range v.expect { for i := range v.expect {
if !same(ret[i], v.expect[i]) { if !scalar.Same(ret[i], v.expect[i]) {
t.Errorf("Test %d DivTo error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i]) t.Errorf("Test %d DivTo error at %d Got: %v Expected: %v", j, i, ret[i], v.expect[i])
} }
if !same(ret[i], dst[i]) { if !scalar.Same(ret[i], dst[i]) {
t.Errorf("Test %d DivTo ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i]) t.Errorf("Test %d DivTo ret/dst mismatch %d Ret: %v Dst: %v", j, i, ret[i], dst[i])
} }
} }
@@ -502,7 +503,7 @@ func TestL1Dist(t *testing.T) {
v.s, v.t = guardVector(v.s, s_gd, sg_ln), guardVector(v.t, t_gd, tg_ln) v.s, v.t = guardVector(v.s, s_gd, sg_ln), guardVector(v.t, t_gd, tg_ln)
s_lc, t_lc := v.s[sg_ln:len(v.s)-sg_ln], v.t[tg_ln:len(v.t)-tg_ln] s_lc, t_lc := v.s[sg_ln:len(v.s)-sg_ln], v.t[tg_ln:len(v.t)-tg_ln]
ret := L1Dist(s_lc, t_lc) ret := L1Dist(s_lc, t_lc)
if !same(ret, v.expect) { if !scalar.Same(ret, v.expect) {
t.Errorf("Test %d L1Dist error Got: %f Expected: %f", j, ret, v.expect) t.Errorf("Test %d L1Dist error Got: %f Expected: %f", j, ret, v.expect)
} }
if !isValidGuard(v.s, s_gd, sg_ln) { if !isValidGuard(v.s, s_gd, sg_ln) {
@@ -565,7 +566,7 @@ func TestLinfDist(t *testing.T) {
v.s, v.t = guardVector(v.s, s_gd, sg_ln), guardVector(v.t, t_gd, tg_ln) v.s, v.t = guardVector(v.s, s_gd, sg_ln), guardVector(v.t, t_gd, tg_ln)
s_lc, t_lc := v.s[sg_ln:len(v.s)-sg_ln], v.t[tg_ln:len(v.t)-tg_ln] s_lc, t_lc := v.s[sg_ln:len(v.s)-sg_ln], v.t[tg_ln:len(v.t)-tg_ln]
ret := LinfDist(s_lc, t_lc) ret := LinfDist(s_lc, t_lc)
if !same(ret, v.expect) { if !scalar.Same(ret, v.expect) {
t.Errorf("Test %d LinfDist error Got: %f Expected: %f", j, ret, v.expect) t.Errorf("Test %d LinfDist error Got: %f Expected: %f", j, ret, v.expect)
} }
if !isValidGuard(v.s, s_gd, sg_ln) { if !isValidGuard(v.s, s_gd, sg_ln) {
@@ -628,7 +629,7 @@ func TestSum(t *testing.T) {
gsrc := guardVector(v.src, srcGd, gdLn) gsrc := guardVector(v.src, srcGd, gdLn)
src := gsrc[gdLn : len(gsrc)-gdLn] src := gsrc[gdLn : len(gsrc)-gdLn]
ret := Sum(src) ret := Sum(src)
if !same(ret, v.expect) { if !scalar.Same(ret, v.expect) {
t.Errorf("Test %d Sum error Got: %v Expected: %v", j, ret, v.expect) t.Errorf("Test %d Sum error Got: %v Expected: %v", j, ret, v.expect)
} }
if !isValidGuard(gsrc, srcGd, gdLn) { if !isValidGuard(gsrc, srcGd, gdLn) {

View File

@@ -111,10 +111,10 @@ func TestUniformScore(t *testing.T) {
{1.001, math.NaN(), math.NaN()}, {1.001, math.NaN(), math.NaN()},
} { } {
score := u.Score(nil, test.x) score := u.Score(nil, test.x)
if !same(score[0], test.wantMin) { if !scalar.Same(score[0], test.wantMin) {
t.Errorf("Score[0] mismatch for at %g: got %v, want %g", test.x, score[0], test.wantMin) t.Errorf("Score[0] mismatch for at %g: got %v, want %g", test.x, score[0], test.wantMin)
} }
if !same(score[1], test.wantMax) { if !scalar.Same(score[1], test.wantMax) {
t.Errorf("Score[1] mismatch for at %g: got %v, want %g", test.x, score[1], test.wantMax) t.Errorf("Score[1] mismatch for at %g: got %v, want %g", test.x, score[1], test.wantMax)
} }
} }
@@ -135,7 +135,3 @@ func TestUniformScoreInput(t *testing.T) {
} }
} }
} }
func same(a, b float64) bool {
return a == b || (math.IsNaN(a) && math.IsNaN(b))
}