mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00
227 lines
7.9 KiB
Go
227 lines
7.9 KiB
Go
// Copyright ©2016 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package stat
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
|
|
"gonum.org/v1/gonum/floats"
|
|
)
|
|
|
|
func TestROC(t *testing.T) {
|
|
const tol = 1e-14
|
|
|
|
cases := []struct {
|
|
y []float64
|
|
c []bool
|
|
w []float64
|
|
cutoffs []float64
|
|
wantTPR []float64
|
|
wantFPR []float64
|
|
wantThresh []float64
|
|
}{
|
|
// Test cases were calculated using sklearn metrics.roc_curve when
|
|
// cutoffs is nil. Where cutoffs is not nil, a visual inspection is
|
|
// used.
|
|
// Some differences exist between unweighted ROCs from our function
|
|
// and metrics.roc_curve which appears to use integer cutoffs in that
|
|
// case. sklearn also appears to do some magic that trims leading zeros
|
|
// sometimes.
|
|
{ // 0
|
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
|
|
wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
|
|
},
|
|
{ // 1
|
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
w: []float64{4, 1, 6, 3, 2, 2},
|
|
wantTPR: []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1},
|
|
wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
|
|
},
|
|
{ // 2
|
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
cutoffs: []float64{-1, 2, 4, 6, 8},
|
|
wantTPR: []float64{0, 0.5, 0.75, 1, 1},
|
|
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
|
},
|
|
{ // 3
|
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
|
wantTPR: []float64{0, 0.5, 0.5, 0.75, 0.75, 0.75, 1, 1, 1},
|
|
wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 0.5, 0.5, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
|
},
|
|
{ // 4
|
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
w: []float64{4, 1, 6, 3, 2, 2},
|
|
cutoffs: []float64{-1, 2, 4, 6, 8},
|
|
wantTPR: []float64{0, 0.5, 0.875, 1, 1},
|
|
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
|
},
|
|
{ // 5
|
|
y: []float64{0, 3, 5, 6, 7.5, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
w: []float64{4, 1, 6, 3, 2, 2},
|
|
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
|
wantTPR: []float64{0, 0.5, 0.5, 0.875, 0.875, 0.875, 1, 1, 1},
|
|
wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 0.6, 0.6, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
|
},
|
|
{ // 6
|
|
y: []float64{0, 3, 6, 6, 6, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
wantTPR: []float64{0, 0.25, 0.75, 1, 1},
|
|
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
|
|
},
|
|
{ // 7
|
|
y: []float64{0, 3, 6, 6, 6, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
w: []float64{4, 1, 6, 3, 2, 2},
|
|
wantTPR: []float64{0, 0.25, 0.875, 1, 1},
|
|
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
|
|
},
|
|
{ // 8
|
|
y: []float64{0, 3, 6, 6, 6, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
cutoffs: []float64{-1, 2, 4, 6, 8},
|
|
wantTPR: []float64{0, 0.25, 0.75, 1, 1},
|
|
wantFPR: []float64{0, 0, 0.5, 0.5, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
|
},
|
|
{ // 9
|
|
y: []float64{0, 3, 6, 6, 6, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
|
wantTPR: []float64{0, 0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1},
|
|
wantFPR: []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
|
},
|
|
{ // 10
|
|
y: []float64{0, 3, 6, 6, 6, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
w: []float64{4, 1, 6, 3, 2, 2},
|
|
cutoffs: []float64{-1, 2, 4, 6, 8},
|
|
wantTPR: []float64{0, 0.25, 0.875, 1, 1},
|
|
wantFPR: []float64{0, 0, 0.6, 0.6, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
|
|
},
|
|
{ // 11
|
|
y: []float64{0, 3, 6, 6, 6, 8},
|
|
c: []bool{false, true, false, true, true, true},
|
|
w: []float64{4, 1, 6, 3, 2, 2},
|
|
cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
|
|
wantTPR: []float64{0, 0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1},
|
|
wantFPR: []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
|
|
wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
|
|
},
|
|
{ // 12
|
|
y: []float64{0.1, 0.35, 0.4, 0.8},
|
|
c: []bool{true, false, true, false},
|
|
wantTPR: []float64{0, 0, 0.5, 0.5, 1},
|
|
wantFPR: []float64{0, 0.5, 0.5, 1, 1},
|
|
wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
|
|
},
|
|
{ // 13
|
|
y: []float64{0.1, 0.35, 0.4, 0.8},
|
|
c: []bool{false, false, true, true},
|
|
wantTPR: []float64{0, 0.5, 1, 1, 1},
|
|
wantFPR: []float64{0, 0, 0, 0.5, 1},
|
|
wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
|
|
},
|
|
{ // 14
|
|
y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
|
|
c: []bool{false, true, false, false, true, true, false},
|
|
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
|
|
wantTPR: []float64{0, 0, 0, 0, 1},
|
|
wantFPR: []float64{0, 0.25, 0.25, 0.25, 1},
|
|
wantThresh: []float64{math.Inf(1), 10, 7.5, 5, 2.5},
|
|
},
|
|
{ // 15
|
|
y: []float64{1, 2},
|
|
c: []bool{false, false},
|
|
wantTPR: []float64{math.NaN(), math.NaN(), math.NaN()},
|
|
wantFPR: []float64{0, 0.5, 1},
|
|
wantThresh: []float64{math.Inf(1), 2, 1},
|
|
},
|
|
{ // 16
|
|
y: []float64{1, 2},
|
|
c: []bool{false, false},
|
|
cutoffs: []float64{-1, 2},
|
|
wantTPR: []float64{math.NaN(), math.NaN()},
|
|
wantFPR: []float64{0, 1},
|
|
wantThresh: []float64{math.Inf(1), 2},
|
|
},
|
|
{ // 17
|
|
y: []float64{1, 2},
|
|
c: []bool{false, false},
|
|
cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
|
|
wantTPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
|
|
wantFPR: []float64{0, 0.5, 0.5, 0.5, 0.5, 1},
|
|
wantThresh: []float64{math.Inf(1), 2, 1.8, 1.6, 1.4, 1.2},
|
|
},
|
|
{ // 18
|
|
y: []float64{1},
|
|
c: []bool{false},
|
|
wantTPR: []float64{math.NaN(), math.NaN()},
|
|
wantFPR: []float64{0, 1},
|
|
wantThresh: []float64{math.Inf(1), 1},
|
|
},
|
|
{ // 19
|
|
y: []float64{1},
|
|
c: []bool{false},
|
|
cutoffs: []float64{-1, 1},
|
|
wantTPR: []float64{math.NaN(), math.NaN()},
|
|
wantFPR: []float64{0, 1},
|
|
wantThresh: []float64{math.Inf(1), 1},
|
|
},
|
|
{ // 20
|
|
y: []float64{1},
|
|
c: []bool{true},
|
|
wantTPR: []float64{0, 1},
|
|
wantFPR: []float64{math.NaN(), math.NaN()},
|
|
wantThresh: []float64{math.Inf(1), 1},
|
|
},
|
|
{ // 21
|
|
y: []float64{},
|
|
c: []bool{},
|
|
wantTPR: nil,
|
|
wantFPR: nil,
|
|
wantThresh: nil,
|
|
},
|
|
{ // 22
|
|
y: []float64{},
|
|
c: []bool{},
|
|
cutoffs: []float64{-1, 2.5, 5, 7.5, 10},
|
|
wantTPR: nil,
|
|
wantFPR: nil,
|
|
wantThresh: nil,
|
|
},
|
|
}
|
|
for i, test := range cases {
|
|
gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w)
|
|
if !floats.Same(gotTPR, test.wantTPR) && !floats.EqualApprox(gotTPR, test.wantTPR, tol) {
|
|
t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR)
|
|
}
|
|
if !floats.Same(gotFPR, test.wantFPR) && !floats.EqualApprox(gotFPR, test.wantFPR, tol) {
|
|
t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR)
|
|
}
|
|
if !floats.Same(gotThresh, test.wantThresh) {
|
|
t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh)
|
|
}
|
|
}
|
|
}
|