mirror of
https://github.com/gonum/gonum.git
synced 2025-10-21 06:09:26 +08:00
ROC function
Produces a ROC curve either for all possible cutoffs, or for n equally spaced cutoffs.
This commit is contained in:
60
stat_test.go
60
stat_test.go
@@ -7,6 +7,7 @@ package stat
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
@@ -1264,6 +1265,65 @@ func TestSortWeighted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortWeightedLabeled(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
x []float64
|
||||
l []bool
|
||||
w []float64
|
||||
ansx []float64
|
||||
ansl []bool
|
||||
answ []float64
|
||||
}{
|
||||
{
|
||||
x: []float64{8, 3, 7, 8, 4},
|
||||
ansx: []float64{3, 4, 7, 8, 8},
|
||||
},
|
||||
{
|
||||
x: []float64{8, 3, 7, 8, 4},
|
||||
w: []float64{.5, 1, 1, .5, 1},
|
||||
ansx: []float64{3, 4, 7, 8, 8},
|
||||
answ: []float64{1, 1, 1, .5, .5},
|
||||
},
|
||||
{
|
||||
x: []float64{8, 3, 7, 8, 4},
|
||||
l: []bool{false, false, true, false, true},
|
||||
ansx: []float64{3, 4, 7, 8, 8},
|
||||
ansl: []bool{false, true, true, false, false},
|
||||
},
|
||||
{
|
||||
x: []float64{8, 3, 7, 8, 4},
|
||||
l: []bool{false, false, true, false, true},
|
||||
w: []float64{.5, 1, 1, .5, 1},
|
||||
ansx: []float64{3, 4, 7, 8, 8},
|
||||
ansl: []bool{false, true, true, false, false},
|
||||
answ: []float64{1, 1, 1, .5, .5},
|
||||
},
|
||||
} {
|
||||
SortWeightedLabeled(test.x, test.l, test.w)
|
||||
if !floats.Same(test.x, test.ansx) {
|
||||
t.Errorf("SortWeightedLabelled mismatch case %d. Expected x %v, Found x %v", i, test.ansx, test.x)
|
||||
}
|
||||
if (test.l != nil) && !reflect.DeepEqual(test.l, test.ansl) {
|
||||
t.Errorf("SortWeightedLabelled mismatch case %d. Expected l %v, Found l %v", i, test.ansl, test.l)
|
||||
}
|
||||
if (test.w != nil) && !floats.Same(test.w, test.answ) {
|
||||
t.Errorf("SortWeightedLabelled mismatch case %d. Expected w %v, Found w %v", i, test.answ, test.w)
|
||||
}
|
||||
}
|
||||
if !Panics(func() { SortWeightedLabeled(make([]float64, 3), make([]bool, 2), make([]float64, 3)) }) {
|
||||
t.Errorf("SortWeighted did not panic with x, labels length mismatch")
|
||||
}
|
||||
if !Panics(func() { SortWeightedLabeled(make([]float64, 3), make([]bool, 2), nil) }) {
|
||||
t.Errorf("SortWeighted did not panic with x, labels length mismatch")
|
||||
}
|
||||
if !Panics(func() { SortWeightedLabeled(make([]float64, 3), make([]bool, 3), make([]float64, 2)) }) {
|
||||
t.Errorf("SortWeighted did not panic with x, weights length mismatch")
|
||||
}
|
||||
if !Panics(func() { SortWeightedLabeled(make([]float64, 3), nil, make([]float64, 2)) }) {
|
||||
t.Errorf("SortWeighted did not panic with x, weights length mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVariance(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
x []float64
|
||||
|
Reference in New Issue
Block a user