// Copyright ©2015 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package distuv import ( "math" "testing" "golang.org/x/exp/rand" "gonum.org/v1/gonum/floats" ) func TestCategoricalProb(t *testing.T) { for _, test := range [][]float64{ {1, 2, 3, 0}, } { dist := NewCategorical(test, nil) norm := make([]float64, len(test)) floats.Scale(1/floats.Sum(norm), norm) for i, v := range norm { p := dist.Prob(float64(i)) if math.Abs(p-v) > 1e-14 { t.Errorf("Probability mismatch element %d", i) } p = dist.Prob(float64(i) + 0.5) if p != 0 { t.Errorf("Non-zero probability for non-integer x") } } p := dist.Prob(-1) if p != 0 { t.Errorf("Non-zero probability for -1") } p = dist.Prob(float64(len(test))) if p != 0 { t.Errorf("Non-zero probability for len(test)") } } } func TestCategoricalRand(t *testing.T) { for _, test := range [][]float64{ {1, 2, 3, 0}, } { dist := NewCategorical(test, nil) nSamples := 2000000 counts := sampleCategorical(t, dist, nSamples) probs := make([]float64, len(test)) for i := range probs { probs[i] = dist.Prob(float64(i)) } same := samedDistCategorical(dist, counts, probs, 1e-2) if !same { t.Errorf("Probability mismatch. Want %v, got %v", probs, counts) } dist.Reweight(len(test)-1, 10) counts = sampleCategorical(t, dist, nSamples) probs = make([]float64, len(test)) for i := range probs { probs[i] = dist.Prob(float64(i)) } same = samedDistCategorical(dist, counts, probs, 1e-2) if !same { t.Errorf("Probability mismatch after Reweight. Want %v, got %v", probs, counts) } w := make([]float64, len(test)) for i := range w { w[i] = rand.Float64() } dist.ReweightAll(w) counts = sampleCategorical(t, dist, nSamples) probs = make([]float64, len(test)) for i := range probs { probs[i] = dist.Prob(float64(i)) } same = samedDistCategorical(dist, counts, probs, 1e-2) if !same { t.Errorf("Probability mismatch after ReweightAll. Want %v, got %v", probs, counts) } } } func sampleCategorical(t *testing.T, dist Categorical, nSamples int) []float64 { counts := make([]float64, dist.Len()) for i := 0; i < nSamples; i++ { v := dist.Rand() if float64(int(v)) != v { t.Fatalf("Random number is not an integer") } counts[int(v)]++ } sum := floats.Sum(counts) floats.Scale(1/sum, counts) return counts } func samedDistCategorical(dist Categorical, counts, probs []float64, tol float64) bool { same := true for i, prob := range probs { if prob == 0 && counts[i] != 0 { same = false break } if !floats.EqualWithinAbsOrRel(prob, counts[i], tol, tol) { same = false break } } return same } func TestCategoricalCDF(t *testing.T) { for _, test := range [][]float64{ {1, 2, 3, 0, 4}, } { c := make([]float64, len(test)) copy(c, test) floats.Scale(1/floats.Sum(c), c) sum := make([]float64, len(test)) floats.CumSum(sum, c) dist := NewCategorical(test, nil) cdf := dist.CDF(-0.5) if cdf != 0 { t.Errorf("CDF of negative number not zero") } for i := range c { cdf := dist.CDF(float64(i)) if math.Abs(cdf-sum[i]) > 1e-14 { t.Errorf("CDF mismatch %v. Want %v, got %v.", float64(i), sum[i], cdf) } cdfp := dist.CDF(float64(i) + 0.5) if cdfp != cdf { t.Errorf("CDF mismatch for non-integer input") } } } } func TestCategoricalEntropy(t *testing.T) { for _, test := range []struct { weights []float64 entropy float64 }{ { weights: []float64{1, 1}, entropy: math.Ln2, }, { weights: []float64{1, 1, 1, 1}, entropy: math.Log(4), }, { weights: []float64{0, 0, 1, 1, 0, 0}, entropy: math.Ln2, }, } { dist := NewCategorical(test.weights, nil) entropy := dist.Entropy() if math.IsNaN(entropy) || math.Abs(entropy-test.entropy) > 1e-14 { t.Errorf("Entropy mismatch. Want %v, got %v.", test.entropy, entropy) } } } func TestCategoricalMean(t *testing.T) { for _, test := range []struct { weights []float64 mean float64 }{ { weights: []float64{10, 0, 0, 0}, mean: 0, }, { weights: []float64{0, 10, 0, 0}, mean: 1, }, { weights: []float64{1, 2, 3, 4}, mean: 2, }, } { dist := NewCategorical(test.weights, nil) mean := dist.Mean() if math.IsNaN(mean) || math.Abs(mean-test.mean) > 1e-14 { t.Errorf("Entropy mismatch. Want %v, got %v.", test.mean, mean) } } }