// Copyright ©2016 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package stat import ( "math" "testing" "gonum.org/v1/gonum/floats" ) func TestROC(t *testing.T) { const tol = 1e-14 cases := []struct { y []float64 c []bool w []float64 cutoffs []float64 wantTPR []float64 wantFPR []float64 wantThresh []float64 }{ // Test cases were calculated using sklearn metrics.roc_curve when // cutoffs is nil. Where cutoffs is not nil, a visual inspection is // used. // Some differences exist between unweighted ROCs from our function // and metrics.roc_curve which appears to use integer cutoffs in that // case. sklearn also appears to do some magic that trims leading zeros // sometimes. { // 0 y: []float64{0, 3, 5, 6, 7.5, 8}, c: []bool{false, true, false, true, true, true}, wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1}, wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1}, wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, }, { // 1 y: []float64{0, 3, 5, 6, 7.5, 8}, c: []bool{false, true, false, true, true, true}, w: []float64{4, 1, 6, 3, 2, 2}, wantTPR: []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1}, wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 1}, wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, }, { // 2 y: []float64{0, 3, 5, 6, 7.5, 8}, c: []bool{false, true, false, true, true, true}, cutoffs: []float64{-1, 2, 4, 6, 8}, wantTPR: []float64{0, 0.5, 0.75, 1, 1}, wantFPR: []float64{0, 0, 0.5, 0.5, 1}, wantThresh: []float64{math.Inf(1), 8, 6, 4, 2}, }, { // 3 y: []float64{0, 3, 5, 6, 7.5, 8}, c: []bool{false, true, false, true, true, true}, cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, wantTPR: []float64{0, 0.5, 0.5, 0.75, 0.75, 0.75, 1, 1, 1}, wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 0.5, 0.5, 1}, wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1}, }, { // 4 y: []float64{0, 3, 5, 6, 7.5, 8}, c: []bool{false, true, false, true, true, true}, w: []float64{4, 1, 6, 3, 2, 2}, cutoffs: []float64{-1, 2, 4, 6, 8}, wantTPR: []float64{0, 0.5, 0.875, 1, 1}, wantFPR: []float64{0, 0, 0.6, 0.6, 1}, wantThresh: []float64{math.Inf(1), 8, 6, 4, 2}, }, { // 5 y: []float64{0, 3, 5, 6, 7.5, 8}, c: []bool{false, true, false, true, true, true}, w: []float64{4, 1, 6, 3, 2, 2}, cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, wantTPR: []float64{0, 0.5, 0.5, 0.875, 0.875, 0.875, 1, 1, 1}, wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 0.6, 0.6, 1}, wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1}, }, { // 6 y: []float64{0, 3, 6, 6, 6, 8}, c: []bool{false, true, false, true, true, true}, wantTPR: []float64{0, 0.25, 0.75, 1, 1}, wantFPR: []float64{0, 0, 0.5, 0.5, 1}, wantThresh: []float64{math.Inf(1), 8, 6, 3, 0}, }, { // 7 y: []float64{0, 3, 6, 6, 6, 8}, c: []bool{false, true, false, true, true, true}, w: []float64{4, 1, 6, 3, 2, 2}, wantTPR: []float64{0, 0.25, 0.875, 1, 1}, wantFPR: []float64{0, 0, 0.6, 0.6, 1}, wantThresh: []float64{math.Inf(1), 8, 6, 3, 0}, }, { // 8 y: []float64{0, 3, 6, 6, 6, 8}, c: []bool{false, true, false, true, true, true}, cutoffs: []float64{-1, 2, 4, 6, 8}, wantTPR: []float64{0, 0.25, 0.75, 1, 1}, wantFPR: []float64{0, 0, 0.5, 0.5, 1}, wantThresh: []float64{math.Inf(1), 8, 6, 4, 2}, }, { // 9 y: []float64{0, 3, 6, 6, 6, 8}, c: []bool{false, true, false, true, true, true}, cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, wantTPR: []float64{0, 0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1}, wantFPR: []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1}, wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1}, }, { // 10 y: []float64{0, 3, 6, 6, 6, 8}, c: []bool{false, true, false, true, true, true}, w: []float64{4, 1, 6, 3, 2, 2}, cutoffs: []float64{-1, 2, 4, 6, 8}, wantTPR: []float64{0, 0.25, 0.875, 1, 1}, wantFPR: []float64{0, 0, 0.6, 0.6, 1}, wantThresh: []float64{math.Inf(1), 8, 6, 4, 2}, }, { // 11 y: []float64{0, 3, 6, 6, 6, 8}, c: []bool{false, true, false, true, true, true}, w: []float64{4, 1, 6, 3, 2, 2}, cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, wantTPR: []float64{0, 0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1}, wantFPR: []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1}, wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1}, }, { // 12 y: []float64{0.1, 0.35, 0.4, 0.8}, c: []bool{true, false, true, false}, wantTPR: []float64{0, 0, 0.5, 0.5, 1}, wantFPR: []float64{0, 0.5, 0.5, 1, 1}, wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1}, }, { // 13 y: []float64{0.1, 0.35, 0.4, 0.8}, c: []bool{false, false, true, true}, wantTPR: []float64{0, 0.5, 1, 1, 1}, wantFPR: []float64{0, 0, 0, 0.5, 1}, wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1}, }, { // 14 y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10}, c: []bool{false, true, false, false, true, true, false}, cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, wantTPR: []float64{0, 0, 0, 0, 1}, wantFPR: []float64{0, 0.25, 0.25, 0.25, 1}, wantThresh: []float64{math.Inf(1), 10, 7.5, 5, 2.5}, }, { // 15 y: []float64{1, 2}, c: []bool{false, false}, wantTPR: []float64{math.NaN(), math.NaN(), math.NaN()}, wantFPR: []float64{0, 0.5, 1}, wantThresh: []float64{math.Inf(1), 2, 1}, }, { // 16 y: []float64{1, 2}, c: []bool{false, false}, cutoffs: []float64{-1, 2}, wantTPR: []float64{math.NaN(), math.NaN()}, wantFPR: []float64{0, 1}, wantThresh: []float64{math.Inf(1), 2}, }, { // 17 y: []float64{1, 2}, c: []bool{false, false}, cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2}, wantTPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()}, wantFPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1}, wantThresh: []float64{math.Inf(1), 2, 1.8, 1.6, 1.4, 1.2}, }, { // 18 y: []float64{1}, c: []bool{false}, wantTPR: []float64{math.NaN(), math.NaN()}, wantFPR: []float64{0, 1}, wantThresh: []float64{math.Inf(1), 1}, }, { // 19 y: []float64{1}, c: []bool{false}, cutoffs: []float64{-1, 1}, wantTPR: []float64{math.NaN(), math.NaN()}, wantFPR: []float64{0, 1}, wantThresh: []float64{math.Inf(1), 1}, }, { // 20 y: []float64{1}, c: []bool{true}, wantTPR: []float64{0, 1}, wantFPR: []float64{math.NaN(), math.NaN()}, wantThresh: []float64{math.Inf(1), 1}, }, { // 21 y: []float64{}, c: []bool{}, wantTPR: nil, wantFPR: nil, wantThresh: nil, }, { // 22 y: []float64{}, c: []bool{}, cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, wantTPR: nil, wantFPR: nil, wantThresh: nil, }, } for i, test := range cases { gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w) if !floats.Same(gotTPR, test.wantTPR) && !floats.EqualApprox(gotTPR, test.wantTPR, tol) { t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR) } if !floats.Same(gotFPR, test.wantFPR) && !floats.EqualApprox(gotFPR, test.wantFPR, tol) { t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR) } if !floats.Same(gotThresh, test.wantThresh) { t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh) } } }