diff --git a/stat/roc.go b/stat/roc.go index ff579720..516a100b 100644 --- a/stat/roc.go +++ b/stat/roc.go @@ -12,7 +12,9 @@ import ( // ROC returns paired false positive rate (FPR) and true positive rate // (TPR) values corresponding to cutoff points on the receiver operator // characteristic (ROC) curve obtained when y is treated as a binary -// classifier for classes with weights. +// classifier for classes with weights. The cutoff thresholds used to +// calculate the ROC are returned in thresh such that tpr[i] and fpr[i] +// are the true and false positive rates for y >= thresh[i]. // // The input y and cutoffs must be sorted, and values in y must correspond // to values in classes and weights. SortWeightedLabeled can be used to @@ -34,7 +36,7 @@ import ( // // More details about ROC curves are available at // https://en.wikipedia.org/wiki/Receiver_operating_characteristic -func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []float64) { +func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) { if len(y) != len(classes) { panic("stat: slice length mismatch") } @@ -48,47 +50,50 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []fl panic("stat: cutoff values must be sorted ascending") } if len(y) == 0 { - return nil, nil + return nil, nil, nil } - var bin int if len(cutoffs) == 0 { if cutoffs == nil || cap(cutoffs) < len(y)+1 { cutoffs = make([]float64, len(y)+1) } else { cutoffs = cutoffs[:len(y)+1] } - cutoffs[0] = math.Nextafter(y[0], y[0]-1) - // Choose all possible cutoffs but remove duplicate values - // in y. - for i, u := range y { - if i != 0 && u != y[i-1] { - bin++ + cutoffs[0] = math.Inf(-1) + // Choose all possible cutoffs for unique values in y. + bin := 1 + cutoffs[bin] = y[0] + for i, u := range y[1:] { + if u == y[i] { + continue } - cutoffs[bin+1] = u + bin++ + cutoffs[bin] = u } - cutoffs = cutoffs[0 : bin+2] + cutoffs = cutoffs[:bin+1] + } else { + // Don't mutate the provided cutoffs. + tmp := cutoffs + cutoffs = make([]float64, len(cutoffs)) + copy(cutoffs, tmp) } tpr = make([]float64, len(cutoffs)) fpr = make([]float64, len(cutoffs)) - bin = 0 + var bin int var nPos, nNeg float64 for i, u := range classes { // Update the bin until it matches the next y value // skipping empty bins. - for bin < len(cutoffs) && y[i] > cutoffs[bin] { - if bin == len(cutoffs)-1 { - break - } + for bin < len(cutoffs)-1 && y[i] > cutoffs[bin] { bin++ tpr[bin] = tpr[bin-1] fpr[bin] = fpr[bin-1] } - var posWeight, negWeight float64 = 0, 1 + posWeight, negWeight := 1.0, 0.0 if weights != nil { - negWeight = weights[i] + posWeight = weights[i] } - if u { + if !u { posWeight, negWeight = negWeight, posWeight } nPos += posWeight @@ -103,8 +108,18 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr []fl invPos := 1 / nPos for i := range tpr { tpr[i] *= invPos + tpr[i] = 1 - tpr[i] fpr[i] *= invNeg + fpr[i] = 1 - fpr[i] } + for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 { + tpr[i], tpr[j] = tpr[j], tpr[i] + fpr[i], fpr[j] = fpr[j], fpr[i] + } + for i, j := 1, len(cutoffs)-1; i < j; i, j = i+1, j-1 { + cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i] + } + cutoffs[0] = math.Inf(1) - return tpr, fpr + return tpr, fpr, cutoffs } diff --git a/stat/roc_example_test.go b/stat/roc_example_test.go index 4d8af8ea..9948f702 100644 --- a/stat/roc_example_test.go +++ b/stat/roc_example_test.go @@ -15,67 +15,83 @@ import ( func ExampleROC_weighted() { y := []float64{0, 3, 5, 6, 7.5, 8} - classes := []bool{true, false, true, false, false, false} + classes := []bool{false, true, false, true, true, true} weights := []float64{4, 1, 6, 3, 2, 2} - tpr, fpr := stat.ROC(nil, y, classes, weights) + tpr, fpr, _ := stat.ROC(nil, y, classes, weights) fmt.Printf("true positive rate: %v\n", tpr) fmt.Printf("false positive rate: %v\n", fpr) // Output: - // true positive rate: [0 0.4 0.4 1 1 1 1] - // false positive rate: [0 0 0.125 0.125 0.5 0.75 1] + // true positive rate: [0 0.25 0.5 0.875 0.875 1 1] + // false positive rate: [0 0 0 0 0.6 0.6 1] } func ExampleROC_unweighted() { y := []float64{0, 3, 5, 6, 7.5, 8} - classes := []bool{true, false, true, false, false, false} + classes := []bool{false, true, false, true, true, true} - tpr, fpr := stat.ROC(nil, y, classes, nil) + tpr, fpr, _ := stat.ROC(nil, y, classes, nil) fmt.Printf("true positive rate: %v\n", tpr) fmt.Printf("false positive rate: %v\n", fpr) // Output: - // true positive rate: [0 0.5 0.5 1 1 1 1] - // false positive rate: [0 0 0.25 0.25 0.5 0.75 1] + // true positive rate: [0 0.25 0.5 0.75 0.75 1 1] + // false positive rate: [0 0 0 0 0.5 0.5 1] +} + +func ExampleROC_threshold() { + y := []float64{0.1, 0.4, 0.35, 0.8} + classes := []bool{false, false, true, true} + stat.SortWeightedLabeled(y, classes, nil) + + tpr, fpr, thresh := stat.ROC(nil, y, classes, nil) + fmt.Printf("true positive rate: %v\n", tpr) + fmt.Printf("false positive rate: %v\n", fpr) + fmt.Printf("cutoff thresholds: %v\n", thresh) + + // Output: + // true positive rate: [0 0.5 0.5 1 1] + // false positive rate: [0 0 0.5 0.5 1] + // cutoff thresholds: [+Inf 0.8 0.4 0.35 0.1] } func ExampleROC_unsorted() { y := []float64{8, 7.5, 6, 5, 3, 0} - classes := []bool{false, false, false, true, false, true} + classes := []bool{true, true, true, false, true, false} weights := []float64{2, 2, 3, 6, 1, 4} stat.SortWeightedLabeled(y, classes, weights) - tpr, fpr := stat.ROC(nil, y, classes, weights) + tpr, fpr, _ := stat.ROC(nil, y, classes, weights) fmt.Printf("true positive rate: %v\n", tpr) fmt.Printf("false positive rate: %v\n", fpr) // Output: - // true positive rate: [0 0.4 0.4 1 1 1 1] - // false positive rate: [0 0 0.125 0.125 0.5 0.75 1] + // true positive rate: [0 0.25 0.5 0.875 0.875 1 1] + // false positive rate: [0 0 0 0 0.6 0.6 1] } func ExampleROC_knownCutoffs() { y := []float64{8, 7.5, 6, 5, 3, 0} - classes := []bool{false, false, false, true, false, true} + classes := []bool{true, true, true, false, true, false} weights := []float64{2, 2, 3, 6, 1, 4} cutoffs := []float64{-1, 3, 4} stat.SortWeightedLabeled(y, classes, weights) - tpr, fpr := stat.ROC(cutoffs, y, classes, weights) + tpr, fpr, _ := stat.ROC(cutoffs, y, classes, weights) fmt.Printf("true positive rate: %v\n", tpr) fmt.Printf("false positive rate: %v\n", fpr) // Output: - // true positive rate: [0 0.4 0.4] - // false positive rate: [0 0.125 0.125] + // true positive rate: [0.875 0.875 1] + // false positive rate: [0.6 0.6 1] } func ExampleROC_equallySpacedCutoffs() { y := []float64{8, 7.5, 6, 5, 3, 0} - classes := []bool{false, false, false, true, false, true} + classes := []bool{true, true, true, false, true, true} weights := []float64{2, 2, 3, 6, 1, 4} n := 9 @@ -83,20 +99,20 @@ func ExampleROC_equallySpacedCutoffs() { cutoffs := make([]float64, n) floats.Span(cutoffs, math.Nextafter(y[0], y[0]-1), y[len(y)-1]) - tpr, fpr := stat.ROC(cutoffs, y, classes, weights) - fmt.Printf("true positive rate: %v\n", tpr) - fmt.Printf("false positive rate: %v\n", fpr) + tpr, fpr, _ := stat.ROC(cutoffs, y, classes, weights) + fmt.Printf("true positive rate: %.3v\n", tpr) + fmt.Printf("false positive rate: %.3v\n", fpr) // Output: - // true positive rate: [0 0.4 0.4 0.4 0.4 1 1 1 1] - // false positive rate: [0 0 0 0.125 0.125 0.125 0.5 0.5 1] + // true positive rate: [0 0.333 0.333 0.583 0.583 0.583 0.667 0.667 1] + // false positive rate: [0 0 0 0 1 1 1 1 1] } func ExampleROC_aUC() { y := []float64{0.1, 0.35, 0.4, 0.8} classes := []bool{true, false, true, false} - tpr, fpr := stat.ROC(nil, y, classes, nil) + tpr, fpr, _ := stat.ROC(nil, y, classes, nil) // Compute Area Under Curve. auc := integrate.Trapezoidal(fpr, tpr) @@ -105,7 +121,7 @@ func ExampleROC_aUC() { fmt.Printf("auc: %v\n", auc) // Output: - // true positive rate: [0 0.5 0.5 1 1] - // false positive rate: [0 0 0.5 0.5 1] - // auc: 0.75 + // true positive rate: [0 0 0.5 0.5 1] + // false positive rate: [0 0.5 0.5 1 1] + // auc: 0.25 } diff --git a/stat/roc_test.go b/stat/roc_test.go index f42b4129..36c658b8 100644 --- a/stat/roc_test.go +++ b/stat/roc_test.go @@ -11,169 +11,214 @@ import ( "gonum.org/v1/gonum/floats" ) -// Test cases were calculated manually. func TestROC(t *testing.T) { cases := []struct { - y []float64 - c []bool - w []float64 - cutoffs []float64 - wantTPR []float64 - wantFPR []float64 + y []float64 + c []bool + w []float64 + cutoffs []float64 + wantTPR []float64 + wantFPR []float64 + wantThresh []float64 }{ + // Test cases were calculated using sklearn metrics.roc_curve when + // cutoffs is nil. Where cutoffs is not nil, a visual inspection is + // used. + // Some differences exist between unweighted ROCs from our function + // and metrics.roc_curve which appears to use integer cutoffs in that + // case. sklearn also appears to do some magic that trims leading zeros + // sometimes. { // 0 - y: []float64{0, 3, 5, 6, 7.5, 8}, - c: []bool{true, false, true, false, false, false}, - wantTPR: []float64{0, 0.5, 0.5, 1, 1, 1, 1}, - wantFPR: []float64{0, 0, 0.25, 0.25, 0.5, 0.75, 1}, + y: []float64{0, 3, 5, 6, 7.5, 8}, + c: []bool{false, true, false, true, true, true}, + wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1}, + wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1}, + wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, }, { // 1 - y: []float64{0, 3, 5, 6, 7.5, 8}, - c: []bool{true, false, true, false, false, false}, - w: []float64{4, 1, 6, 3, 2, 2}, - wantTPR: []float64{0, 0.4, 0.4, 1, 1, 1, 1}, - wantFPR: []float64{0, 0, 0.125, 0.125, 0.5, 0.75, 1}, + y: []float64{0, 3, 5, 6, 7.5, 8}, + c: []bool{false, true, false, true, true, true}, + w: []float64{4, 1, 6, 3, 2, 2}, + wantTPR: []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1}, + wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 1}, + wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, }, { // 2 - y: []float64{0, 3, 5, 6, 7.5, 8}, - c: []bool{true, false, true, false, false, false}, - cutoffs: []float64{-1, 2, 4, 6, 8}, - wantTPR: []float64{0, 0.5, 0.5, 1, 1}, - wantFPR: []float64{0, 0, 0.25, 0.5, 1}, + y: []float64{0, 3, 5, 6, 7.5, 8}, + c: []bool{false, true, false, true, true, true}, + cutoffs: []float64{-1, 2, 4, 6, 8}, + wantTPR: []float64{0, 0.5, 0.75, 1, 1}, + wantFPR: []float64{0, 0, 0.5, 0.5, 1}, + wantThresh: []float64{math.Inf(1), 8, 6, 4, 2}, }, { // 3 - y: []float64{0, 3, 5, 6, 7.5, 8}, - c: []bool{true, false, true, false, false, false}, - cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, - wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1, 1, 1, 1}, - wantFPR: []float64{0, 0, 0, 0.25, 0.25, 0.25, 0.5, 0.5, 1}, + y: []float64{0, 3, 5, 6, 7.5, 8}, + c: []bool{false, true, false, true, true, true}, + cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, + wantTPR: []float64{0, 0.5, 0.5, 0.75, 0.75, 0.75, 1, 1, 1}, + wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 0.5, 0.5, 1}, + wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1}, }, { // 4 - y: []float64{0, 3, 5, 6, 7.5, 8}, - c: []bool{true, false, true, false, false, false}, - w: []float64{4, 1, 6, 3, 2, 2}, - cutoffs: []float64{-1, 2, 4, 6, 8}, - wantTPR: []float64{0, 0.4, 0.4, 1, 1}, - wantFPR: []float64{0, 0, 0.125, 0.5, 1}, + y: []float64{0, 3, 5, 6, 7.5, 8}, + c: []bool{false, true, false, true, true, true}, + w: []float64{4, 1, 6, 3, 2, 2}, + cutoffs: []float64{-1, 2, 4, 6, 8}, + wantTPR: []float64{0, 0.5, 0.875, 1, 1}, + wantFPR: []float64{0, 0, 0.6, 0.6, 1}, + wantThresh: []float64{math.Inf(1), 8, 6, 4, 2}, }, { // 5 - y: []float64{0, 3, 5, 6, 7.5, 8}, - c: []bool{true, false, true, false, false, false}, - w: []float64{4, 1, 6, 3, 2, 2}, - cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, - wantTPR: []float64{0, 0.4, 0.4, 0.4, 0.4, 1, 1, 1, 1}, - wantFPR: []float64{0, 0, 0, 0.125, 0.125, 0.125, 0.5, 0.5, 1}, + y: []float64{0, 3, 5, 6, 7.5, 8}, + c: []bool{false, true, false, true, true, true}, + w: []float64{4, 1, 6, 3, 2, 2}, + cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, + wantTPR: []float64{0, 0.5, 0.5, 0.875, 0.875, 0.875, 1, 1, 1}, + wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 0.6, 0.6, 1}, + wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1}, }, { // 6 - y: []float64{0, 3, 6, 6, 6, 8}, - c: []bool{true, false, true, false, false, false}, - wantTPR: []float64{0, 0.5, 0.5, 1, 1}, - wantFPR: []float64{0, 0, 0.25, 0.75, 1}, + y: []float64{0, 3, 6, 6, 6, 8}, + c: []bool{false, true, false, true, true, true}, + wantTPR: []float64{0, 0.25, 0.75, 1, 1}, + wantFPR: []float64{0, 0, 0.5, 0.5, 1}, + wantThresh: []float64{math.Inf(1), 8, 6, 3, 0}, }, { // 7 - y: []float64{0, 3, 6, 6, 6, 8}, - c: []bool{true, false, true, false, false, false}, - w: []float64{4, 1, 6, 3, 2, 2}, - wantTPR: []float64{0, 0.4, 0.4, 1, 1}, - wantFPR: []float64{0, 0, 0.125, 0.75, 1}, + y: []float64{0, 3, 6, 6, 6, 8}, + c: []bool{false, true, false, true, true, true}, + w: []float64{4, 1, 6, 3, 2, 2}, + wantTPR: []float64{0, 0.25, 0.875, 1, 1}, + wantFPR: []float64{0, 0, 0.6, 0.6, 1}, + wantThresh: []float64{math.Inf(1), 8, 6, 3, 0}, }, { // 8 - y: []float64{0, 3, 6, 6, 6, 8}, - c: []bool{true, false, true, false, false, false}, - cutoffs: []float64{-1, 2, 4, 6, 8}, - wantTPR: []float64{0, 0.5, 0.5, 1, 1}, - wantFPR: []float64{0, 0, 0.25, 0.75, 1}, + y: []float64{0, 3, 6, 6, 6, 8}, + c: []bool{false, true, false, true, true, true}, + cutoffs: []float64{-1, 2, 4, 6, 8}, + wantTPR: []float64{0, 0.25, 0.75, 1, 1}, + wantFPR: []float64{0, 0, 0.5, 0.5, 1}, + wantThresh: []float64{math.Inf(1), 8, 6, 4, 2}, }, { // 9 - y: []float64{0, 3, 6, 6, 6, 8}, - c: []bool{true, false, true, false, false, false}, - cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, - wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 0.5, 1, 1, 1}, - wantFPR: []float64{0, 0, 0, 0.25, 0.25, 0.25, 0.75, 0.75, 1}, + y: []float64{0, 3, 6, 6, 6, 8}, + c: []bool{false, true, false, true, true, true}, + cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, + wantTPR: []float64{0, 0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1}, + wantFPR: []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1}, + wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1}, }, { // 10 - y: []float64{0, 3, 6, 6, 6, 8}, - c: []bool{true, false, true, false, false, false}, - w: []float64{4, 1, 6, 3, 2, 2}, - cutoffs: []float64{-1, 2, 4, 6, 8}, - wantTPR: []float64{0, 0.4, 0.4, 1, 1}, - wantFPR: []float64{0, 0, 0.125, 0.75, 1}, + y: []float64{0, 3, 6, 6, 6, 8}, + c: []bool{false, true, false, true, true, true}, + w: []float64{4, 1, 6, 3, 2, 2}, + cutoffs: []float64{-1, 2, 4, 6, 8}, + wantTPR: []float64{0, 0.25, 0.875, 1, 1}, + wantFPR: []float64{0, 0, 0.6, 0.6, 1}, + wantThresh: []float64{math.Inf(1), 8, 6, 4, 2}, }, { // 11 - y: []float64{0, 3, 6, 6, 6, 8}, - c: []bool{true, false, true, false, false, false}, - w: []float64{4, 1, 6, 3, 2, 2}, - cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, - wantTPR: []float64{0, 0.4, 0.4, 0.4, 0.4, 0.4, 1, 1, 1}, - wantFPR: []float64{0, 0, 0, 0.125, 0.125, 0.125, 0.75, 0.75, 1}, + y: []float64{0, 3, 6, 6, 6, 8}, + c: []bool{false, true, false, true, true, true}, + w: []float64{4, 1, 6, 3, 2, 2}, + cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, + wantTPR: []float64{0, 0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1}, + wantFPR: []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1}, + wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1}, }, { // 12 - y: []float64{1, 2}, - c: []bool{true, true}, - wantTPR: []float64{0, 0.5, 1}, - wantFPR: []float64{math.NaN(), math.NaN(), math.NaN()}, + y: []float64{0.1, 0.35, 0.4, 0.8}, + c: []bool{true, false, true, false}, + wantTPR: []float64{0, 0, 0.5, 0.5, 1}, + wantFPR: []float64{0, 0.5, 0.5, 1, 1}, + wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1}, }, { // 13 - y: []float64{1, 2}, - c: []bool{true, true}, - cutoffs: []float64{-1, 2}, - wantTPR: []float64{0, 1}, - wantFPR: []float64{math.NaN(), math.NaN()}, + y: []float64{0.1, 0.35, 0.4, 0.8}, + c: []bool{false, false, true, true}, + wantTPR: []float64{0, 0.5, 1, 1, 1}, + wantFPR: []float64{0, 0, 0, 0.5, 1}, + wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1}, }, { // 14 - y: []float64{1, 2}, - c: []bool{true, true}, - cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2}, - wantTPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1}, - wantFPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()}, + y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10}, + c: []bool{false, true, false, false, true, true, false}, + cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, + wantTPR: []float64{0, 0, 0, 0, 1}, + wantFPR: []float64{0, 0.25, 0.25, 0.25, 1}, + wantThresh: []float64{math.Inf(1), 10, 7.5, 5, 2.5}, }, { // 15 - y: []float64{1}, - c: []bool{true}, - wantTPR: []float64{0, 1}, - wantFPR: []float64{math.NaN(), math.NaN()}, + y: []float64{1, 2}, + c: []bool{false, false}, + wantTPR: []float64{math.NaN(), math.NaN(), math.NaN()}, + wantFPR: []float64{0, 0.5, 1}, + wantThresh: []float64{math.Inf(1), 2, 1}, }, { // 16 - y: []float64{1}, - c: []bool{true}, - cutoffs: []float64{-1, 1}, - wantTPR: []float64{0, 1}, - wantFPR: []float64{math.NaN(), math.NaN()}, + y: []float64{1, 2}, + c: []bool{false, false}, + cutoffs: []float64{-1, 2}, + wantTPR: []float64{math.NaN(), math.NaN()}, + wantFPR: []float64{0, 1}, + wantThresh: []float64{math.Inf(1), 2}, }, { // 17 - y: []float64{1}, - c: []bool{false}, - wantTPR: []float64{math.NaN(), math.NaN()}, - wantFPR: []float64{0, 1}, + y: []float64{1, 2}, + c: []bool{false, false}, + cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2}, + wantTPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()}, + wantFPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1}, + wantThresh: []float64{math.Inf(1), 2, 1.8, 1.6, 1.4, 1.2}, }, { // 18 - y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10}, - c: []bool{true, false, true, true, false, false, true}, - cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, - wantTPR: []float64{0, 0.75, 0.75, 0.75, 1}, - wantFPR: []float64{0, 1, 1, 1, 1}, + y: []float64{1}, + c: []bool{false}, + wantTPR: []float64{math.NaN(), math.NaN()}, + wantFPR: []float64{0, 1}, + wantThresh: []float64{math.Inf(1), 1}, }, { // 19 - y: []float64{}, - c: []bool{}, - wantTPR: nil, - wantFPR: nil, + y: []float64{1}, + c: []bool{false}, + cutoffs: []float64{-1, 1}, + wantTPR: []float64{math.NaN(), math.NaN()}, + wantFPR: []float64{0, 1}, + wantThresh: []float64{math.Inf(1), 1}, }, { // 20 - y: []float64{}, - c: []bool{}, - cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, - wantTPR: nil, - wantFPR: nil, + y: []float64{1}, + c: []bool{true}, + wantTPR: []float64{0, 1}, + wantFPR: []float64{math.NaN(), math.NaN()}, + wantThresh: []float64{math.Inf(1), 1}, + }, + { // 21 + y: []float64{}, + c: []bool{}, + wantTPR: nil, + wantFPR: nil, + wantThresh: nil, + }, + { // 22 + y: []float64{}, + c: []bool{}, + cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, + wantTPR: nil, + wantFPR: nil, + wantThresh: nil, }, } for i, test := range cases { - gotTPR, gotFPR := ROC(test.cutoffs, test.y, test.c, test.w) + gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w) if !floats.Same(gotTPR, test.wantTPR) { t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR) } if !floats.Same(gotFPR, test.wantFPR) { t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR) } + if !floats.Same(gotThresh, test.wantThresh) { + t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh) + } } }