stat: add TOC function

This commit is contained in:
Dan Kortschak
2021-04-12 18:31:26 +09:30
committed by GitHub
parent eae7c2a69e
commit 11e9fee16c
3 changed files with 263 additions and 1 deletions

View File

@@ -127,3 +127,74 @@ func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thr
return tpr, fpr, cutoffs
}
// TOC returns the Total Operating Characteristic for the classes provided
// and the minimum and maximum bounds for the TOC.
//
// The input y values that correspond to classes and weights must be sorted
// in ascending order. classes[i] is the class of value y[i] and weights[i]
// is the weight of y[i]. SortWeightedLabeled can be used to sort classes
// together with weights by the rank variable, i+1.
//
// The returned ntp values can be interpreted as the number of true positives
// where values above the given rank are assigned class true for each given
// rank from 1 to len(classes).
// ntp_i = sum_{j = 1}^{len(ntp)-i-1} [ classes_j ] * weights_j, where [x] = 1 if x else 0.
// and
// ntp_0 = 0
// The values of min and max provide the minimum and maximum possible number
// of false values for the set of classes. The first element of ntp, min and
// max are always zero as this corresponds to assigning all data class false
// and the last elements are always weighted sum of classes as this corresponds
// to assigning every data class true.
//
// If weights is nil, all weights are treated as 1. When weights are not nil,
// the calculation of min and max allows for partial assignment of single data
// points. If weights is not nil it must have the same length as classes,
// otherwise TOC will panic.
//
// More details about TOC curves are available at
// https://en.wikipedia.org/wiki/Total_operating_characteristic
func TOC(classes []bool, weights []float64) (min, ntp, max []float64) {
if weights != nil && len(classes) != len(weights) {
panic("stat: slice length mismatch")
}
if len(classes) == 0 {
return nil, nil, nil
}
ntp = make([]float64, len(classes)+1)
min = make([]float64, len(ntp))
max = make([]float64, len(ntp))
if weights == nil {
for i := range ntp[1:] {
ntp[i+1] = ntp[i]
if classes[len(classes)-i-1] {
ntp[i+1]++
}
}
totalPositive := ntp[len(ntp)-1]
for i := range ntp {
min[i] = math.Max(0, totalPositive-float64(len(classes)-i))
max[i] = math.Min(totalPositive, float64(i))
}
return min, ntp, max
}
cumw := max // Reuse max for cumulative weight. Update its elements last.
for i := range ntp[1:] {
ntp[i+1] = ntp[i]
w := weights[len(weights)-i-1]
cumw[i+1] = cumw[i] + w
if classes[len(classes)-i-1] {
ntp[i+1] += w
}
}
totw := cumw[len(cumw)-1]
totalPositive := ntp[len(ntp)-1]
for i := range ntp {
min[i] = math.Max(0, totalPositive-(totw-cumw[i]))
max[i] = math.Min(totalPositive, cumw[i])
}
return min, ntp, max
}

View File

@@ -108,7 +108,7 @@ func ExampleROC_equallySpacedCutoffs() {
// false positive rate: [0 0 0 1 1 1 1 1 1]
}
func ExampleROC_aUC() {
func ExampleROC_aUC_unweighted() {
y := []float64{0.1, 0.35, 0.4, 0.8}
classes := []bool{true, false, true, false}
@@ -125,3 +125,126 @@ func ExampleROC_aUC() {
// false positive rate: [0 0.5 0.5 1 1]
// auc: 0.25
}
func ExampleROC_aUC_weighted() {
y := []float64{0.1, 0.35, 0.4, 0.8}
classes := []bool{true, false, true, false}
weights := []float64{1, 2, 2, 1}
tpr, fpr, _ := stat.ROC(nil, y, classes, weights)
// Compute Area Under Curve.
auc := integrate.Trapezoidal(fpr, tpr)
fmt.Printf("auc: %f\n", auc)
// Output:
// auc: 0.444444
}
func ExampleTOC() {
classes := []bool{
false, false, false, false, false, false,
false, false, false, false, false, false,
false, false, true, true, true, true,
true, true, true, false, false, true,
false, true, false, false, true, false,
}
min, ntp, max := stat.TOC(classes, nil)
fmt.Printf("minimum bound: %v\n", min)
fmt.Printf("TOC: %v\n", ntp)
fmt.Printf("maximum bound: %v\n", max)
// Output:
// minimum bound: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 2 3 4 5 6 7 8 9 10]
// TOC: [0 0 1 1 1 2 2 3 3 3 4 5 6 7 8 9 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10]
// maximum bound: [0 1 2 3 4 5 6 7 8 9 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10]
}
func ExampleTOC_unsorted() {
y := []float64{8, 7.5, 6, 5, 3, 0}
classes := []bool{true, false, true, false, false, false}
weights := []float64{4, 1, 6, 3, 2, 2}
stat.SortWeightedLabeled(y, classes, weights)
min, ntp, max := stat.TOC(classes, weights)
fmt.Printf("minimum bound: %v\n", min)
fmt.Printf("TOC: %v\n", ntp)
fmt.Printf("maximum bound: %v\n", max)
// Output:
// minimum bound: [0 0 0 3 6 8 10]
// TOC: [0 4 4 10 10 10 10]
// maximum bound: [0 4 5 10 10 10 10]
}
func ExampleTOC_aUC_unweighted() {
classes := []bool{true, false, true, false}
_, ntp, _ := stat.TOC(classes, nil)
pos := ntp[len(ntp)-1]
base := float64(len(classes)) - pos
// Compute the area under ntp and under the
// minimum bound.
x := floats.Span(make([]float64, len(classes)+1), 0, float64(len(classes)))
aucNTP := integrate.Trapezoidal(x, ntp)
aucMin := pos * pos / 2
// Calculate the the area under the curve
// within the bounding parallelogram.
auc := aucNTP - aucMin
// Calculate the area within the bounding
// parallelogram.
par := pos * base
// The AUC is the ratio of the area under
// the curve within the bounding parallelogram
// and the total parallelogram bound.
auc /= par
fmt.Printf("number of true positives: %v\n", ntp)
fmt.Printf("auc: %v\n", auc)
// Output:
// number of true positives: [0 0 1 1 2]
// auc: 0.25
}
func ExampleTOC_aUC_weighted() {
classes := []bool{true, false, true, false}
weights := []float64{1, 2, 2, 1}
min, ntp, max := stat.TOC(classes, weights)
// Compute the area under ntp and under the
// minimum and maximum bounds.
x := make([]float64, len(classes)+1)
floats.CumSum(x[1:], weights)
aucNTP := integrate.Trapezoidal(x, ntp)
aucMin := integrate.Trapezoidal(x, min)
aucMax := integrate.Trapezoidal(x, max)
// Calculate the the area under the curve
// within the bounding parallelogram.
auc := aucNTP - aucMin
// Calculate the area within the bounding
// parallelogram.
par := aucMax - aucMin
// The AUC is the ratio of the area under
// the curve within the bounding parallelogram
// and the total parallelogram bound.
auc /= par
fmt.Printf("number of true positives: %v\n", ntp)
fmt.Printf("auc: %f\n", auc)
// Output:
// number of true positives: [0 0 2 2 3]
// auc: 0.444444
}

View File

@@ -256,3 +256,71 @@ func TestROC(t *testing.T) {
}
}
}
func TestTOC(t *testing.T) {
cases := []struct {
c []bool
w []float64
wantMin []float64
wantMax []float64
wantTOC []float64
}{
{ // 0
// This is the example given in the paper's supplement.
// http://www2.clarku.edu/~rpontius/TOCexample2.xlsx
// It is also shown in the WP article.
// https://en.wikipedia.org/wiki/Total_operating_characteristic#/media/File:TOC_labeled.png
c: []bool{
false, false, false, false, false, false,
false, false, false, false, false, false,
false, false, true, true, true, true,
true, true, true, false, false, true,
false, true, false, false, true, false,
},
wantMin: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
wantMax: []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10},
wantTOC: []float64{0, 0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10},
},
{ // 1
c: []bool{},
wantMin: nil,
wantMax: nil,
wantTOC: nil,
},
{ // 2
c: []bool{
true, true, true, true, true,
},
wantMin: []float64{0, 1, 2, 3, 4, 5},
wantMax: []float64{0, 1, 2, 3, 4, 5},
wantTOC: []float64{0, 1, 2, 3, 4, 5},
},
{ // 3
c: []bool{
false, false, false, false, false,
},
wantMin: []float64{0, 0, 0, 0, 0, 0},
wantMax: []float64{0, 0, 0, 0, 0, 0},
wantTOC: []float64{0, 0, 0, 0, 0, 0},
},
{ // 4
c: []bool{false, false, false, true, false, true},
w: []float64{2, 2, 3, 6, 1, 4},
wantMin: []float64{0, 0, 0, 3, 6, 8, 10},
wantMax: []float64{0, 4, 5, 10, 10, 10, 10},
wantTOC: []float64{0, 4, 4, 10, 10, 10, 10},
},
}
for i, test := range cases {
gotMin, gotTOC, gotMax := TOC(test.c, test.w)
if !floats.Same(gotMin, test.wantMin) {
t.Errorf("%d: unexpected minimum bound got:%v want:%v", i, gotMin, test.wantMin)
}
if !floats.Same(gotMax, test.wantMax) {
t.Errorf("%d: unexpected maximum bound got:%v want:%v", i, gotMax, test.wantMax)
}
if !floats.Same(gotTOC, test.wantTOC) {
t.Errorf("%d: unexpected TOC got:%v want:%v", i, gotTOC, test.wantTOC)
}
}
}