stat: invert class semantics and return cutoffs

This commit is contained in:
Dan Kortschak
2019-04-15 20:19:50 +09:30
parent e080754f81
commit 50ee437d7b
3 changed files with 234 additions and 158 deletions

View File

@@ -12,7 +12,9 @@ import (
// ROC returns paired false positive rate (FPR) and true positive rate
// (TPR) values corresponding to cutoff points on the receiver operator
// characteristic (ROC) curve obtained when y is treated as a binary
// classifier for classes with weights.
// classifier for classes with weights. The cutoff thresholds used to
// calculate the ROC are returned in thresh such that tpr[i] and fpr[i]
// are the true and false positive rates for y >= thresh[i].
//
// The input y and cutoffs must be sorted, and values in y must correspond
// to values in classes and weights. SortWeightedLabeled can be used to
@@ -34,7 +36,7 @@ import (
//
// More details about ROC curves are available at
// https://en.wikipedia.org/wiki/Receiver_operating_characteristic
func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []float64) {
func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) {
if len(y) != len(classes) {
panic("stat: slice length mismatch")
}
@@ -48,47 +50,50 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []fl
panic("stat: cutoff values must be sorted ascending")
}
if len(y) == 0 {
return nil, nil
return nil, nil, nil
}
var bin int
if len(cutoffs) == 0 {
if cutoffs == nil || cap(cutoffs) < len(y)+1 {
cutoffs = make([]float64, len(y)+1)
} else {
cutoffs = cutoffs[:len(y)+1]
}
cutoffs[0] = math.Nextafter(y[0], y[0]-1)
// Choose all possible cutoffs but remove duplicate values
// in y.
for i, u := range y {
if i != 0 && u != y[i-1] {
cutoffs[0] = math.Inf(-1)
// Choose all possible cutoffs for unique values in y.
bin := 1
cutoffs[bin] = y[0]
for i, u := range y[1:] {
if u == y[i] {
continue
}
bin++
cutoffs[bin] = u
}
cutoffs[bin+1] = u
}
cutoffs = cutoffs[0 : bin+2]
cutoffs = cutoffs[:bin+1]
} else {
// Don't mutate the provided cutoffs.
tmp := cutoffs
cutoffs = make([]float64, len(cutoffs))
copy(cutoffs, tmp)
}
tpr = make([]float64, len(cutoffs))
fpr = make([]float64, len(cutoffs))
bin = 0
var bin int
var nPos, nNeg float64
for i, u := range classes {
// Update the bin until it matches the next y value
// skipping empty bins.
for bin < len(cutoffs) && y[i] > cutoffs[bin] {
if bin == len(cutoffs)-1 {
break
}
for bin < len(cutoffs)-1 && y[i] > cutoffs[bin] {
bin++
tpr[bin] = tpr[bin-1]
fpr[bin] = fpr[bin-1]
}
var posWeight, negWeight float64 = 0, 1
posWeight, negWeight := 1.0, 0.0
if weights != nil {
negWeight = weights[i]
posWeight = weights[i]
}
if u {
if !u {
posWeight, negWeight = negWeight, posWeight
}
nPos += posWeight
@@ -103,8 +108,18 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []fl
invPos := 1 / nPos
for i := range tpr {
tpr[i] *= invPos
tpr[i] = 1 - tpr[i]
fpr[i] *= invNeg
fpr[i] = 1 - fpr[i]
}
for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 {
tpr[i], tpr[j] = tpr[j], tpr[i]
fpr[i], fpr[j] = fpr[j], fpr[i]
}
for i, j := 1, len(cutoffs)-1; i < j; i, j = i+1, j-1 {
cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i]
}
cutoffs[0] = math.Inf(1)
return tpr, fpr
return tpr, fpr, cutoffs
}

View File

@@ -15,67 +15,83 @@ import (
func ExampleROC_weighted() {
y := []float64{0, 3, 5, 6, 7.5, 8}
classes := []bool{true, false, true, false, false, false}
classes := []bool{false, true, false, true, true, true}
weights := []float64{4, 1, 6, 3, 2, 2}
tpr, fpr := stat.ROC(nil, y, classes, weights)
tpr, fpr, _ := stat.ROC(nil, y, classes, weights)
fmt.Printf("true positive rate: %v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr)
// Output:
// true positive rate: [0 0.4 0.4 1 1 1 1]
// false positive rate: [0 0 0.125 0.125 0.5 0.75 1]
// true positive rate: [0 0.25 0.5 0.875 0.875 1 1]
// false positive rate: [0 0 0 0 0.6 0.6 1]
}
func ExampleROC_unweighted() {
y := []float64{0, 3, 5, 6, 7.5, 8}
classes := []bool{true, false, true, false, false, false}
classes := []bool{false, true, false, true, true, true}
tpr, fpr := stat.ROC(nil, y, classes, nil)
tpr, fpr, _ := stat.ROC(nil, y, classes, nil)
fmt.Printf("true positive rate: %v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr)
// Output:
// true positive rate: [0 0.5 0.5 1 1 1 1]
// false positive rate: [0 0 0.25 0.25 0.5 0.75 1]
// true positive rate: [0 0.25 0.5 0.75 0.75 1 1]
// false positive rate: [0 0 0 0 0.5 0.5 1]
}
func ExampleROC_threshold() {
y := []float64{0.1, 0.4, 0.35, 0.8}
classes := []bool{false, false, true, true}
stat.SortWeightedLabeled(y, classes, nil)
tpr, fpr, thresh := stat.ROC(nil, y, classes, nil)
fmt.Printf("true positive rate: %v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr)
fmt.Printf("cutoff thresholds: %v\n", thresh)
// Output:
// true positive rate: [0 0.5 0.5 1 1]
// false positive rate: [0 0 0.5 0.5 1]
// cutoff thresholds: [+Inf 0.8 0.4 0.35 0.1]
}
func ExampleROC_unsorted() {
y := []float64{8, 7.5, 6, 5, 3, 0}
classes := []bool{false, false, false, true, false, true}
classes := []bool{true, true, true, false, true, false}
weights := []float64{2, 2, 3, 6, 1, 4}
stat.SortWeightedLabeled(y, classes, weights)
tpr, fpr := stat.ROC(nil, y, classes, weights)
tpr, fpr, _ := stat.ROC(nil, y, classes, weights)
fmt.Printf("true positive rate: %v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr)
// Output:
// true positive rate: [0 0.4 0.4 1 1 1 1]
// false positive rate: [0 0 0.125 0.125 0.5 0.75 1]
// true positive rate: [0 0.25 0.5 0.875 0.875 1 1]
// false positive rate: [0 0 0 0 0.6 0.6 1]
}
func ExampleROC_knownCutoffs() {
y := []float64{8, 7.5, 6, 5, 3, 0}
classes := []bool{false, false, false, true, false, true}
classes := []bool{true, true, true, false, true, false}
weights := []float64{2, 2, 3, 6, 1, 4}
cutoffs := []float64{-1, 3, 4}
stat.SortWeightedLabeled(y, classes, weights)
tpr, fpr := stat.ROC(cutoffs, y, classes, weights)
tpr, fpr, _ := stat.ROC(cutoffs, y, classes, weights)
fmt.Printf("true positive rate: %v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr)
// Output:
// true positive rate: [0 0.4 0.4]
// false positive rate: [0 0.125 0.125]
// true positive rate: [0.875 0.875 1]
// false positive rate: [0.6 0.6 1]
}
func ExampleROC_equallySpacedCutoffs() {
y := []float64{8, 7.5, 6, 5, 3, 0}
classes := []bool{false, false, false, true, false, true}
classes := []bool{true, true, true, false, true, true}
weights := []float64{2, 2, 3, 6, 1, 4}
n := 9
@@ -83,20 +99,20 @@ func ExampleROC_equallySpacedCutoffs() {
cutoffs := make([]float64, n)
floats.Span(cutoffs, math.Nextafter(y[0], y[0]-1), y[len(y)-1])
tpr, fpr := stat.ROC(cutoffs, y, classes, weights)
fmt.Printf("true positive rate: %v\n", tpr)
fmt.Printf("false positive rate: %v\n", fpr)
tpr, fpr, _ := stat.ROC(cutoffs, y, classes, weights)
fmt.Printf("true positive rate: %.3v\n", tpr)
fmt.Printf("false positive rate: %.3v\n", fpr)
// Output:
// true positive rate: [0 0.4 0.4 0.4 0.4 1 1 1 1]
// false positive rate: [0 0 0 0.125 0.125 0.125 0.5 0.5 1]
// true positive rate: [0 0.333 0.333 0.583 0.583 0.583 0.667 0.667 1]
// false positive rate: [0 0 0 0 1 1 1 1 1]
}
func ExampleROC_aUC() {
y := []float64{0.1, 0.35, 0.4, 0.8}
classes := []bool{true, false, true, false}
tpr, fpr := stat.ROC(nil, y, classes, nil)
tpr, fpr, _ := stat.ROC(nil, y, classes, nil)
// Compute Area Under Curve.
auc := integrate.Trapezoidal(fpr, tpr)
@@ -105,7 +121,7 @@ func ExampleROC_aUC() {
fmt.Printf("auc: %v\n", auc)
// Output:
// true positive rate: [0 0.5 0.5 1 1]
// false positive rate: [0 0 0.5 0.5 1]
// auc: 0.75
// true positive rate: [0 0 0.5 0.5 1]
// false positive rate: [0 0.5 0.5 1 1]
// auc: 0.25
}

View File

@@ -11,7 +11,6 @@ import (
"gonum.org/v1/gonum/floats"
)
// Test cases were calculated manually.
func TestROC(t *testing.T) {
cases := []struct {
y []float64
@@ -20,160 +19,206 @@ func TestROC(t *testing.T) {
cutoffs []float64
wantTPR []float64
wantFPR []float64
wantThresh []float64
}{
// Test cases were calculated using sklearn metrics.roc_curve when
// cutoffs is nil. Where cutoffs is not nil, a visual inspection is
// used.
// Some differences exist between unweighted ROCs from our function
// and metrics.roc_curve which appears to use integer cutoffs in that
// case. sklearn also appears to do some magic that trims leading zeros
// sometimes.
{ // 0
y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false},
wantTPR: []float64{0, 0.5, 0.5, 1, 1, 1, 1},
wantFPR: []float64{0, 0, 0.25, 0.25, 0.5, 0.75, 1},
c: []bool{false, true, false, true, true, true},
wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
},
{ // 1
y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false},
c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2},
wantTPR: []float64{0, 0.4, 0.4, 1, 1, 1, 1},
wantFPR: []float64{0, 0, 0.125, 0.125, 0.5, 0.75, 1},
wantTPR: []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1},
wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
},
{ // 2
y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false},
c: []bool{false, true, false, true, true, true},
cutoffs: []float64{-1, 2, 4, 6, 8},
wantTPR: []float64{0, 0.5, 0.5, 1, 1},
wantFPR: []float64{0, 0, 0.25, 0.5, 1},
wantTPR: []float64{0, 0.5, 0.75, 1, 1},
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
},
{ // 3
y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false},
c: []bool{false, true, false, true, true, true},
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0.25, 0.25, 0.25, 0.5, 0.5, 1},
wantTPR: []float64{0, 0.5, 0.5, 0.75, 0.75, 0.75, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
},
{ // 4
y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false},
c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2},
cutoffs: []float64{-1, 2, 4, 6, 8},
wantTPR: []float64{0, 0.4, 0.4, 1, 1},
wantFPR: []float64{0, 0, 0.125, 0.5, 1},
wantTPR: []float64{0, 0.5, 0.875, 1, 1},
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
},
{ // 5
y: []float64{0, 3, 5, 6, 7.5, 8},
c: []bool{true, false, true, false, false, false},
c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2},
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
wantTPR: []float64{0, 0.4, 0.4, 0.4, 0.4, 1, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0.125, 0.125, 0.125, 0.5, 0.5, 1},
wantTPR: []float64{0, 0.5, 0.5, 0.875, 0.875, 0.875, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
},
{ // 6
y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false},
wantTPR: []float64{0, 0.5, 0.5, 1, 1},
wantFPR: []float64{0, 0, 0.25, 0.75, 1},
c: []bool{false, true, false, true, true, true},
wantTPR: []float64{0, 0.25, 0.75, 1, 1},
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
},
{ // 7
y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false},
c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2},
wantTPR: []float64{0, 0.4, 0.4, 1, 1},
wantFPR: []float64{0, 0, 0.125, 0.75, 1},
wantTPR: []float64{0, 0.25, 0.875, 1, 1},
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
},
{ // 8
y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false},
c: []bool{false, true, false, true, true, true},
cutoffs: []float64{-1, 2, 4, 6, 8},
wantTPR: []float64{0, 0.5, 0.5, 1, 1},
wantFPR: []float64{0, 0, 0.25, 0.75, 1},
wantTPR: []float64{0, 0.25, 0.75, 1, 1},
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
},
{ // 9
y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false},
c: []bool{false, true, false, true, true, true},
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 0.5, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0.25, 0.25, 0.25, 0.75, 0.75, 1},
wantTPR: []float64{0, 0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
},
{ // 10
y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false},
c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2},
cutoffs: []float64{-1, 2, 4, 6, 8},
wantTPR: []float64{0, 0.4, 0.4, 1, 1},
wantFPR: []float64{0, 0, 0.125, 0.75, 1},
wantTPR: []float64{0, 0.25, 0.875, 1, 1},
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
},
{ // 11
y: []float64{0, 3, 6, 6, 6, 8},
c: []bool{true, false, true, false, false, false},
c: []bool{false, true, false, true, true, true},
w: []float64{4, 1, 6, 3, 2, 2},
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
wantTPR: []float64{0, 0.4, 0.4, 0.4, 0.4, 0.4, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0.125, 0.125, 0.125, 0.75, 0.75, 1},
wantTPR: []float64{0, 0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
},
{ // 12
y: []float64{1, 2},
c: []bool{true, true},
wantTPR: []float64{0, 0.5, 1},
wantFPR: []float64{math.NaN(), math.NaN(), math.NaN()},
y: []float64{0.1, 0.35, 0.4, 0.8},
c: []bool{true, false, true, false},
wantTPR: []float64{0, 0, 0.5, 0.5, 1},
wantFPR: []float64{0, 0.5, 0.5, 1, 1},
wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
},
{ // 13
y: []float64{1, 2},
c: []bool{true, true},
cutoffs: []float64{-1, 2},
wantTPR: []float64{0, 1},
wantFPR: []float64{math.NaN(), math.NaN()},
y: []float64{0.1, 0.35, 0.4, 0.8},
c: []bool{false, false, true, true},
wantTPR: []float64{0, 0.5, 1, 1, 1},
wantFPR: []float64{0, 0, 0, 0.5, 1},
wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
},
{ // 14
y: []float64{1, 2},
c: []bool{true, true},
cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1},
wantFPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
c: []bool{false, true, false, false, true, true, false},
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
wantTPR: []float64{0, 0, 0, 0, 1},
wantFPR: []float64{0, 0.25, 0.25, 0.25, 1},
wantThresh: []float64{math.Inf(1), 10, 7.5, 5, 2.5},
},
{ // 15
y: []float64{1},
c: []bool{true},
wantTPR: []float64{0, 1},
wantFPR: []float64{math.NaN(), math.NaN()},
y: []float64{1, 2},
c: []bool{false, false},
wantTPR: []float64{math.NaN(), math.NaN(), math.NaN()},
wantFPR: []float64{0, 0.5, 1},
wantThresh: []float64{math.Inf(1), 2, 1},
},
{ // 16
y: []float64{1},
c: []bool{true},
cutoffs: []float64{-1, 1},
wantTPR: []float64{0, 1},
wantFPR: []float64{math.NaN(), math.NaN()},
y: []float64{1, 2},
c: []bool{false, false},
cutoffs: []float64{-1, 2},
wantTPR: []float64{math.NaN(), math.NaN()},
wantFPR: []float64{0, 1},
wantThresh: []float64{math.Inf(1), 2},
},
{ // 17
y: []float64{1, 2},
c: []bool{false, false},
cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
wantTPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
wantFPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1},
wantThresh: []float64{math.Inf(1), 2, 1.8, 1.6, 1.4, 1.2},
},
{ // 18
y: []float64{1},
c: []bool{false},
wantTPR: []float64{math.NaN(), math.NaN()},
wantFPR: []float64{0, 1},
},
{ // 18
y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
c: []bool{true, false, true, true, false, false, true},
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
wantTPR: []float64{0, 0.75, 0.75, 0.75, 1},
wantFPR: []float64{0, 1, 1, 1, 1},
wantThresh: []float64{math.Inf(1), 1},
},
{ // 19
y: []float64{1},
c: []bool{false},
cutoffs: []float64{-1, 1},
wantTPR: []float64{math.NaN(), math.NaN()},
wantFPR: []float64{0, 1},
wantThresh: []float64{math.Inf(1), 1},
},
{ // 20
y: []float64{1},
c: []bool{true},
wantTPR: []float64{0, 1},
wantFPR: []float64{math.NaN(), math.NaN()},
wantThresh: []float64{math.Inf(1), 1},
},
{ // 21
y: []float64{},
c: []bool{},
wantTPR: nil,
wantFPR: nil,
wantThresh: nil,
},
{ // 20
{ // 22
y: []float64{},
c: []bool{},
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
wantTPR: nil,
wantFPR: nil,
wantThresh: nil,
},
}
for i, test := range cases {
gotTPR, gotFPR := ROC(test.cutoffs, test.y, test.c, test.w)
gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w)
if !floats.Same(gotTPR, test.wantTPR) {
t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR)
}
if !floats.Same(gotFPR, test.wantFPR) {
t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR)
}
if !floats.Same(gotThresh, test.wantThresh) {
t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh)
}
}
}