mirror of
https://github.com/gonum/gonum.git
synced 2025-12-24 13:47:56 +08:00
stat: add TOC function
This commit is contained in:
71
stat/roc.go
71
stat/roc.go
@@ -127,3 +127,74 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thr
|
||||
|
||||
return tpr, fpr, cutoffs
|
||||
}
|
||||
|
||||
// TOC returns the Total Operating Characteristic for the classes provided
|
||||
// and the minimum and maximum bounds for the TOC.
|
||||
//
|
||||
// The input y values that correspond to classes and weights must be sorted
|
||||
// in ascending order. classes[i] is the class of value y[i] and weights[i]
|
||||
// is the weight of y[i]. SortWeightedLabeled can be used to sort classes
|
||||
// together with weights by the rank variable, i+1.
|
||||
//
|
||||
// The returned ntp values can be interpreted as the number of true positives
|
||||
// where values above the given rank are assigned class true for each given
|
||||
// rank from 1 to len(classes).
|
||||
// ntp_i = sum_{j = 1}^{len(ntp)-i-1} [ classes_j ] * weights_j, where [x] = 1 if x else 0.
|
||||
// and
|
||||
// ntp_0 = 0
|
||||
// The values of min and max provide the minimum and maximum possible number
|
||||
// of false values for the set of classes. The first element of ntp, min and
|
||||
// max are always zero as this corresponds to assigning all data class false
|
||||
// and the last elements are always weighted sum of classes as this corresponds
|
||||
// to assigning every data class true.
|
||||
//
|
||||
// If weights is nil, all weights are treated as 1. When weights are not nil,
|
||||
// the calculation of min and max allows for partial assignment of single data
|
||||
// points. If weights is not nil it must have the same length as classes,
|
||||
// otherwise TOC will panic.
|
||||
//
|
||||
// More details about TOC curves are available at
|
||||
// https://en.wikipedia.org/wiki/Total_operating_characteristic
|
||||
func TOC(classes []bool, weights []float64) (min, ntp, max []float64) {
|
||||
if weights != nil && len(classes) != len(weights) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
if len(classes) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
ntp = make([]float64, len(classes)+1)
|
||||
min = make([]float64, len(ntp))
|
||||
max = make([]float64, len(ntp))
|
||||
if weights == nil {
|
||||
for i := range ntp[1:] {
|
||||
ntp[i+1] = ntp[i]
|
||||
if classes[len(classes)-i-1] {
|
||||
ntp[i+1]++
|
||||
}
|
||||
}
|
||||
totalPositive := ntp[len(ntp)-1]
|
||||
for i := range ntp {
|
||||
min[i] = math.Max(0, totalPositive-float64(len(classes)-i))
|
||||
max[i] = math.Min(totalPositive, float64(i))
|
||||
}
|
||||
return min, ntp, max
|
||||
}
|
||||
|
||||
cumw := max // Reuse max for cumulative weight. Update its elements last.
|
||||
for i := range ntp[1:] {
|
||||
ntp[i+1] = ntp[i]
|
||||
w := weights[len(weights)-i-1]
|
||||
cumw[i+1] = cumw[i] + w
|
||||
if classes[len(classes)-i-1] {
|
||||
ntp[i+1] += w
|
||||
}
|
||||
}
|
||||
totw := cumw[len(cumw)-1]
|
||||
totalPositive := ntp[len(ntp)-1]
|
||||
for i := range ntp {
|
||||
min[i] = math.Max(0, totalPositive-(totw-cumw[i]))
|
||||
max[i] = math.Min(totalPositive, cumw[i])
|
||||
}
|
||||
return min, ntp, max
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func ExampleROC_equallySpacedCutoffs() {
|
||||
// false positive rate: [0 0 0 1 1 1 1 1 1]
|
||||
}
|
||||
|
||||
func ExampleROC_aUC() {
|
||||
func ExampleROC_aUC_unweighted() {
|
||||
y := []float64{0.1, 0.35, 0.4, 0.8}
|
||||
classes := []bool{true, false, true, false}
|
||||
|
||||
@@ -125,3 +125,126 @@ func ExampleROC_aUC() {
|
||||
// false positive rate: [0 0.5 0.5 1 1]
|
||||
// auc: 0.25
|
||||
}
|
||||
|
||||
func ExampleROC_aUC_weighted() {
|
||||
y := []float64{0.1, 0.35, 0.4, 0.8}
|
||||
classes := []bool{true, false, true, false}
|
||||
weights := []float64{1, 2, 2, 1}
|
||||
|
||||
tpr, fpr, _ := stat.ROC(nil, y, classes, weights)
|
||||
|
||||
// Compute Area Under Curve.
|
||||
auc := integrate.Trapezoidal(fpr, tpr)
|
||||
fmt.Printf("auc: %f\n", auc)
|
||||
|
||||
// Output:
|
||||
// auc: 0.444444
|
||||
|
||||
}
|
||||
|
||||
func ExampleTOC() {
|
||||
classes := []bool{
|
||||
false, false, false, false, false, false,
|
||||
false, false, false, false, false, false,
|
||||
false, false, true, true, true, true,
|
||||
true, true, true, false, false, true,
|
||||
false, true, false, false, true, false,
|
||||
}
|
||||
|
||||
min, ntp, max := stat.TOC(classes, nil)
|
||||
fmt.Printf("minimum bound: %v\n", min)
|
||||
fmt.Printf("TOC: %v\n", ntp)
|
||||
fmt.Printf("maximum bound: %v\n", max)
|
||||
|
||||
// Output:
|
||||
// minimum bound: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 2 3 4 5 6 7 8 9 10]
|
||||
// TOC: [0 0 1 1 1 2 2 3 3 3 4 5 6 7 8 9 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10]
|
||||
// maximum bound: [0 1 2 3 4 5 6 7 8 9 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10]
|
||||
}
|
||||
|
||||
func ExampleTOC_unsorted() {
|
||||
y := []float64{8, 7.5, 6, 5, 3, 0}
|
||||
classes := []bool{true, false, true, false, false, false}
|
||||
weights := []float64{4, 1, 6, 3, 2, 2}
|
||||
|
||||
stat.SortWeightedLabeled(y, classes, weights)
|
||||
|
||||
min, ntp, max := stat.TOC(classes, weights)
|
||||
fmt.Printf("minimum bound: %v\n", min)
|
||||
fmt.Printf("TOC: %v\n", ntp)
|
||||
fmt.Printf("maximum bound: %v\n", max)
|
||||
|
||||
// Output:
|
||||
// minimum bound: [0 0 0 3 6 8 10]
|
||||
// TOC: [0 4 4 10 10 10 10]
|
||||
// maximum bound: [0 4 5 10 10 10 10]
|
||||
}
|
||||
|
||||
func ExampleTOC_aUC_unweighted() {
|
||||
classes := []bool{true, false, true, false}
|
||||
|
||||
_, ntp, _ := stat.TOC(classes, nil)
|
||||
pos := ntp[len(ntp)-1]
|
||||
base := float64(len(classes)) - pos
|
||||
|
||||
// Compute the area under ntp and under the
|
||||
// minimum bound.
|
||||
x := floats.Span(make([]float64, len(classes)+1), 0, float64(len(classes)))
|
||||
aucNTP := integrate.Trapezoidal(x, ntp)
|
||||
aucMin := pos * pos / 2
|
||||
|
||||
// Calculate the the area under the curve
|
||||
// within the bounding parallelogram.
|
||||
auc := aucNTP - aucMin
|
||||
|
||||
// Calculate the area within the bounding
|
||||
// parallelogram.
|
||||
par := pos * base
|
||||
|
||||
// The AUC is the ratio of the area under
|
||||
// the curve within the bounding parallelogram
|
||||
// and the total parallelogram bound.
|
||||
auc /= par
|
||||
|
||||
fmt.Printf("number of true positives: %v\n", ntp)
|
||||
fmt.Printf("auc: %v\n", auc)
|
||||
|
||||
// Output:
|
||||
// number of true positives: [0 0 1 1 2]
|
||||
// auc: 0.25
|
||||
}
|
||||
|
||||
func ExampleTOC_aUC_weighted() {
|
||||
classes := []bool{true, false, true, false}
|
||||
weights := []float64{1, 2, 2, 1}
|
||||
|
||||
min, ntp, max := stat.TOC(classes, weights)
|
||||
|
||||
// Compute the area under ntp and under the
|
||||
// minimum and maximum bounds.
|
||||
x := make([]float64, len(classes)+1)
|
||||
floats.CumSum(x[1:], weights)
|
||||
aucNTP := integrate.Trapezoidal(x, ntp)
|
||||
aucMin := integrate.Trapezoidal(x, min)
|
||||
aucMax := integrate.Trapezoidal(x, max)
|
||||
|
||||
// Calculate the the area under the curve
|
||||
// within the bounding parallelogram.
|
||||
auc := aucNTP - aucMin
|
||||
|
||||
// Calculate the area within the bounding
|
||||
// parallelogram.
|
||||
par := aucMax - aucMin
|
||||
|
||||
// The AUC is the ratio of the area under
|
||||
// the curve within the bounding parallelogram
|
||||
// and the total parallelogram bound.
|
||||
auc /= par
|
||||
|
||||
fmt.Printf("number of true positives: %v\n", ntp)
|
||||
fmt.Printf("auc: %f\n", auc)
|
||||
|
||||
// Output:
|
||||
// number of true positives: [0 0 2 2 3]
|
||||
// auc: 0.444444
|
||||
}
|
||||
|
||||
@@ -256,3 +256,71 @@ func TestROC(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOC(t *testing.T) {
|
||||
cases := []struct {
|
||||
c []bool
|
||||
w []float64
|
||||
wantMin []float64
|
||||
wantMax []float64
|
||||
wantTOC []float64
|
||||
}{
|
||||
{ // 0
|
||||
// This is the example given in the paper's supplement.
|
||||
// http://www2.clarku.edu/~rpontius/TOCexample2.xlsx
|
||||
// It is also shown in the WP article.
|
||||
// https://en.wikipedia.org/wiki/Total_operating_characteristic#/media/File:TOC_labeled.png
|
||||
c: []bool{
|
||||
false, false, false, false, false, false,
|
||||
false, false, false, false, false, false,
|
||||
false, false, true, true, true, true,
|
||||
true, true, true, false, false, true,
|
||||
false, true, false, false, true, false,
|
||||
},
|
||||
wantMin: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
wantMax: []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10},
|
||||
wantTOC: []float64{0, 0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10},
|
||||
},
|
||||
{ // 1
|
||||
c: []bool{},
|
||||
wantMin: nil,
|
||||
wantMax: nil,
|
||||
wantTOC: nil,
|
||||
},
|
||||
{ // 2
|
||||
c: []bool{
|
||||
true, true, true, true, true,
|
||||
},
|
||||
wantMin: []float64{0, 1, 2, 3, 4, 5},
|
||||
wantMax: []float64{0, 1, 2, 3, 4, 5},
|
||||
wantTOC: []float64{0, 1, 2, 3, 4, 5},
|
||||
},
|
||||
{ // 3
|
||||
c: []bool{
|
||||
false, false, false, false, false,
|
||||
},
|
||||
wantMin: []float64{0, 0, 0, 0, 0, 0},
|
||||
wantMax: []float64{0, 0, 0, 0, 0, 0},
|
||||
wantTOC: []float64{0, 0, 0, 0, 0, 0},
|
||||
},
|
||||
{ // 4
|
||||
c: []bool{false, false, false, true, false, true},
|
||||
w: []float64{2, 2, 3, 6, 1, 4},
|
||||
wantMin: []float64{0, 0, 0, 3, 6, 8, 10},
|
||||
wantMax: []float64{0, 4, 5, 10, 10, 10, 10},
|
||||
wantTOC: []float64{0, 4, 4, 10, 10, 10, 10},
|
||||
},
|
||||
}
|
||||
for i, test := range cases {
|
||||
gotMin, gotTOC, gotMax := TOC(test.c, test.w)
|
||||
if !floats.Same(gotMin, test.wantMin) {
|
||||
t.Errorf("%d: unexpected minimum bound got:%v want:%v", i, gotMin, test.wantMin)
|
||||
}
|
||||
if !floats.Same(gotMax, test.wantMax) {
|
||||
t.Errorf("%d: unexpected maximum bound got:%v want:%v", i, gotMax, test.wantMax)
|
||||
}
|
||||
if !floats.Same(gotTOC, test.wantTOC) {
|
||||
t.Errorf("%d: unexpected TOC got:%v want:%v", i, gotTOC, test.wantTOC)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user