stat: invert class semantics and return cutoffs

This commit is contained in:
Dan Kortschak
2019-04-15 20:19:50 +09:30
parent e080754f81
commit 50ee437d7b
3 changed files with 234 additions and 158 deletions

View File

@@ -12,7 +12,9 @@ import (
// ROC returns paired false positive rate (FPR) and true positive rate // ROC returns paired false positive rate (FPR) and true positive rate
// (TPR) values corresponding to cutoff points on the receiver operator // (TPR) values corresponding to cutoff points on the receiver operator
// characteristic (ROC) curve obtained when y is treated as a binary // characteristic (ROC) curve obtained when y is treated as a binary
// classifier for classes with weights. // classifier for classes with weights. The cutoff thresholds used to
// calculate the ROC are returned in thresh such that tpr[i] and fpr[i]
// are the true and false positive rates for y >= thresh[i].
// //
// The input y and cutoffs must be sorted, and values in y must correspond // The input y and cutoffs must be sorted, and values in y must correspond
// to values in classes and weights. SortWeightedLabeled can be used to // to values in classes and weights. SortWeightedLabeled can be used to
@@ -34,7 +36,7 @@ import (
// //
// More details about ROC curves are available at // More details about ROC curves are available at
// https://en.wikipedia.org/wiki/Receiver_operating_characteristic // https://en.wikipedia.org/wiki/Receiver_operating_characteristic
func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []float64) { func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) {
if len(y) != len(classes) { if len(y) != len(classes) {
panic("stat: slice length mismatch") panic("stat: slice length mismatch")
} }
@@ -48,47 +50,50 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []fl
panic("stat: cutoff values must be sorted ascending") panic("stat: cutoff values must be sorted ascending")
} }
if len(y) == 0 { if len(y) == 0 {
return nil, nil return nil, nil, nil
} }
var bin int
if len(cutoffs) == 0 { if len(cutoffs) == 0 {
if cutoffs == nil || cap(cutoffs) < len(y)+1 { if cutoffs == nil || cap(cutoffs) < len(y)+1 {
cutoffs = make([]float64, len(y)+1) cutoffs = make([]float64, len(y)+1)
} else { } else {
cutoffs = cutoffs[:len(y)+1] cutoffs = cutoffs[:len(y)+1]
} }
cutoffs[0] = math.Nextafter(y[0], y[0]-1) cutoffs[0] = math.Inf(-1)
// Choose all possible cutoffs but remove duplicate values // Choose all possible cutoffs for unique values in y.
// in y. bin := 1
for i, u := range y { cutoffs[bin] = y[0]
if i != 0 && u != y[i-1] { for i, u := range y[1:] {
bin++ if u == y[i] {
continue
} }
cutoffs[bin+1] = u bin++
cutoffs[bin] = u
} }
cutoffs = cutoffs[0 : bin+2] cutoffs = cutoffs[:bin+1]
} else {
// Don't mutate the provided cutoffs.
tmp := cutoffs
cutoffs = make([]float64, len(cutoffs))
copy(cutoffs, tmp)
} }
tpr = make([]float64, len(cutoffs)) tpr = make([]float64, len(cutoffs))
fpr = make([]float64, len(cutoffs)) fpr = make([]float64, len(cutoffs))
bin = 0 var bin int
var nPos, nNeg float64 var nPos, nNeg float64
for i, u := range classes { for i, u := range classes {
// Update the bin until it matches the next y value // Update the bin until it matches the next y value
// skipping empty bins. // skipping empty bins.
for bin < len(cutoffs) && y[i] > cutoffs[bin] { for bin < len(cutoffs)-1 && y[i] > cutoffs[bin] {
if bin == len(cutoffs)-1 {
break
}
bin++ bin++
tpr[bin] = tpr[bin-1] tpr[bin] = tpr[bin-1]
fpr[bin] = fpr[bin-1] fpr[bin] = fpr[bin-1]
} }
var posWeight, negWeight float64 = 0, 1 posWeight, negWeight := 1.0, 0.0
if weights != nil { if weights != nil {
negWeight = weights[i] posWeight = weights[i]
} }
if u { if !u {
posWeight, negWeight = negWeight, posWeight posWeight, negWeight = negWeight, posWeight
} }
nPos += posWeight nPos += posWeight
@@ -103,8 +108,18 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []fl
invPos := 1 / nPos invPos := 1 / nPos
for i := range tpr { for i := range tpr {
tpr[i] *= invPos tpr[i] *= invPos
tpr[i] = 1 - tpr[i]
fpr[i] *= invNeg fpr[i] *= invNeg
fpr[i] = 1 - fpr[i]
} }
for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 {
tpr[i], tpr[j] = tpr[j], tpr[i]
fpr[i], fpr[j] = fpr[j], fpr[i]
}
for i, j := 1, len(cutoffs)-1; i < j; i, j = i+1, j-1 {
cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i]
}
cutoffs[0] = math.Inf(1)
return tpr, fpr return tpr, fpr, cutoffs
} }

View File

@@ -15,67 +15,83 @@ import (
func ExampleROC_weighted() { func ExampleROC_weighted() {
y := []float64{0, 3, 5, 6, 7.5, 8} y := []float64{0, 3, 5, 6, 7.5, 8}
classes := []bool{true, false, true, false, false, false} classes := []bool{false, true, false, true, true, true}
weights := []float64{4, 1, 6, 3, 2, 2} weights := []float64{4, 1, 6, 3, 2, 2}
tpr, fpr := stat.ROC(nil, y, classes, weights) tpr, fpr, _ := stat.ROC(nil, y, classes, weights)
fmt.Printf("true positive rate: %v\n", tpr) fmt.Printf("true positive rate: %v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr) fmt.Printf("false positive rate: %v\n", fpr)
// Output: // Output:
// true positive rate: [0 0.4 0.4 1 1 1 1] // true positive rate: [0 0.25 0.5 0.875 0.875 1 1]
// false positive rate: [0 0 0.125 0.125 0.5 0.75 1] // false positive rate: [0 0 0 0 0.6 0.6 1]
} }
func ExampleROC_unweighted() { func ExampleROC_unweighted() {
y := []float64{0, 3, 5, 6, 7.5, 8} y := []float64{0, 3, 5, 6, 7.5, 8}
classes := []bool{true, false, true, false, false, false} classes := []bool{false, true, false, true, true, true}
tpr, fpr := stat.ROC(nil, y, classes, nil) tpr, fpr, _ := stat.ROC(nil, y, classes, nil)
fmt.Printf("true positive rate: %v\n", tpr) fmt.Printf("true positive rate: %v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr) fmt.Printf("false positive rate: %v\n", fpr)
// Output: // Output:
// true positive rate: [0 0.5 0.5 1 1 1 1] // true positive rate: [0 0.25 0.5 0.75 0.75 1 1]
// false positive rate: [0 0 0.25 0.25 0.5 0.75 1] // false positive rate: [0 0 0 0 0.5 0.5 1]
}
func ExampleROC_threshold() {
y := []float64{0.1, 0.4, 0.35, 0.8}
classes := []bool{false, false, true, true}
stat.SortWeightedLabeled(y, classes, nil)
tpr, fpr, thresh := stat.ROC(nil, y, classes, nil)
fmt.Printf("true positive rate: %v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr)
fmt.Printf("cutoff thresholds: %v\n", thresh)
// Output:
// true positive rate: [0 0.5 0.5 1 1]
// false positive rate: [0 0 0.5 0.5 1]
// cutoff thresholds: [+Inf 0.8 0.4 0.35 0.1]
} }
func ExampleROC_unsorted() { func ExampleROC_unsorted() {
y := []float64{8, 7.5, 6, 5, 3, 0} y := []float64{8, 7.5, 6, 5, 3, 0}
classes := []bool{false, false, false, true, false, true} classes := []bool{true, true, true, false, true, false}
weights := []float64{2, 2, 3, 6, 1, 4} weights := []float64{2, 2, 3, 6, 1, 4}
stat.SortWeightedLabeled(y, classes, weights) stat.SortWeightedLabeled(y, classes, weights)
tpr, fpr := stat.ROC(nil, y, classes, weights) tpr, fpr, _ := stat.ROC(nil, y, classes, weights)
fmt.Printf("true positive rate: %v\n", tpr) fmt.Printf("true positive rate: %v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr) fmt.Printf("false positive rate: %v\n", fpr)
// Output: // Output:
// true positive rate: [0 0.4 0.4 1 1 1 1] // true positive rate: [0 0.25 0.5 0.875 0.875 1 1]
// false positive rate: [0 0 0.125 0.125 0.5 0.75 1] // false positive rate: [0 0 0 0 0.6 0.6 1]
} }
func ExampleROC_knownCutoffs() { func ExampleROC_knownCutoffs() {
y := []float64{8, 7.5, 6, 5, 3, 0} y := []float64{8, 7.5, 6, 5, 3, 0}
classes := []bool{false, false, false, true, false, true} classes := []bool{true, true, true, false, true, false}
weights := []float64{2, 2, 3, 6, 1, 4} weights := []float64{2, 2, 3, 6, 1, 4}
cutoffs := []float64{-1, 3, 4} cutoffs := []float64{-1, 3, 4}
stat.SortWeightedLabeled(y, classes, weights) stat.SortWeightedLabeled(y, classes, weights)
tpr, fpr := stat.ROC(cutoffs, y, classes, weights) tpr, fpr, _ := stat.ROC(cutoffs, y, classes, weights)
fmt.Printf("true positive rate: %v\n", tpr) fmt.Printf("true positive rate: %v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr) fmt.Printf("false positive rate: %v\n", fpr)
// Output: // Output:
// true positive rate: [0 0.4 0.4] // true positive rate: [0.875 0.875 1]
// false positive rate: [0 0.125 0.125] // false positive rate: [0.6 0.6 1]
} }
func ExampleROC_equallySpacedCutoffs() { func ExampleROC_equallySpacedCutoffs() {
y := []float64{8, 7.5, 6, 5, 3, 0} y := []float64{8, 7.5, 6, 5, 3, 0}
classes := []bool{false, false, false, true, false, true} classes := []bool{true, true, true, false, true, true}
weights := []float64{2, 2, 3, 6, 1, 4} weights := []float64{2, 2, 3, 6, 1, 4}
n := 9 n := 9
@@ -83,20 +99,20 @@ func ExampleROC_equallySpacedCutoffs() {
cutoffs := make([]float64, n) cutoffs := make([]float64, n)
floats.Span(cutoffs, math.Nextafter(y[0], y[0]-1), y[len(y)-1]) floats.Span(cutoffs, math.Nextafter(y[0], y[0]-1), y[len(y)-1])
tpr, fpr := stat.ROC(cutoffs, y, classes, weights) tpr, fpr, _ := stat.ROC(cutoffs, y, classes, weights)
fmt.Printf("true positive rate: %v\n", tpr) fmt.Printf("true positive rate: %.3v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr) fmt.Printf("false positive rate: %.3v\n", fpr)
// Output: // Output:
// true positive rate: [0 0.4 0.4 0.4 0.4 1 1 1 1] // true positive rate: [0 0.333 0.333 0.583 0.583 0.583 0.667 0.667 1]
// false positive rate: [0 0 0 0.125 0.125 0.125 0.5 0.5 1] // false positive rate: [0 0 0 0 1 1 1 1 1]
} }
func ExampleROC_aUC() { func ExampleROC_aUC() {
y := []float64{0.1, 0.35, 0.4, 0.8} y := []float64{0.1, 0.35, 0.4, 0.8}
classes := []bool{true, false, true, false} classes := []bool{true, false, true, false}
tpr, fpr := stat.ROC(nil, y, classes, nil) tpr, fpr, _ := stat.ROC(nil, y, classes, nil)
// Compute Area Under Curve. // Compute Area Under Curve.
auc := integrate.Trapezoidal(fpr, tpr) auc := integrate.Trapezoidal(fpr, tpr)
@@ -105,7 +121,7 @@ func ExampleROC_aUC() {
fmt.Printf("auc: %v\n", auc) fmt.Printf("auc: %v\n", auc)
// Output: // Output:
// true positive rate: [0 0.5 0.5 1 1] // true positive rate: [0 0 0.5 0.5 1]
// false positive rate: [0 0 0.5 0.5 1] // false positive rate: [0 0.5 0.5 1 1]
// auc: 0.75 // auc: 0.25
} }

View File

@@ -11,169 +11,214 @@ import (
"gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/floats"
) )
// Test cases were calculated manually.
func TestROC(t *testing.T) { func TestROC(t *testing.T) {
cases := []struct { cases := []struct {
y []float64 y []float64
c []bool c []bool
w []float64 w []float64
cutoffs []float64 cutoffs []float64
wantTPR []float64 wantTPR []float64
wantFPR []float64 wantFPR []float64
wantThresh []float64
}{ }{
// Test cases were calculated using sklearn metrics.roc_curve when
// cutoffs is nil. Where cutoffs is not nil, a visual inspection is
// used.
// Some differences exist between unweighted ROCs from our function
// and metrics.roc_curve which appears to use integer cutoffs in that
// case. sklearn also appears to do some magic that trims leading zeros
// sometimes.
{ // 0 { // 0
y: []float64{0, 3, 5, 6, 7.5, 8}, y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
wantTPR: []float64{0, 0.5, 0.5, 1, 1, 1, 1}, wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
wantFPR: []float64{0, 0, 0.25, 0.25, 0.5, 0.75, 1}, wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
}, },
{ // 1 { // 1
y: []float64{0, 3, 5, 6, 7.5, 8}, y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2}, w: []float64{4, 1, 6, 3, 2, 2},
wantTPR: []float64{0, 0.4, 0.4, 1, 1, 1, 1}, wantTPR: []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1},
wantFPR: []float64{0, 0, 0.125, 0.125, 0.5, 0.75, 1}, wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
}, },
{ // 2 { // 2
y: []float64{0, 3, 5, 6, 7.5, 8}, y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
cutoffs: []float64{-1, 2, 4, 6, 8}, cutoffs: []float64{-1, 2, 4, 6, 8},
wantTPR: []float64{0, 0.5, 0.5, 1, 1}, wantTPR: []float64{0, 0.5, 0.75, 1, 1},
wantFPR: []float64{0, 0, 0.25, 0.5, 1}, wantFPR: []float64{0, 0, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
}, },
{ // 3 { // 3
y: []float64{0, 3, 5, 6, 7.5, 8}, y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1, 1, 1, 1}, wantTPR: []float64{0, 0.5, 0.5, 0.75, 0.75, 0.75, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0.25, 0.25, 0.25, 0.5, 0.5, 1}, wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
}, },
{ // 4 { // 4
y: []float64{0, 3, 5, 6, 7.5, 8}, y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2}, w: []float64{4, 1, 6, 3, 2, 2},
cutoffs: []float64{-1, 2, 4, 6, 8}, cutoffs: []float64{-1, 2, 4, 6, 8},
wantTPR: []float64{0, 0.4, 0.4, 1, 1}, wantTPR: []float64{0, 0.5, 0.875, 1, 1},
wantFPR: []float64{0, 0, 0.125, 0.5, 1}, wantFPR: []float64{0, 0, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
}, },
{ // 5 { // 5
y: []float64{0, 3, 5, 6, 7.5, 8}, y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2}, w: []float64{4, 1, 6, 3, 2, 2},
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
wantTPR: []float64{0, 0.4, 0.4, 0.4, 0.4, 1, 1, 1, 1}, wantTPR: []float64{0, 0.5, 0.5, 0.875, 0.875, 0.875, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0.125, 0.125, 0.125, 0.5, 0.5, 1}, wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
}, },
{ // 6 { // 6
y: []float64{0, 3, 6, 6, 6, 8}, y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
wantTPR: []float64{0, 0.5, 0.5, 1, 1}, wantTPR: []float64{0, 0.25, 0.75, 1, 1},
wantFPR: []float64{0, 0, 0.25, 0.75, 1}, wantFPR: []float64{0, 0, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
}, },
{ // 7 { // 7
y: []float64{0, 3, 6, 6, 6, 8}, y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2}, w: []float64{4, 1, 6, 3, 2, 2},
wantTPR: []float64{0, 0.4, 0.4, 1, 1}, wantTPR: []float64{0, 0.25, 0.875, 1, 1},
wantFPR: []float64{0, 0, 0.125, 0.75, 1}, wantFPR: []float64{0, 0, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
}, },
{ // 8 { // 8
y: []float64{0, 3, 6, 6, 6, 8}, y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
cutoffs: []float64{-1, 2, 4, 6, 8}, cutoffs: []float64{-1, 2, 4, 6, 8},
wantTPR: []float64{0, 0.5, 0.5, 1, 1}, wantTPR: []float64{0, 0.25, 0.75, 1, 1},
wantFPR: []float64{0, 0, 0.25, 0.75, 1}, wantFPR: []float64{0, 0, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
}, },
{ // 9 { // 9
y: []float64{0, 3, 6, 6, 6, 8}, y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 0.5, 1, 1, 1}, wantTPR: []float64{0, 0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0.25, 0.25, 0.25, 0.75, 0.75, 1}, wantFPR: []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
}, },
{ // 10 { // 10
y: []float64{0, 3, 6, 6, 6, 8}, y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2}, w: []float64{4, 1, 6, 3, 2, 2},
cutoffs: []float64{-1, 2, 4, 6, 8}, cutoffs: []float64{-1, 2, 4, 6, 8},
wantTPR: []float64{0, 0.4, 0.4, 1, 1}, wantTPR: []float64{0, 0.25, 0.875, 1, 1},
wantFPR: []float64{0, 0, 0.125, 0.75, 1}, wantFPR: []float64{0, 0, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
}, },
{ // 11 { // 11
y: []float64{0, 3, 6, 6, 6, 8}, y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false}, c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2}, w: []float64{4, 1, 6, 3, 2, 2},
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
wantTPR: []float64{0, 0.4, 0.4, 0.4, 0.4, 0.4, 1, 1, 1}, wantTPR: []float64{0, 0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0.125, 0.125, 0.125, 0.75, 0.75, 1}, wantFPR: []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
}, },
{ // 12 { // 12
y: []float64{1, 2}, y: []float64{0.1, 0.35, 0.4, 0.8},
c: []bool{true, true}, c: []bool{true, false, true, false},
wantTPR: []float64{0, 0.5, 1}, wantTPR: []float64{0, 0, 0.5, 0.5, 1},
wantFPR: []float64{math.NaN(), math.NaN(), math.NaN()}, wantFPR: []float64{0, 0.5, 0.5, 1, 1},
wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
}, },
{ // 13 { // 13
y: []float64{1, 2}, y: []float64{0.1, 0.35, 0.4, 0.8},
c: []bool{true, true}, c: []bool{false, false, true, true},
cutoffs: []float64{-1, 2}, wantTPR: []float64{0, 0.5, 1, 1, 1},
wantTPR: []float64{0, 1}, wantFPR: []float64{0, 0, 0, 0.5, 1},
wantFPR: []float64{math.NaN(), math.NaN()}, wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
}, },
{ // 14 { // 14
y: []float64{1, 2}, y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
c: []bool{true, true}, c: []bool{false, true, false, false, true, true, false},
cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2}, cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1}, wantTPR: []float64{0, 0, 0, 0, 1},
wantFPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()}, wantFPR: []float64{0, 0.25, 0.25, 0.25, 1},
wantThresh: []float64{math.Inf(1), 10, 7.5, 5, 2.5},
}, },
{ // 15 { // 15
y: []float64{1}, y: []float64{1, 2},
c: []bool{true}, c: []bool{false, false},
wantTPR: []float64{0, 1}, wantTPR: []float64{math.NaN(), math.NaN(), math.NaN()},
wantFPR: []float64{math.NaN(), math.NaN()}, wantFPR: []float64{0, 0.5, 1},
wantThresh: []float64{math.Inf(1), 2, 1},
}, },
{ // 16 { // 16
y: []float64{1}, y: []float64{1, 2},
c: []bool{true}, c: []bool{false, false},
cutoffs: []float64{-1, 1}, cutoffs: []float64{-1, 2},
wantTPR: []float64{0, 1}, wantTPR: []float64{math.NaN(), math.NaN()},
wantFPR: []float64{math.NaN(), math.NaN()}, wantFPR: []float64{0, 1},
wantThresh: []float64{math.Inf(1), 2},
}, },
{ // 17 { // 17
y: []float64{1}, y: []float64{1, 2},
c: []bool{false}, c: []bool{false, false},
wantTPR: []float64{math.NaN(), math.NaN()}, cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
wantFPR: []float64{0, 1}, wantTPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
wantFPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 2, 1.8, 1.6, 1.4, 1.2},
}, },
{ // 18 { // 18
y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10}, y: []float64{1},
c: []bool{true, false, true, true, false, false, true}, c: []bool{false},
cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, wantTPR: []float64{math.NaN(), math.NaN()},
wantTPR: []float64{0, 0.75, 0.75, 0.75, 1}, wantFPR: []float64{0, 1},
wantFPR: []float64{0, 1, 1, 1, 1}, wantThresh: []float64{math.Inf(1), 1},
}, },
{ // 19 { // 19
y: []float64{}, y: []float64{1},
c: []bool{}, c: []bool{false},
wantTPR: nil, cutoffs: []float64{-1, 1},
wantFPR: nil, wantTPR: []float64{math.NaN(), math.NaN()},
wantFPR: []float64{0, 1},
wantThresh: []float64{math.Inf(1), 1},
}, },
{ // 20 { // 20
y: []float64{}, y: []float64{1},
c: []bool{}, c: []bool{true},
cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, wantTPR: []float64{0, 1},
wantTPR: nil, wantFPR: []float64{math.NaN(), math.NaN()},
wantFPR: nil, wantThresh: []float64{math.Inf(1), 1},
},
{ // 21
y: []float64{},
c: []bool{},
wantTPR: nil,
wantFPR: nil,
wantThresh: nil,
},
{ // 22
y: []float64{},
c: []bool{},
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
wantTPR: nil,
wantFPR: nil,
wantThresh: nil,
}, },
} }
for i, test := range cases { for i, test := range cases {
gotTPR, gotFPR := ROC(test.cutoffs, test.y, test.c, test.w) gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w)
if !floats.Same(gotTPR, test.wantTPR) { if !floats.Same(gotTPR, test.wantTPR) {
t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR) t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR)
} }
if !floats.Same(gotFPR, test.wantFPR) { if !floats.Same(gotFPR, test.wantFPR) {
t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR) t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR)
} }
if !floats.Same(gotThresh, test.wantThresh) {
t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh)
}
} }
} }