mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 07:37:03 +08:00
stat: invert class semantics and return cutoffs
This commit is contained in:
57
stat/roc.go
57
stat/roc.go
@@ -12,7 +12,9 @@ import (
|
||||
// ROC returns paired false positive rate (FPR) and true positive rate
|
||||
// (TPR) values corresponding to cutoff points on the receiver operator
|
||||
// characteristic (ROC) curve obtained when y is treated as a binary
|
||||
// classifier for classes with weights.
|
||||
// classifier for classes with weights. The cutoff thresholds used to
|
||||
// calculate the ROC are returned in thresh such that tpr[i] and fpr[i]
|
||||
// are the true and false positive rates for y >= thresh[i].
|
||||
//
|
||||
// The input y and cutoffs must be sorted, and values in y must correspond
|
||||
// to values in classes and weights. SortWeightedLabeled can be used to
|
||||
@@ -34,7 +36,7 @@ import (
|
||||
//
|
||||
// More details about ROC curves are available at
|
||||
// https://en.wikipedia.org/wiki/Receiver_operating_characteristic
|
||||
func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []float64) {
|
||||
func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) {
|
||||
if len(y) != len(classes) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
@@ -48,47 +50,50 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []fl
|
||||
panic("stat: cutoff values must be sorted ascending")
|
||||
}
|
||||
if len(y) == 0 {
|
||||
return nil, nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
var bin int
|
||||
if len(cutoffs) == 0 {
|
||||
if cutoffs == nil || cap(cutoffs) < len(y)+1 {
|
||||
cutoffs = make([]float64, len(y)+1)
|
||||
} else {
|
||||
cutoffs = cutoffs[:len(y)+1]
|
||||
}
|
||||
cutoffs[0] = math.Nextafter(y[0], y[0]-1)
|
||||
// Choose all possible cutoffs but remove duplicate values
|
||||
// in y.
|
||||
for i, u := range y {
|
||||
if i != 0 && u != y[i-1] {
|
||||
cutoffs[0] = math.Inf(-1)
|
||||
// Choose all possible cutoffs for unique values in y.
|
||||
bin := 1
|
||||
cutoffs[bin] = y[0]
|
||||
for i, u := range y[1:] {
|
||||
if u == y[i] {
|
||||
continue
|
||||
}
|
||||
bin++
|
||||
cutoffs[bin] = u
|
||||
}
|
||||
cutoffs[bin+1] = u
|
||||
}
|
||||
cutoffs = cutoffs[0 : bin+2]
|
||||
cutoffs = cutoffs[:bin+1]
|
||||
} else {
|
||||
// Don't mutate the provided cutoffs.
|
||||
tmp := cutoffs
|
||||
cutoffs = make([]float64, len(cutoffs))
|
||||
copy(cutoffs, tmp)
|
||||
}
|
||||
|
||||
tpr = make([]float64, len(cutoffs))
|
||||
fpr = make([]float64, len(cutoffs))
|
||||
bin = 0
|
||||
var bin int
|
||||
var nPos, nNeg float64
|
||||
for i, u := range classes {
|
||||
// Update the bin until it matches the next y value
|
||||
// skipping empty bins.
|
||||
for bin < len(cutoffs) && y[i] > cutoffs[bin] {
|
||||
if bin == len(cutoffs)-1 {
|
||||
break
|
||||
}
|
||||
for bin < len(cutoffs)-1 && y[i] > cutoffs[bin] {
|
||||
bin++
|
||||
tpr[bin] = tpr[bin-1]
|
||||
fpr[bin] = fpr[bin-1]
|
||||
}
|
||||
var posWeight, negWeight float64 = 0, 1
|
||||
posWeight, negWeight := 1.0, 0.0
|
||||
if weights != nil {
|
||||
negWeight = weights[i]
|
||||
posWeight = weights[i]
|
||||
}
|
||||
if u {
|
||||
if !u {
|
||||
posWeight, negWeight = negWeight, posWeight
|
||||
}
|
||||
nPos += posWeight
|
||||
@@ -103,8 +108,18 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []fl
|
||||
invPos := 1 / nPos
|
||||
for i := range tpr {
|
||||
tpr[i] *= invPos
|
||||
tpr[i] = 1 - tpr[i]
|
||||
fpr[i] *= invNeg
|
||||
fpr[i] = 1 - fpr[i]
|
||||
}
|
||||
for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 {
|
||||
tpr[i], tpr[j] = tpr[j], tpr[i]
|
||||
fpr[i], fpr[j] = fpr[j], fpr[i]
|
||||
}
|
||||
for i, j := 1, len(cutoffs)-1; i < j; i, j = i+1, j-1 {
|
||||
cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i]
|
||||
}
|
||||
cutoffs[0] = math.Inf(1)
|
||||
|
||||
return tpr, fpr
|
||||
return tpr, fpr, cutoffs
|
||||
}
|
||||
|
@@ -15,67 +15,83 @@ import (
|
||||
|
||||
func ExampleROC_weighted() {
|
||||
y := []float64{0, 3, 5, 6, 7.5, 8}
|
||||
classes := []bool{true, false, true, false, false, false}
|
||||
classes := []bool{false, true, false, true, true, true}
|
||||
weights := []float64{4, 1, 6, 3, 2, 2}
|
||||
|
||||
tpr, fpr := stat.ROC(nil, y, classes, weights)
|
||||
tpr, fpr, _ := stat.ROC(nil, y, classes, weights)
|
||||
fmt.Printf("true positive rate: %v\n", tpr)
|
||||
fmt.Printf("false positive rate: %v\n", fpr)
|
||||
|
||||
// Output:
|
||||
// true positive rate: [0 0.4 0.4 1 1 1 1]
|
||||
// false positive rate: [0 0 0.125 0.125 0.5 0.75 1]
|
||||
// true positive rate: [0 0.25 0.5 0.875 0.875 1 1]
|
||||
// false positive rate: [0 0 0 0 0.6 0.6 1]
|
||||
}
|
||||
|
||||
func ExampleROC_unweighted() {
|
||||
y := []float64{0, 3, 5, 6, 7.5, 8}
|
||||
classes := []bool{true, false, true, false, false, false}
|
||||
classes := []bool{false, true, false, true, true, true}
|
||||
|
||||
tpr, fpr := stat.ROC(nil, y, classes, nil)
|
||||
tpr, fpr, _ := stat.ROC(nil, y, classes, nil)
|
||||
fmt.Printf("true positive rate: %v\n", tpr)
|
||||
fmt.Printf("false positive rate: %v\n", fpr)
|
||||
|
||||
// Output:
|
||||
// true positive rate: [0 0.5 0.5 1 1 1 1]
|
||||
// false positive rate: [0 0 0.25 0.25 0.5 0.75 1]
|
||||
// true positive rate: [0 0.25 0.5 0.75 0.75 1 1]
|
||||
// false positive rate: [0 0 0 0 0.5 0.5 1]
|
||||
}
|
||||
|
||||
func ExampleROC_threshold() {
|
||||
y := []float64{0.1, 0.4, 0.35, 0.8}
|
||||
classes := []bool{false, false, true, true}
|
||||
stat.SortWeightedLabeled(y, classes, nil)
|
||||
|
||||
tpr, fpr, thresh := stat.ROC(nil, y, classes, nil)
|
||||
fmt.Printf("true positive rate: %v\n", tpr)
|
||||
fmt.Printf("false positive rate: %v\n", fpr)
|
||||
fmt.Printf("cutoff thresholds: %v\n", thresh)
|
||||
|
||||
// Output:
|
||||
// true positive rate: [0 0.5 0.5 1 1]
|
||||
// false positive rate: [0 0 0.5 0.5 1]
|
||||
// cutoff thresholds: [+Inf 0.8 0.4 0.35 0.1]
|
||||
}
|
||||
|
||||
func ExampleROC_unsorted() {
|
||||
y := []float64{8, 7.5, 6, 5, 3, 0}
|
||||
classes := []bool{false, false, false, true, false, true}
|
||||
classes := []bool{true, true, true, false, true, false}
|
||||
weights := []float64{2, 2, 3, 6, 1, 4}
|
||||
|
||||
stat.SortWeightedLabeled(y, classes, weights)
|
||||
|
||||
tpr, fpr := stat.ROC(nil, y, classes, weights)
|
||||
tpr, fpr, _ := stat.ROC(nil, y, classes, weights)
|
||||
fmt.Printf("true positive rate: %v\n", tpr)
|
||||
fmt.Printf("false positive rate: %v\n", fpr)
|
||||
|
||||
// Output:
|
||||
// true positive rate: [0 0.4 0.4 1 1 1 1]
|
||||
// false positive rate: [0 0 0.125 0.125 0.5 0.75 1]
|
||||
// true positive rate: [0 0.25 0.5 0.875 0.875 1 1]
|
||||
// false positive rate: [0 0 0 0 0.6 0.6 1]
|
||||
}
|
||||
|
||||
func ExampleROC_knownCutoffs() {
|
||||
y := []float64{8, 7.5, 6, 5, 3, 0}
|
||||
classes := []bool{false, false, false, true, false, true}
|
||||
classes := []bool{true, true, true, false, true, false}
|
||||
weights := []float64{2, 2, 3, 6, 1, 4}
|
||||
cutoffs := []float64{-1, 3, 4}
|
||||
|
||||
stat.SortWeightedLabeled(y, classes, weights)
|
||||
|
||||
tpr, fpr := stat.ROC(cutoffs, y, classes, weights)
|
||||
tpr, fpr, _ := stat.ROC(cutoffs, y, classes, weights)
|
||||
fmt.Printf("true positive rate: %v\n", tpr)
|
||||
fmt.Printf("false positive rate: %v\n", fpr)
|
||||
|
||||
// Output:
|
||||
// true positive rate: [0 0.4 0.4]
|
||||
// false positive rate: [0 0.125 0.125]
|
||||
// true positive rate: [0.875 0.875 1]
|
||||
// false positive rate: [0.6 0.6 1]
|
||||
}
|
||||
|
||||
func ExampleROC_equallySpacedCutoffs() {
|
||||
y := []float64{8, 7.5, 6, 5, 3, 0}
|
||||
classes := []bool{false, false, false, true, false, true}
|
||||
classes := []bool{true, true, true, false, true, true}
|
||||
weights := []float64{2, 2, 3, 6, 1, 4}
|
||||
n := 9
|
||||
|
||||
@@ -83,20 +99,20 @@ func ExampleROC_equallySpacedCutoffs() {
|
||||
cutoffs := make([]float64, n)
|
||||
floats.Span(cutoffs, math.Nextafter(y[0], y[0]-1), y[len(y)-1])
|
||||
|
||||
tpr, fpr := stat.ROC(cutoffs, y, classes, weights)
|
||||
fmt.Printf("true positive rate: %v\n", tpr)
|
||||
fmt.Printf("false positive rate: %v\n", fpr)
|
||||
tpr, fpr, _ := stat.ROC(cutoffs, y, classes, weights)
|
||||
fmt.Printf("true positive rate: %.3v\n", tpr)
|
||||
fmt.Printf("false positive rate: %.3v\n", fpr)
|
||||
|
||||
// Output:
|
||||
// true positive rate: [0 0.4 0.4 0.4 0.4 1 1 1 1]
|
||||
// false positive rate: [0 0 0 0.125 0.125 0.125 0.5 0.5 1]
|
||||
// true positive rate: [0 0.333 0.333 0.583 0.583 0.583 0.667 0.667 1]
|
||||
// false positive rate: [0 0 0 0 1 1 1 1 1]
|
||||
}
|
||||
|
||||
func ExampleROC_aUC() {
|
||||
y := []float64{0.1, 0.35, 0.4, 0.8}
|
||||
classes := []bool{true, false, true, false}
|
||||
|
||||
tpr, fpr := stat.ROC(nil, y, classes, nil)
|
||||
tpr, fpr, _ := stat.ROC(nil, y, classes, nil)
|
||||
|
||||
// Compute Area Under Curve.
|
||||
auc := integrate.Trapezoidal(fpr, tpr)
|
||||
@@ -105,7 +121,7 @@ func ExampleROC_aUC() {
|
||||
fmt.Printf("auc: %v\n", auc)
|
||||
|
||||
// Output:
|
||||
// true positive rate: [0 0.5 0.5 1 1]
|
||||
// false positive rate: [0 0 0.5 0.5 1]
|
||||
// auc: 0.75
|
||||
// true positive rate: [0 0 0.5 0.5 1]
|
||||
// false positive rate: [0 0.5 0.5 1 1]
|
||||
// auc: 0.25
|
||||
}
|
||||
|
183
stat/roc_test.go
183
stat/roc_test.go
@@ -11,7 +11,6 @@ import (
|
||||
"gonum.org/v1/gonum/floats"
|
||||
)
|
||||
|
||||
// Test cases were calculated manually.
|
||||
func TestROC(t *testing.T) {
|
||||
cases := []struct {
|
||||
y []float64
|
||||
@@ -20,160 +19,206 @@ func TestROC(t *testing.T) {
|
||||
cutoffs []float64
|
||||
wantTPR []float64
|
||||
wantFPR []float64
|
||||
wantThresh []float64
|
||||
}{
|
||||
// Test cases were calculated using sklearn metrics.roc_curve when
|
||||
// cutoffs is nil. Where cutoffs is not nil, a visual inspection is
|
||||
// used.
|
||||
// Some differences exist between unweighted ROCs from our function
|
||||
// and metrics.roc_curve which appears to use integer cutoffs in that
|
||||
// case. sklearn also appears to do some magic that trims leading zeros
|
||||
// sometimes.
|
||||
{ // 0
|
||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
wantTPR: []float64{0, 0.5, 0.5, 1, 1, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.25, 0.25, 0.5, 0.75, 1},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
|
||||
},
|
||||
{ // 1
|
||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
w: []float64{4, 1, 6, 3, 2, 2},
|
||||
wantTPR: []float64{0, 0.4, 0.4, 1, 1, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.125, 0.125, 0.5, 0.75, 1},
|
||||
wantTPR: []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
|
||||
},
|
||||
{ // 2
|
||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
cutoffs: []float64{-1, 2, 4, 6, 8},
|
||||
wantTPR: []float64{0, 0.5, 0.5, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.25, 0.5, 1},
|
||||
wantTPR: []float64{0, 0.5, 0.75, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
||||
},
|
||||
{ // 3
|
||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
||||
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1, 1, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0, 0.25, 0.25, 0.25, 0.5, 0.5, 1},
|
||||
wantTPR: []float64{0, 0.5, 0.5, 0.75, 0.75, 0.75, 1, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 0.5, 0.5, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
||||
},
|
||||
{ // 4
|
||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
w: []float64{4, 1, 6, 3, 2, 2},
|
||||
cutoffs: []float64{-1, 2, 4, 6, 8},
|
||||
wantTPR: []float64{0, 0.4, 0.4, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.125, 0.5, 1},
|
||||
wantTPR: []float64{0, 0.5, 0.875, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
||||
},
|
||||
{ // 5
|
||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
w: []float64{4, 1, 6, 3, 2, 2},
|
||||
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
||||
wantTPR: []float64{0, 0.4, 0.4, 0.4, 0.4, 1, 1, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0, 0.125, 0.125, 0.125, 0.5, 0.5, 1},
|
||||
wantTPR: []float64{0, 0.5, 0.5, 0.875, 0.875, 0.875, 1, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 0.6, 0.6, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
||||
},
|
||||
{ // 6
|
||||
y: []float64{0, 3, 6, 6, 6, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
wantTPR: []float64{0, 0.5, 0.5, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.25, 0.75, 1},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
wantTPR: []float64{0, 0.25, 0.75, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
|
||||
},
|
||||
{ // 7
|
||||
y: []float64{0, 3, 6, 6, 6, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
w: []float64{4, 1, 6, 3, 2, 2},
|
||||
wantTPR: []float64{0, 0.4, 0.4, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.125, 0.75, 1},
|
||||
wantTPR: []float64{0, 0.25, 0.875, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
|
||||
},
|
||||
{ // 8
|
||||
y: []float64{0, 3, 6, 6, 6, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
cutoffs: []float64{-1, 2, 4, 6, 8},
|
||||
wantTPR: []float64{0, 0.5, 0.5, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.25, 0.75, 1},
|
||||
wantTPR: []float64{0, 0.25, 0.75, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
||||
},
|
||||
{ // 9
|
||||
y: []float64{0, 3, 6, 6, 6, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
||||
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 0.5, 1, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0, 0.25, 0.25, 0.25, 0.75, 0.75, 1},
|
||||
wantTPR: []float64{0, 0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
||||
},
|
||||
{ // 10
|
||||
y: []float64{0, 3, 6, 6, 6, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
w: []float64{4, 1, 6, 3, 2, 2},
|
||||
cutoffs: []float64{-1, 2, 4, 6, 8},
|
||||
wantTPR: []float64{0, 0.4, 0.4, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.125, 0.75, 1},
|
||||
wantTPR: []float64{0, 0.25, 0.875, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
||||
},
|
||||
{ // 11
|
||||
y: []float64{0, 3, 6, 6, 6, 8},
|
||||
c: []bool{true, false, true, false, false, false},
|
||||
c: []bool{false, true, false, true, true, true},
|
||||
w: []float64{4, 1, 6, 3, 2, 2},
|
||||
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
||||
wantTPR: []float64{0, 0.4, 0.4, 0.4, 0.4, 0.4, 1, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0, 0.125, 0.125, 0.125, 0.75, 0.75, 1},
|
||||
wantTPR: []float64{0, 0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
|
||||
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
||||
},
|
||||
{ // 12
|
||||
y: []float64{1, 2},
|
||||
c: []bool{true, true},
|
||||
wantTPR: []float64{0, 0.5, 1},
|
||||
wantFPR: []float64{math.NaN(), math.NaN(), math.NaN()},
|
||||
y: []float64{0.1, 0.35, 0.4, 0.8},
|
||||
c: []bool{true, false, true, false},
|
||||
wantTPR: []float64{0, 0, 0.5, 0.5, 1},
|
||||
wantFPR: []float64{0, 0.5, 0.5, 1, 1},
|
||||
wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
|
||||
},
|
||||
{ // 13
|
||||
y: []float64{1, 2},
|
||||
c: []bool{true, true},
|
||||
cutoffs: []float64{-1, 2},
|
||||
wantTPR: []float64{0, 1},
|
||||
wantFPR: []float64{math.NaN(), math.NaN()},
|
||||
y: []float64{0.1, 0.35, 0.4, 0.8},
|
||||
c: []bool{false, false, true, true},
|
||||
wantTPR: []float64{0, 0.5, 1, 1, 1},
|
||||
wantFPR: []float64{0, 0, 0, 0.5, 1},
|
||||
wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
|
||||
},
|
||||
{ // 14
|
||||
y: []float64{1, 2},
|
||||
c: []bool{true, true},
|
||||
cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
|
||||
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1},
|
||||
wantFPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
|
||||
y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
|
||||
c: []bool{false, true, false, false, true, true, false},
|
||||
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
|
||||
wantTPR: []float64{0, 0, 0, 0, 1},
|
||||
wantFPR: []float64{0, 0.25, 0.25, 0.25, 1},
|
||||
wantThresh: []float64{math.Inf(1), 10, 7.5, 5, 2.5},
|
||||
},
|
||||
{ // 15
|
||||
y: []float64{1},
|
||||
c: []bool{true},
|
||||
wantTPR: []float64{0, 1},
|
||||
wantFPR: []float64{math.NaN(), math.NaN()},
|
||||
y: []float64{1, 2},
|
||||
c: []bool{false, false},
|
||||
wantTPR: []float64{math.NaN(), math.NaN(), math.NaN()},
|
||||
wantFPR: []float64{0, 0.5, 1},
|
||||
wantThresh: []float64{math.Inf(1), 2, 1},
|
||||
},
|
||||
{ // 16
|
||||
y: []float64{1},
|
||||
c: []bool{true},
|
||||
cutoffs: []float64{-1, 1},
|
||||
wantTPR: []float64{0, 1},
|
||||
wantFPR: []float64{math.NaN(), math.NaN()},
|
||||
y: []float64{1, 2},
|
||||
c: []bool{false, false},
|
||||
cutoffs: []float64{-1, 2},
|
||||
wantTPR: []float64{math.NaN(), math.NaN()},
|
||||
wantFPR: []float64{0, 1},
|
||||
wantThresh: []float64{math.Inf(1), 2},
|
||||
},
|
||||
{ // 17
|
||||
y: []float64{1, 2},
|
||||
c: []bool{false, false},
|
||||
cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
|
||||
wantTPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
|
||||
wantFPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1},
|
||||
wantThresh: []float64{math.Inf(1), 2, 1.8, 1.6, 1.4, 1.2},
|
||||
},
|
||||
{ // 18
|
||||
y: []float64{1},
|
||||
c: []bool{false},
|
||||
wantTPR: []float64{math.NaN(), math.NaN()},
|
||||
wantFPR: []float64{0, 1},
|
||||
},
|
||||
{ // 18
|
||||
y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
|
||||
c: []bool{true, false, true, true, false, false, true},
|
||||
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
|
||||
wantTPR: []float64{0, 0.75, 0.75, 0.75, 1},
|
||||
wantFPR: []float64{0, 1, 1, 1, 1},
|
||||
wantThresh: []float64{math.Inf(1), 1},
|
||||
},
|
||||
{ // 19
|
||||
y: []float64{1},
|
||||
c: []bool{false},
|
||||
cutoffs: []float64{-1, 1},
|
||||
wantTPR: []float64{math.NaN(), math.NaN()},
|
||||
wantFPR: []float64{0, 1},
|
||||
wantThresh: []float64{math.Inf(1), 1},
|
||||
},
|
||||
{ // 20
|
||||
y: []float64{1},
|
||||
c: []bool{true},
|
||||
wantTPR: []float64{0, 1},
|
||||
wantFPR: []float64{math.NaN(), math.NaN()},
|
||||
wantThresh: []float64{math.Inf(1), 1},
|
||||
},
|
||||
{ // 21
|
||||
y: []float64{},
|
||||
c: []bool{},
|
||||
wantTPR: nil,
|
||||
wantFPR: nil,
|
||||
wantThresh: nil,
|
||||
},
|
||||
{ // 20
|
||||
{ // 22
|
||||
y: []float64{},
|
||||
c: []bool{},
|
||||
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
|
||||
wantTPR: nil,
|
||||
wantFPR: nil,
|
||||
wantThresh: nil,
|
||||
},
|
||||
}
|
||||
for i, test := range cases {
|
||||
gotTPR, gotFPR := ROC(test.cutoffs, test.y, test.c, test.w)
|
||||
gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w)
|
||||
if !floats.Same(gotTPR, test.wantTPR) {
|
||||
t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR)
|
||||
}
|
||||
if !floats.Same(gotFPR, test.wantFPR) {
|
||||
t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR)
|
||||
}
|
||||
if !floats.Same(gotThresh, test.wantThresh) {
|
||||
t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user