mirror of
https://github.com/gonum/gonum.git
synced 2025-10-21 14:19:35 +08:00
Handle NaN inputs in CDF
When sort.Float64sAreSorted is called, it panics on NaN inputs. The way the code is written, there seems to be an intent that if the input x contains a NaN, then the function should return a NaN. I switched the order of the conditional statements so that this became possible. I also modified the test cases so that cases with NaN could be evaluated.
This commit is contained in:
6
stat.go
6
stat.go
@@ -66,12 +66,12 @@ func CDF(q float64, c CumulantKind, x, weights []float64) float64 {
|
|||||||
if weights != nil && len(x) != len(weights) {
|
if weights != nil && len(x) != len(weights) {
|
||||||
panic("stat: slice length mismatch")
|
panic("stat: slice length mismatch")
|
||||||
}
|
}
|
||||||
if !sort.Float64sAreSorted(x) {
|
|
||||||
panic("x data are not sorted")
|
|
||||||
}
|
|
||||||
if floats.HasNaN(x) {
|
if floats.HasNaN(x) {
|
||||||
return math.NaN()
|
return math.NaN()
|
||||||
}
|
}
|
||||||
|
if !sort.Float64sAreSorted(x) {
|
||||||
|
panic("x data are not sorted")
|
||||||
|
}
|
||||||
|
|
||||||
if q < x[0] {
|
if q < x[0] {
|
||||||
return 0
|
return 0
|
||||||
|
15
stat_test.go
15
stat_test.go
@@ -755,6 +755,11 @@ func TestCDF(t *testing.T) {
|
|||||||
weights: []float64{1, 1, 1, 1, 1},
|
weights: []float64{1, 1, 1, 1, 1},
|
||||||
ans: [][]float64{{0, 0, 0.2, 0.2, 0.4, 0.6, 0.6, 0.8, 1, 1}},
|
ans: [][]float64{{0, 0, 0.2, 0.2, 0.4, 0.6, 0.6, 0.8, 1, 1}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
q: []float64{0, 0.9, 1},
|
||||||
|
x: []float64{math.NaN()},
|
||||||
|
ans: [][]float64{{math.NaN(), math.NaN(), math.NaN()}},
|
||||||
|
},
|
||||||
} {
|
} {
|
||||||
copyX := make([]float64, len(test.x))
|
copyX := make([]float64, len(test.x))
|
||||||
copy(copyX, test.x)
|
copy(copyX, test.x)
|
||||||
@@ -766,13 +771,13 @@ func TestCDF(t *testing.T) {
|
|||||||
for j, q := range test.q {
|
for j, q := range test.q {
|
||||||
for k, kind := range cumulantKinds {
|
for k, kind := range cumulantKinds {
|
||||||
v := CDF(q, kind, test.x, test.weights)
|
v := CDF(q, kind, test.x, test.weights)
|
||||||
if !floats.Equal(copyX, test.x) {
|
if !floats.Equal(copyX, test.x) && !math.IsNaN(v) {
|
||||||
t.Errorf("x changed for case %d kind %d percentile %v", i, k, q)
|
t.Errorf("x changed for case %d kind %d percentile %v", i, k, q)
|
||||||
}
|
}
|
||||||
if !floats.Equal(copyW, test.weights) {
|
if !floats.Equal(copyW, test.weights) {
|
||||||
t.Errorf("x changed for case %d kind %d percentile %v", i, k, q)
|
t.Errorf("x changed for case %d kind %d percentile %v", i, k, q)
|
||||||
}
|
}
|
||||||
if v != test.ans[k][j] {
|
if v != test.ans[k][j] && !(math.IsNaN(v) && math.IsNaN(test.ans[k][j])) {
|
||||||
t.Errorf("mismatch case %d kind %d percentile %v. Expected: %v, found: %v", i, k, q, test.ans[k][j], v)
|
t.Errorf("mismatch case %d kind %d percentile %v. Expected: %v, found: %v", i, k, q, test.ans[k][j], v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -800,12 +805,6 @@ func TestCDF(t *testing.T) {
|
|||||||
kind: Empirical,
|
kind: Empirical,
|
||||||
x: []float64{3, 2, 1},
|
x: []float64{3, 2, 1},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "x has a NaN",
|
|
||||||
q: 1.5,
|
|
||||||
kind: Empirical,
|
|
||||||
x: []float64{1, 2, math.NaN()},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "unknown CumulantKind",
|
name: "unknown CumulantKind",
|
||||||
q: 1.5,
|
q: 1.5,
|
||||||
|
Reference in New Issue
Block a user