mirror of
				https://github.com/gonum/gonum.git
				synced 2025-10-27 01:00:26 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			227 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			227 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright ©2016 The Gonum Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package stat
 | |
| 
 | |
| import (
 | |
| 	"math"
 | |
| 	"testing"
 | |
| 
 | |
| 	"gonum.org/v1/gonum/floats"
 | |
| )
 | |
| 
 | |
| func TestROC(t *testing.T) {
 | |
| 	const tol = 1e-14
 | |
| 
 | |
| 	cases := []struct {
 | |
| 		y          []float64
 | |
| 		c          []bool
 | |
| 		w          []float64
 | |
| 		cutoffs    []float64
 | |
| 		wantTPR    []float64
 | |
| 		wantFPR    []float64
 | |
| 		wantThresh []float64
 | |
| 	}{
 | |
| 		// Test cases were calculated using sklearn metrics.roc_curve when
 | |
| 		// cutoffs is nil. Where cutoffs is not nil, a visual inspection is
 | |
| 		// used.
 | |
| 		// Some differences exist between unweighted ROCs from our function
 | |
| 		// and metrics.roc_curve which appears to use integer cutoffs in that
 | |
| 		// case. sklearn also appears to do some magic that trims leading zeros
 | |
| 		// sometimes.
 | |
| 		{ // 0
 | |
| 			y:          []float64{0, 3, 5, 6, 7.5, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			wantTPR:    []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
 | |
| 		},
 | |
| 		{ // 1
 | |
| 			y:          []float64{0, 3, 5, 6, 7.5, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			w:          []float64{4, 1, 6, 3, 2, 2},
 | |
| 			wantTPR:    []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0, 0, 0.6, 0.6, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
 | |
| 		},
 | |
| 		{ // 2
 | |
| 			y:          []float64{0, 3, 5, 6, 7.5, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			cutoffs:    []float64{-1, 2, 4, 6, 8},
 | |
| 			wantTPR:    []float64{0, 0.5, 0.75, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0.5, 0.5, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
 | |
| 		},
 | |
| 		{ // 3
 | |
| 			y:          []float64{0, 3, 5, 6, 7.5, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
 | |
| 			wantTPR:    []float64{0, 0.5, 0.5, 0.75, 0.75, 0.75, 1, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 0.5, 0.5, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
 | |
| 		},
 | |
| 		{ // 4
 | |
| 			y:          []float64{0, 3, 5, 6, 7.5, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			w:          []float64{4, 1, 6, 3, 2, 2},
 | |
| 			cutoffs:    []float64{-1, 2, 4, 6, 8},
 | |
| 			wantTPR:    []float64{0, 0.5, 0.875, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0.6, 0.6, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
 | |
| 		},
 | |
| 		{ // 5
 | |
| 			y:          []float64{0, 3, 5, 6, 7.5, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			w:          []float64{4, 1, 6, 3, 2, 2},
 | |
| 			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
 | |
| 			wantTPR:    []float64{0, 0.5, 0.5, 0.875, 0.875, 0.875, 1, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0, 0, 0.6, 0.6, 0.6, 0.6, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
 | |
| 		},
 | |
| 		{ // 6
 | |
| 			y:          []float64{0, 3, 6, 6, 6, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			wantTPR:    []float64{0, 0.25, 0.75, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0.5, 0.5, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
 | |
| 		},
 | |
| 		{ // 7
 | |
| 			y:          []float64{0, 3, 6, 6, 6, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			w:          []float64{4, 1, 6, 3, 2, 2},
 | |
| 			wantTPR:    []float64{0, 0.25, 0.875, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0.6, 0.6, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
 | |
| 		},
 | |
| 		{ // 8
 | |
| 			y:          []float64{0, 3, 6, 6, 6, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			cutoffs:    []float64{-1, 2, 4, 6, 8},
 | |
| 			wantTPR:    []float64{0, 0.25, 0.75, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0.5, 0.5, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
 | |
| 		},
 | |
| 		{ // 9
 | |
| 			y:          []float64{0, 3, 6, 6, 6, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
 | |
| 			wantTPR:    []float64{0, 0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
 | |
| 		},
 | |
| 		{ // 10
 | |
| 			y:          []float64{0, 3, 6, 6, 6, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			w:          []float64{4, 1, 6, 3, 2, 2},
 | |
| 			cutoffs:    []float64{-1, 2, 4, 6, 8},
 | |
| 			wantTPR:    []float64{0, 0.25, 0.875, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0.6, 0.6, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 6, 4, 2},
 | |
| 		},
 | |
| 		{ // 11
 | |
| 			y:          []float64{0, 3, 6, 6, 6, 8},
 | |
| 			c:          []bool{false, true, false, true, true, true},
 | |
| 			w:          []float64{4, 1, 6, 3, 2, 2},
 | |
| 			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
 | |
| 			wantTPR:    []float64{0, 0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 8, 7, 6, 5, 4, 3, 2, 1},
 | |
| 		},
 | |
| 		{ // 12
 | |
| 			y:          []float64{0.1, 0.35, 0.4, 0.8},
 | |
| 			c:          []bool{true, false, true, false},
 | |
| 			wantTPR:    []float64{0, 0, 0.5, 0.5, 1},
 | |
| 			wantFPR:    []float64{0, 0.5, 0.5, 1, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
 | |
| 		},
 | |
| 		{ // 13
 | |
| 			y:          []float64{0.1, 0.35, 0.4, 0.8},
 | |
| 			c:          []bool{false, false, true, true},
 | |
| 			wantTPR:    []float64{0, 0.5, 1, 1, 1},
 | |
| 			wantFPR:    []float64{0, 0, 0, 0.5, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
 | |
| 		},
 | |
| 		{ // 14
 | |
| 			y:          []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
 | |
| 			c:          []bool{false, true, false, false, true, true, false},
 | |
| 			cutoffs:    []float64{-1, 2.5, 5, 7.5, 10},
 | |
| 			wantTPR:    []float64{0, 0, 0, 0, 1},
 | |
| 			wantFPR:    []float64{0, 0.25, 0.25, 0.25, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 10, 7.5, 5, 2.5},
 | |
| 		},
 | |
| 		{ // 15
 | |
| 			y:          []float64{1, 2},
 | |
| 			c:          []bool{false, false},
 | |
| 			wantTPR:    []float64{math.NaN(), math.NaN(), math.NaN()},
 | |
| 			wantFPR:    []float64{0, 0.5, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 2, 1},
 | |
| 		},
 | |
| 		{ // 16
 | |
| 			y:          []float64{1, 2},
 | |
| 			c:          []bool{false, false},
 | |
| 			cutoffs:    []float64{-1, 2},
 | |
| 			wantTPR:    []float64{math.NaN(), math.NaN()},
 | |
| 			wantFPR:    []float64{0, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 2},
 | |
| 		},
 | |
| 		{ // 17
 | |
| 			y:          []float64{1, 2},
 | |
| 			c:          []bool{false, false},
 | |
| 			cutoffs:    []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
 | |
| 			wantTPR:    []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
 | |
| 			wantFPR:    []float64{0, 0.5, 0.5, 0.5, 0.5, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 2, 1.8, 1.6, 1.4, 1.2},
 | |
| 		},
 | |
| 		{ // 18
 | |
| 			y:          []float64{1},
 | |
| 			c:          []bool{false},
 | |
| 			wantTPR:    []float64{math.NaN(), math.NaN()},
 | |
| 			wantFPR:    []float64{0, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 1},
 | |
| 		},
 | |
| 		{ // 19
 | |
| 			y:          []float64{1},
 | |
| 			c:          []bool{false},
 | |
| 			cutoffs:    []float64{-1, 1},
 | |
| 			wantTPR:    []float64{math.NaN(), math.NaN()},
 | |
| 			wantFPR:    []float64{0, 1},
 | |
| 			wantThresh: []float64{math.Inf(1), 1},
 | |
| 		},
 | |
| 		{ // 20
 | |
| 			y:          []float64{1},
 | |
| 			c:          []bool{true},
 | |
| 			wantTPR:    []float64{0, 1},
 | |
| 			wantFPR:    []float64{math.NaN(), math.NaN()},
 | |
| 			wantThresh: []float64{math.Inf(1), 1},
 | |
| 		},
 | |
| 		{ // 21
 | |
| 			y:          []float64{},
 | |
| 			c:          []bool{},
 | |
| 			wantTPR:    nil,
 | |
| 			wantFPR:    nil,
 | |
| 			wantThresh: nil,
 | |
| 		},
 | |
| 		{ // 22
 | |
| 			y:          []float64{},
 | |
| 			c:          []bool{},
 | |
| 			cutoffs:    []float64{-1, 2.5, 5, 7.5, 10},
 | |
| 			wantTPR:    nil,
 | |
| 			wantFPR:    nil,
 | |
| 			wantThresh: nil,
 | |
| 		},
 | |
| 	}
 | |
| 	for i, test := range cases {
 | |
| 		gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w)
 | |
| 		if !floats.Same(gotTPR, test.wantTPR) && !floats.EqualApprox(gotTPR, test.wantTPR, tol) {
 | |
| 			t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR)
 | |
| 		}
 | |
| 		if !floats.Same(gotFPR, test.wantFPR) && !floats.EqualApprox(gotFPR, test.wantFPR, tol) {
 | |
| 			t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR)
 | |
| 		}
 | |
| 		if !floats.Same(gotThresh, test.wantThresh) {
 | |
| 			t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh)
 | |
| 		}
 | |
| 	}
 | |
| }
 | 
