mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 15:47:01 +08:00
stat: invert class semantics and return cutoffs
This commit is contained in:
57
stat/roc.go
57
stat/roc.go
@@ -12,7 +12,9 @@ import (
|
|||||||
// ROC returns paired false positive rate (FPR) and true positive rate
|
// ROC returns paired false positive rate (FPR) and true positive rate
|
||||||
// (TPR) values corresponding to cutoff points on the receiver operator
|
// (TPR) values corresponding to cutoff points on the receiver operator
|
||||||
// characteristic (ROC) curve obtained when y is treated as a binary
|
// characteristic (ROC) curve obtained when y is treated as a binary
|
||||||
// classifier for classes with weights.
|
// classifier for classes with weights. The cutoff thresholds used to
|
||||||
|
// calculate the ROC are returned in thresh such that tpr[i] and fpr[i]
|
||||||
|
// are the true and false positive rates for y >= thresh[i].
|
||||||
//
|
//
|
||||||
// The input y and cutoffs must be sorted, and values in y must correspond
|
// The input y and cutoffs must be sorted, and values in y must correspond
|
||||||
// to values in classes and weights. SortWeightedLabeled can be used to
|
// to values in classes and weights. SortWeightedLabeled can be used to
|
||||||
@@ -34,7 +36,7 @@ import (
|
|||||||
//
|
//
|
||||||
// More details about ROC curves are available at
|
// More details about ROC curves are available at
|
||||||
// https://en.wikipedia.org/wiki/Receiver_operating_characteristic
|
// https://en.wikipedia.org/wiki/Receiver_operating_characteristic
|
||||||
func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []float64) {
|
func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) {
|
||||||
if len(y) != len(classes) {
|
if len(y) != len(classes) {
|
||||||
panic("stat: slice length mismatch")
|
panic("stat: slice length mismatch")
|
||||||
}
|
}
|
||||||
@@ -48,47 +50,50 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []fl
|
|||||||
panic("stat: cutoff values must be sorted ascending")
|
panic("stat: cutoff values must be sorted ascending")
|
||||||
}
|
}
|
||||||
if len(y) == 0 {
|
if len(y) == 0 {
|
||||||
return nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
var bin int
|
|
||||||
if len(cutoffs) == 0 {
|
if len(cutoffs) == 0 {
|
||||||
if cutoffs == nil || cap(cutoffs) < len(y)+1 {
|
if cutoffs == nil || cap(cutoffs) < len(y)+1 {
|
||||||
cutoffs = make([]float64, len(y)+1)
|
cutoffs = make([]float64, len(y)+1)
|
||||||
} else {
|
} else {
|
||||||
cutoffs = cutoffs[:len(y)+1]
|
cutoffs = cutoffs[:len(y)+1]
|
||||||
}
|
}
|
||||||
cutoffs[0] = math.Nextafter(y[0], y[0]-1)
|
cutoffs[0] = math.Inf(-1)
|
||||||
// Choose all possible cutoffs but remove duplicate values
|
// Choose all possible cutoffs for unique values in y.
|
||||||
// in y.
|
bin := 1
|
||||||
for i, u := range y {
|
cutoffs[bin] = y[0]
|
||||||
if i != 0 && u != y[i-1] {
|
for i, u := range y[1:] {
|
||||||
bin++
|
if u == y[i] {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
cutoffs[bin+1] = u
|
bin++
|
||||||
|
cutoffs[bin] = u
|
||||||
}
|
}
|
||||||
cutoffs = cutoffs[0 : bin+2]
|
cutoffs = cutoffs[:bin+1]
|
||||||
|
} else {
|
||||||
|
// Don't mutate the provided cutoffs.
|
||||||
|
tmp := cutoffs
|
||||||
|
cutoffs = make([]float64, len(cutoffs))
|
||||||
|
copy(cutoffs, tmp)
|
||||||
}
|
}
|
||||||
|
|
||||||
tpr = make([]float64, len(cutoffs))
|
tpr = make([]float64, len(cutoffs))
|
||||||
fpr = make([]float64, len(cutoffs))
|
fpr = make([]float64, len(cutoffs))
|
||||||
bin = 0
|
var bin int
|
||||||
var nPos, nNeg float64
|
var nPos, nNeg float64
|
||||||
for i, u := range classes {
|
for i, u := range classes {
|
||||||
// Update the bin until it matches the next y value
|
// Update the bin until it matches the next y value
|
||||||
// skipping empty bins.
|
// skipping empty bins.
|
||||||
for bin < len(cutoffs) && y[i] > cutoffs[bin] {
|
for bin < len(cutoffs)-1 && y[i] > cutoffs[bin] {
|
||||||
if bin == len(cutoffs)-1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
bin++
|
bin++
|
||||||
tpr[bin] = tpr[bin-1]
|
tpr[bin] = tpr[bin-1]
|
||||||
fpr[bin] = fpr[bin-1]
|
fpr[bin] = fpr[bin-1]
|
||||||
}
|
}
|
||||||
var posWeight, negWeight float64 = 0, 1
|
posWeight, negWeight := 1.0, 0.0
|
||||||
if weights != nil {
|
if weights != nil {
|
||||||
negWeight = weights[i]
|
posWeight = weights[i]
|
||||||
}
|
}
|
||||||
if u {
|
if !u {
|
||||||
posWeight, negWeight = negWeight, posWeight
|
posWeight, negWeight = negWeight, posWeight
|
||||||
}
|
}
|
||||||
nPos += posWeight
|
nPos += posWeight
|
||||||
@@ -103,8 +108,18 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []fl
|
|||||||
invPos := 1 / nPos
|
invPos := 1 / nPos
|
||||||
for i := range tpr {
|
for i := range tpr {
|
||||||
tpr[i] *= invPos
|
tpr[i] *= invPos
|
||||||
|
tpr[i] = 1 - tpr[i]
|
||||||
fpr[i] *= invNeg
|
fpr[i] *= invNeg
|
||||||
|
fpr[i] = 1 - fpr[i]
|
||||||
}
|
}
|
||||||
|
for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
tpr[i], tpr[j] = tpr[j], tpr[i]
|
||||||
|
fpr[i], fpr[j] = fpr[j], fpr[i]
|
||||||
|
}
|
||||||
|
for i, j := 1, len(cutoffs)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i]
|
||||||
|
}
|
||||||
|
cutoffs[0] = math.Inf(1)
|
||||||
|
|
||||||
return tpr, fpr
|
return tpr, fpr, cutoffs
|
||||||
}
|
}
|
||||||
|
@@ -15,67 +15,83 @@ import (
|
|||||||
|
|
||||||
func ExampleROC_weighted() {
|
func ExampleROC_weighted() {
|
||||||
y := []float64{0, 3, 5, 6, 7.5, 8}
|
y := []float64{0, 3, 5, 6, 7.5, 8}
|
||||||
classes := []bool{true, false, true, false, false, false}
|
classes := []bool{false, true, false, true, true, true}
|
||||||
weights := []float64{4, 1, 6, 3, 2, 2}
|
weights := []float64{4, 1, 6, 3, 2, 2}
|
||||||
|
|
||||||
tpr, fpr := stat.ROC(nil, y, classes, weights)
|
tpr, fpr, _ := stat.ROC(nil, y, classes, weights)
|
||||||
fmt.Printf("true positive rate: %v\n", tpr)
|
fmt.Printf("true positive rate: %v\n", tpr)
|
||||||
fmt.Printf("false positive rate: %v\n", fpr)
|
fmt.Printf("false positive rate: %v\n", fpr)
|
||||||
|
|
||||||
// Output:
|
// Output:
|
||||||
// true positive rate: [0 0.4 0.4 1 1 1 1]
|
// true positive rate: [0 0.25 0.5 0.875 0.875 1 1]
|
||||||
// false positive rate: [0 0 0.125 0.125 0.5 0.75 1]
|
// false positive rate: [0 0 0 0 0.6 0.6 1]
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleROC_unweighted() {
|
func ExampleROC_unweighted() {
|
||||||
y := []float64{0, 3, 5, 6, 7.5, 8}
|
y := []float64{0, 3, 5, 6, 7.5, 8}
|
||||||
classes := []bool{true, false, true, false, false, false}
|
classes := []bool{false, true, false, true, true, true}
|
||||||
|
|
||||||
tpr, fpr := stat.ROC(nil, y, classes, nil)
|
tpr, fpr, _ := stat.ROC(nil, y, classes, nil)
|
||||||
fmt.Printf("true positive rate: %v\n", tpr)
|
fmt.Printf("true positive rate: %v\n", tpr)
|
||||||
fmt.Printf("false positive rate: %v\n", fpr)
|
fmt.Printf("false positive rate: %v\n", fpr)
|
||||||
|
|
||||||
// Output:
|
// Output:
|
||||||
// true positive rate: [0 0.5 0.5 1 1 1 1]
|
// true positive rate: [0 0.25 0.5 0.75 0.75 1 1]
|
||||||
// false positive rate: [0 0 0.25 0.25 0.5 0.75 1]
|
// false positive rate: [0 0 0 0 0.5 0.5 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleROC_threshold() {
|
||||||
|
y := []float64{0.1, 0.4, 0.35, 0.8}
|
||||||
|
classes := []bool{false, false, true, true}
|
||||||
|
stat.SortWeightedLabeled(y, classes, nil)
|
||||||
|
|
||||||
|
tpr, fpr, thresh := stat.ROC(nil, y, classes, nil)
|
||||||
|
fmt.Printf("true positive rate: %v\n", tpr)
|
||||||
|
fmt.Printf("false positive rate: %v\n", fpr)
|
||||||
|
fmt.Printf("cutoff thresholds: %v\n", thresh)
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// true positive rate: [0 0.5 0.5 1 1]
|
||||||
|
// false positive rate: [0 0 0.5 0.5 1]
|
||||||
|
// cutoff thresholds: [+Inf 0.8 0.4 0.35 0.1]
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleROC_unsorted() {
|
func ExampleROC_unsorted() {
|
||||||
y := []float64{8, 7.5, 6, 5, 3, 0}
|
y := []float64{8, 7.5, 6, 5, 3, 0}
|
||||||
classes := []bool{false, false, false, true, false, true}
|
classes := []bool{true, true, true, false, true, false}
|
||||||
weights := []float64{2, 2, 3, 6, 1, 4}
|
weights := []float64{2, 2, 3, 6, 1, 4}
|
||||||
|
|
||||||
stat.SortWeightedLabeled(y, classes, weights)
|
stat.SortWeightedLabeled(y, classes, weights)
|
||||||
|
|
||||||
tpr, fpr := stat.ROC(nil, y, classes, weights)
|
tpr, fpr, _ := stat.ROC(nil, y, classes, weights)
|
||||||
fmt.Printf("true positive rate: %v\n", tpr)
|
fmt.Printf("true positive rate: %v\n", tpr)
|
||||||
fmt.Printf("false positive rate: %v\n", fpr)
|
fmt.Printf("false positive rate: %v\n", fpr)
|
||||||
|
|
||||||
// Output:
|
// Output:
|
||||||
// true positive rate: [0 0.4 0.4 1 1 1 1]
|
// true positive rate: [0 0.25 0.5 0.875 0.875 1 1]
|
||||||
// false positive rate: [0 0 0.125 0.125 0.5 0.75 1]
|
// false positive rate: [0 0 0 0 0.6 0.6 1]
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleROC_knownCutoffs() {
|
func ExampleROC_knownCutoffs() {
|
||||||
y := []float64{8, 7.5, 6, 5, 3, 0}
|
y := []float64{8, 7.5, 6, 5, 3, 0}
|
||||||
classes := []bool{false, false, false, true, false, true}
|
classes := []bool{true, true, true, false, true, false}
|
||||||
weights := []float64{2, 2, 3, 6, 1, 4}
|
weights := []float64{2, 2, 3, 6, 1, 4}
|
||||||
cutoffs := []float64{-1, 3, 4}
|
cutoffs := []float64{-1, 3, 4}
|
||||||
|
|
||||||
stat.SortWeightedLabeled(y, classes, weights)
|
stat.SortWeightedLabeled(y, classes, weights)
|
||||||
|
|
||||||
tpr, fpr := stat.ROC(cutoffs, y, classes, weights)
|
tpr, fpr, _ := stat.ROC(cutoffs, y, classes, weights)
|
||||||
fmt.Printf("true positive rate: %v\n", tpr)
|
fmt.Printf("true positive rate: %v\n", tpr)
|
||||||
fmt.Printf("false positive rate: %v\n", fpr)
|
fmt.Printf("false positive rate: %v\n", fpr)
|
||||||
|
|
||||||
// Output:
|
// Output:
|
||||||
// true positive rate: [0 0.4 0.4]
|
// true positive rate: [0.875 0.875 1]
|
||||||
// false positive rate: [0 0.125 0.125]
|
// false positive rate: [0.6 0.6 1]
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleROC_equallySpacedCutoffs() {
|
func ExampleROC_equallySpacedCutoffs() {
|
||||||
y := []float64{8, 7.5, 6, 5, 3, 0}
|
y := []float64{8, 7.5, 6, 5, 3, 0}
|
||||||
classes := []bool{false, false, false, true, false, true}
|
classes := []bool{true, true, true, false, true, true}
|
||||||
weights := []float64{2, 2, 3, 6, 1, 4}
|
weights := []float64{2, 2, 3, 6, 1, 4}
|
||||||
n := 9
|
n := 9
|
||||||
|
|
||||||
@@ -83,20 +99,20 @@ func ExampleROC_equallySpacedCutoffs() {
|
|||||||
cutoffs := make([]float64, n)
|
cutoffs := make([]float64, n)
|
||||||
floats.Span(cutoffs, math.Nextafter(y[0], y[0]-1), y[len(y)-1])
|
floats.Span(cutoffs, math.Nextafter(y[0], y[0]-1), y[len(y)-1])
|
||||||
|
|
||||||
tpr, fpr := stat.ROC(cutoffs, y, classes, weights)
|
tpr, fpr, _ := stat.ROC(cutoffs, y, classes, weights)
|
||||||
fmt.Printf("true positive rate: %v\n", tpr)
|
fmt.Printf("true positive rate: %.3v\n", tpr)
|
||||||
fmt.Printf("false positive rate: %v\n", fpr)
|
fmt.Printf("false positive rate: %.3v\n", fpr)
|
||||||
|
|
||||||
// Output:
|
// Output:
|
||||||
// true positive rate: [0 0.4 0.4 0.4 0.4 1 1 1 1]
|
// true positive rate: [0 0.333 0.333 0.583 0.583 0.583 0.667 0.667 1]
|
||||||
// false positive rate: [0 0 0 0.125 0.125 0.125 0.5 0.5 1]
|
// false positive rate: [0 0 0 0 1 1 1 1 1]
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleROC_aUC() {
|
func ExampleROC_aUC() {
|
||||||
y := []float64{0.1, 0.35, 0.4, 0.8}
|
y := []float64{0.1, 0.35, 0.4, 0.8}
|
||||||
classes := []bool{true, false, true, false}
|
classes := []bool{true, false, true, false}
|
||||||
|
|
||||||
tpr, fpr := stat.ROC(nil, y, classes, nil)
|
tpr, fpr, _ := stat.ROC(nil, y, classes, nil)
|
||||||
|
|
||||||
// Compute Area Under Curve.
|
// Compute Area Under Curve.
|
||||||
auc := integrate.Trapezoidal(fpr, tpr)
|
auc := integrate.Trapezoidal(fpr, tpr)
|
||||||
@@ -105,7 +121,7 @@ func ExampleROC_aUC() {
|
|||||||
fmt.Printf("auc: %v\n", auc)
|
fmt.Printf("auc: %v\n", auc)
|
||||||
|
|
||||||
// Output:
|
// Output:
|
||||||
// true positive rate: [0 0.5 0.5 1 1]
|
// true positive rate: [0 0 0.5 0.5 1]
|
||||||
// false positive rate: [0 0 0.5 0.5 1]
|
// false positive rate: [0 0.5 0.5 1 1]
|
||||||
// auc: 0.75
|
// auc: 0.25
|
||||||
}
|
}
|
||||||
|
267
stat/roc_test.go
267
stat/roc_test.go
@@ -11,169 +11,214 @@ import (
|
|||||||
"gonum.org/v1/gonum/floats"
|
"gonum.org/v1/gonum/floats"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Test cases were calculated manually.
|
|
||||||
func TestROC(t *testing.T) {
|
func TestROC(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
y []float64
|
y []float64
|
||||||
c []bool
|
c []bool
|
||||||
w []float64
|
w []float64
|
||||||
cutoffs []float64
|
cutoffs []float64
|
||||||
wantTPR []float64
|
wantTPR []float64
|
||||||
wantFPR []float64
|
wantFPR []float64
|
||||||
|
wantThresh []float64
|
||||||
}{
|
}{
|
||||||
|
// Test cases were calculated using sklearn metrics.roc_curve when
|
||||||
|
// cutoffs is nil. Where cutoffs is not nil, a visual inspection is
|
||||||
|
// used.
|
||||||
|
// Some differences exist between unweighted ROCs from our function
|
||||||
|
// and metrics.roc_curve which appears to use integer cutoffs in that
|
||||||
|
// case. sklearn also appears to do some magic that trims leading zeros
|
||||||
|
// sometimes.
|
||||||
{ // 0
|
{ // 0
|
||||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
wantTPR: []float64{0, 0.5, 0.5, 1, 1, 1, 1},
|
wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0.25, 0.25, 0.5, 0.75, 1},
|
wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
|
||||||
},
|
},
|
||||||
{ // 1
|
{ // 1
|
||||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
w: []float64{4, 1, 6, 3, 2, 2},
|
w: []float64{4, 1, 6, 3, 2, 2},
|
||||||
wantTPR: []float64{0, 0.4, 0.4, 1, 1, 1, 1},
|
wantTPR: []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0.125, 0.125, 0.5, 0.75, 1},
|
wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
|
||||||
},
|
},
|
||||||
{ // 2
|
{ // 2
|
||||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
cutoffs: []float64{-1, 2, 4, 6, 8},
|
cutoffs: []float64{-1, 2, 4, 6, 8},
|
||||||
wantTPR: []float64{0, 0.5, 0.5, 1, 1},
|
wantTPR: []float64{0, 0.5, 0.75, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0.25, 0.5, 1},
|
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
||||||
},
|
},
|
||||||
{ // 3
|
{ // 3
|
||||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1, 1, 1, 1},
|
wantTPR: []float64{0, 0.5, 0.5, 0.75, 0.75, 0.75, 1, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0, 0.25, 0.25, 0.25, 0.5, 0.5, 1},
|
wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 0.5, 0.5, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
||||||
},
|
},
|
||||||
{ // 4
|
{ // 4
|
||||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
w: []float64{4, 1, 6, 3, 2, 2},
|
w: []float64{4, 1, 6, 3, 2, 2},
|
||||||
cutoffs: []float64{-1, 2, 4, 6, 8},
|
cutoffs: []float64{-1, 2, 4, 6, 8},
|
||||||
wantTPR: []float64{0, 0.4, 0.4, 1, 1},
|
wantTPR: []float64{0, 0.5, 0.875, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0.125, 0.5, 1},
|
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
||||||
},
|
},
|
||||||
{ // 5
|
{ // 5
|
||||||
y: []float64{0, 3, 5, 6, 7.5, 8},
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
w: []float64{4, 1, 6, 3, 2, 2},
|
w: []float64{4, 1, 6, 3, 2, 2},
|
||||||
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
wantTPR: []float64{0, 0.4, 0.4, 0.4, 0.4, 1, 1, 1, 1},
|
wantTPR: []float64{0, 0.5, 0.5, 0.875, 0.875, 0.875, 1, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0, 0.125, 0.125, 0.125, 0.5, 0.5, 1},
|
wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 0.6, 0.6, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
||||||
},
|
},
|
||||||
{ // 6
|
{ // 6
|
||||||
y: []float64{0, 3, 6, 6, 6, 8},
|
y: []float64{0, 3, 6, 6, 6, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
wantTPR: []float64{0, 0.5, 0.5, 1, 1},
|
wantTPR: []float64{0, 0.25, 0.75, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0.25, 0.75, 1},
|
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
|
||||||
},
|
},
|
||||||
{ // 7
|
{ // 7
|
||||||
y: []float64{0, 3, 6, 6, 6, 8},
|
y: []float64{0, 3, 6, 6, 6, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
w: []float64{4, 1, 6, 3, 2, 2},
|
w: []float64{4, 1, 6, 3, 2, 2},
|
||||||
wantTPR: []float64{0, 0.4, 0.4, 1, 1},
|
wantTPR: []float64{0, 0.25, 0.875, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0.125, 0.75, 1},
|
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
|
||||||
},
|
},
|
||||||
{ // 8
|
{ // 8
|
||||||
y: []float64{0, 3, 6, 6, 6, 8},
|
y: []float64{0, 3, 6, 6, 6, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
cutoffs: []float64{-1, 2, 4, 6, 8},
|
cutoffs: []float64{-1, 2, 4, 6, 8},
|
||||||
wantTPR: []float64{0, 0.5, 0.5, 1, 1},
|
wantTPR: []float64{0, 0.25, 0.75, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0.25, 0.75, 1},
|
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
||||||
},
|
},
|
||||||
{ // 9
|
{ // 9
|
||||||
y: []float64{0, 3, 6, 6, 6, 8},
|
y: []float64{0, 3, 6, 6, 6, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 0.5, 1, 1, 1},
|
wantTPR: []float64{0, 0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0, 0.25, 0.25, 0.25, 0.75, 0.75, 1},
|
wantFPR: []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
||||||
},
|
},
|
||||||
{ // 10
|
{ // 10
|
||||||
y: []float64{0, 3, 6, 6, 6, 8},
|
y: []float64{0, 3, 6, 6, 6, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
w: []float64{4, 1, 6, 3, 2, 2},
|
w: []float64{4, 1, 6, 3, 2, 2},
|
||||||
cutoffs: []float64{-1, 2, 4, 6, 8},
|
cutoffs: []float64{-1, 2, 4, 6, 8},
|
||||||
wantTPR: []float64{0, 0.4, 0.4, 1, 1},
|
wantTPR: []float64{0, 0.25, 0.875, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0.125, 0.75, 1},
|
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
||||||
},
|
},
|
||||||
{ // 11
|
{ // 11
|
||||||
y: []float64{0, 3, 6, 6, 6, 8},
|
y: []float64{0, 3, 6, 6, 6, 8},
|
||||||
c: []bool{true, false, true, false, false, false},
|
c: []bool{false, true, false, true, true, true},
|
||||||
w: []float64{4, 1, 6, 3, 2, 2},
|
w: []float64{4, 1, 6, 3, 2, 2},
|
||||||
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
wantTPR: []float64{0, 0.4, 0.4, 0.4, 0.4, 0.4, 1, 1, 1},
|
wantTPR: []float64{0, 0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1},
|
||||||
wantFPR: []float64{0, 0, 0, 0.125, 0.125, 0.125, 0.75, 0.75, 1},
|
wantFPR: []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
||||||
},
|
},
|
||||||
{ // 12
|
{ // 12
|
||||||
y: []float64{1, 2},
|
y: []float64{0.1, 0.35, 0.4, 0.8},
|
||||||
c: []bool{true, true},
|
c: []bool{true, false, true, false},
|
||||||
wantTPR: []float64{0, 0.5, 1},
|
wantTPR: []float64{0, 0, 0.5, 0.5, 1},
|
||||||
wantFPR: []float64{math.NaN(), math.NaN(), math.NaN()},
|
wantFPR: []float64{0, 0.5, 0.5, 1, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
|
||||||
},
|
},
|
||||||
{ // 13
|
{ // 13
|
||||||
y: []float64{1, 2},
|
y: []float64{0.1, 0.35, 0.4, 0.8},
|
||||||
c: []bool{true, true},
|
c: []bool{false, false, true, true},
|
||||||
cutoffs: []float64{-1, 2},
|
wantTPR: []float64{0, 0.5, 1, 1, 1},
|
||||||
wantTPR: []float64{0, 1},
|
wantFPR: []float64{0, 0, 0, 0.5, 1},
|
||||||
wantFPR: []float64{math.NaN(), math.NaN()},
|
wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
|
||||||
},
|
},
|
||||||
{ // 14
|
{ // 14
|
||||||
y: []float64{1, 2},
|
y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
|
||||||
c: []bool{true, true},
|
c: []bool{false, true, false, false, true, true, false},
|
||||||
cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
|
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
|
||||||
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1},
|
wantTPR: []float64{0, 0, 0, 0, 1},
|
||||||
wantFPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
|
wantFPR: []float64{0, 0.25, 0.25, 0.25, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 10, 7.5, 5, 2.5},
|
||||||
},
|
},
|
||||||
{ // 15
|
{ // 15
|
||||||
y: []float64{1},
|
y: []float64{1, 2},
|
||||||
c: []bool{true},
|
c: []bool{false, false},
|
||||||
wantTPR: []float64{0, 1},
|
wantTPR: []float64{math.NaN(), math.NaN(), math.NaN()},
|
||||||
wantFPR: []float64{math.NaN(), math.NaN()},
|
wantFPR: []float64{0, 0.5, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 2, 1},
|
||||||
},
|
},
|
||||||
{ // 16
|
{ // 16
|
||||||
y: []float64{1},
|
y: []float64{1, 2},
|
||||||
c: []bool{true},
|
c: []bool{false, false},
|
||||||
cutoffs: []float64{-1, 1},
|
cutoffs: []float64{-1, 2},
|
||||||
wantTPR: []float64{0, 1},
|
wantTPR: []float64{math.NaN(), math.NaN()},
|
||||||
wantFPR: []float64{math.NaN(), math.NaN()},
|
wantFPR: []float64{0, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 2},
|
||||||
},
|
},
|
||||||
{ // 17
|
{ // 17
|
||||||
y: []float64{1},
|
y: []float64{1, 2},
|
||||||
c: []bool{false},
|
c: []bool{false, false},
|
||||||
wantTPR: []float64{math.NaN(), math.NaN()},
|
cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
|
||||||
wantFPR: []float64{0, 1},
|
wantTPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
|
||||||
|
wantFPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 2, 1.8, 1.6, 1.4, 1.2},
|
||||||
},
|
},
|
||||||
{ // 18
|
{ // 18
|
||||||
y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
|
y: []float64{1},
|
||||||
c: []bool{true, false, true, true, false, false, true},
|
c: []bool{false},
|
||||||
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
|
wantTPR: []float64{math.NaN(), math.NaN()},
|
||||||
wantTPR: []float64{0, 0.75, 0.75, 0.75, 1},
|
wantFPR: []float64{0, 1},
|
||||||
wantFPR: []float64{0, 1, 1, 1, 1},
|
wantThresh: []float64{math.Inf(1), 1},
|
||||||
},
|
},
|
||||||
{ // 19
|
{ // 19
|
||||||
y: []float64{},
|
y: []float64{1},
|
||||||
c: []bool{},
|
c: []bool{false},
|
||||||
wantTPR: nil,
|
cutoffs: []float64{-1, 1},
|
||||||
wantFPR: nil,
|
wantTPR: []float64{math.NaN(), math.NaN()},
|
||||||
|
wantFPR: []float64{0, 1},
|
||||||
|
wantThresh: []float64{math.Inf(1), 1},
|
||||||
},
|
},
|
||||||
{ // 20
|
{ // 20
|
||||||
y: []float64{},
|
y: []float64{1},
|
||||||
c: []bool{},
|
c: []bool{true},
|
||||||
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
|
wantTPR: []float64{0, 1},
|
||||||
wantTPR: nil,
|
wantFPR: []float64{math.NaN(), math.NaN()},
|
||||||
wantFPR: nil,
|
wantThresh: []float64{math.Inf(1), 1},
|
||||||
|
},
|
||||||
|
{ // 21
|
||||||
|
y: []float64{},
|
||||||
|
c: []bool{},
|
||||||
|
wantTPR: nil,
|
||||||
|
wantFPR: nil,
|
||||||
|
wantThresh: nil,
|
||||||
|
},
|
||||||
|
{ // 22
|
||||||
|
y: []float64{},
|
||||||
|
c: []bool{},
|
||||||
|
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
|
||||||
|
wantTPR: nil,
|
||||||
|
wantFPR: nil,
|
||||||
|
wantThresh: nil,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for i, test := range cases {
|
for i, test := range cases {
|
||||||
gotTPR, gotFPR := ROC(test.cutoffs, test.y, test.c, test.w)
|
gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w)
|
||||||
if !floats.Same(gotTPR, test.wantTPR) {
|
if !floats.Same(gotTPR, test.wantTPR) {
|
||||||
t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR)
|
t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR)
|
||||||
}
|
}
|
||||||
if !floats.Same(gotFPR, test.wantFPR) {
|
if !floats.Same(gotFPR, test.wantFPR) {
|
||||||
t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR)
|
t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR)
|
||||||
}
|
}
|
||||||
|
if !floats.Same(gotThresh, test.wantThresh) {
|
||||||
|
t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user