mirror of
https://github.com/gonum/gonum.git
synced 2025-10-08 08:30:14 +08:00
205 lines
5.7 KiB
Go
205 lines
5.7 KiB
Go
// Copyright ©2014 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package distuv
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"testing"
|
|
|
|
"gonum.org/v1/gonum/diff/fd"
|
|
"gonum.org/v1/gonum/floats"
|
|
)
|
|
|
|
type univariateProbPoint struct {
|
|
loc float64
|
|
logProb float64
|
|
cumProb float64
|
|
prob float64
|
|
}
|
|
|
|
type UniProbDist interface {
|
|
Prob(float64) float64
|
|
CDF(float64) float64
|
|
LogProb(float64) float64
|
|
Quantile(float64) float64
|
|
Survival(float64) float64
|
|
}
|
|
|
|
func absEq(a, b float64) bool {
|
|
// This is expressed as the inverse to catch the
|
|
// case a = Inf and b = Inf of the same sign.
|
|
return !(math.Abs(a-b) > 1e-14)
|
|
}
|
|
|
|
// TODO: Implement a better test for Quantile
|
|
func testDistributionProbs(t *testing.T, dist UniProbDist, name string, pts []univariateProbPoint) {
|
|
for _, pt := range pts {
|
|
logProb := dist.LogProb(pt.loc)
|
|
if !absEq(logProb, pt.logProb) {
|
|
t.Errorf("Log probability doesnt match for "+name+". Expected %v. Found %v", pt.logProb, logProb)
|
|
}
|
|
prob := dist.Prob(pt.loc)
|
|
if !absEq(prob, pt.prob) {
|
|
t.Errorf("Probability doesn't match for "+name+". Expected %v. Found %v", pt.prob, prob)
|
|
}
|
|
cumProb := dist.CDF(pt.loc)
|
|
if !absEq(cumProb, pt.cumProb) {
|
|
t.Errorf("Cumulative Probability doesn't match for "+name+". Expected %v. Found %v", pt.cumProb, cumProb)
|
|
}
|
|
if !absEq(dist.Survival(pt.loc), 1-pt.cumProb) {
|
|
t.Errorf("Survival doesn't match for %v. Expected %v, Found %v", name, 1-pt.cumProb, dist.Survival(pt.loc))
|
|
}
|
|
if pt.prob != 0 {
|
|
if math.Abs(dist.Quantile(pt.cumProb)-pt.loc) > 1e-4 {
|
|
fmt.Println("true =", pt.loc)
|
|
fmt.Println("calculated=", dist.Quantile(pt.cumProb))
|
|
t.Errorf("Quantile doesn't match for "+name+", loc = %v", pt.loc)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type ConjugateUpdater interface {
|
|
NumParameters() int
|
|
parameters([]Parameter) []Parameter
|
|
|
|
NumSuffStat() int
|
|
SuffStat([]float64, []float64, []float64) float64
|
|
ConjugateUpdate([]float64, float64, []float64)
|
|
|
|
Rand() float64
|
|
}
|
|
|
|
func testConjugateUpdate(t *testing.T, newFittable func() ConjugateUpdater) {
|
|
for i, test := range []struct {
|
|
samps []float64
|
|
weights []float64
|
|
}{
|
|
{
|
|
samps: randn(newFittable(), 10),
|
|
weights: nil,
|
|
},
|
|
{
|
|
samps: randn(newFittable(), 10),
|
|
weights: ones(10),
|
|
},
|
|
{
|
|
samps: randn(newFittable(), 10),
|
|
weights: randn(&Exponential{Rate: 1}, 10),
|
|
},
|
|
} {
|
|
// ensure that conjugate produces the same result both incrementally and all at once
|
|
incDist := newFittable()
|
|
stats := make([]float64, incDist.NumSuffStat())
|
|
prior := make([]float64, incDist.NumParameters())
|
|
for j := range test.samps {
|
|
var incWeights, allWeights []float64
|
|
if test.weights != nil {
|
|
incWeights = test.weights[j : j+1]
|
|
allWeights = test.weights[0 : j+1]
|
|
}
|
|
nsInc := incDist.SuffStat(stats, test.samps[j:j+1], incWeights)
|
|
incDist.ConjugateUpdate(stats, nsInc, prior)
|
|
|
|
allDist := newFittable()
|
|
nsAll := allDist.SuffStat(stats, test.samps[0:j+1], allWeights)
|
|
allDist.ConjugateUpdate(stats, nsAll, make([]float64, allDist.NumParameters()))
|
|
if !parametersEqual(incDist.parameters(nil), allDist.parameters(nil), 1e-12) {
|
|
t.Errorf("prior doesn't match after incremental update for (%d, %d). Incremental is %v, all at once is %v", i, j, incDist, allDist)
|
|
}
|
|
|
|
if test.weights == nil {
|
|
onesDist := newFittable()
|
|
nsOnes := onesDist.SuffStat(stats, test.samps[0:j+1], ones(j+1))
|
|
onesDist.ConjugateUpdate(stats, nsOnes, make([]float64, onesDist.NumParameters()))
|
|
if !parametersEqual(onesDist.parameters(nil), incDist.parameters(nil), 1e-14) {
|
|
t.Errorf("nil and uniform weighted prior doesn't match for incremental update for (%d, %d). Uniform weighted is %v, nil is %v", i, j, onesDist, incDist)
|
|
}
|
|
if !parametersEqual(onesDist.parameters(nil), allDist.parameters(nil), 1e-14) {
|
|
t.Errorf("nil and uniform weighted prior doesn't match for all at once update for (%d, %d). Uniform weighted is %v, nil is %v", i, j, onesDist, incDist)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// randn generates a specified number of random samples
|
|
func randn(dist Rander, n int) []float64 {
|
|
x := make([]float64, n)
|
|
for i := range x {
|
|
x[i] = dist.Rand()
|
|
}
|
|
return x
|
|
}
|
|
|
|
func ones(n int) []float64 {
|
|
x := make([]float64, n)
|
|
for i := range x {
|
|
x[i] = 1
|
|
}
|
|
return x
|
|
}
|
|
|
|
func parametersEqual(p1, p2 []Parameter, tol float64) bool {
|
|
for i, p := range p1 {
|
|
if p.Name != p2[i].Name {
|
|
return false
|
|
}
|
|
if math.Abs(p.Value-p2[i].Value) > tol {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
type derivParamTester interface {
|
|
LogProb(x float64) float64
|
|
Score(deriv []float64, x float64) []float64
|
|
Quantile(p float64) float64
|
|
NumParameters() int
|
|
parameters([]Parameter) []Parameter
|
|
setParameters([]Parameter)
|
|
}
|
|
|
|
func testDerivParam(t *testing.T, d derivParamTester) {
|
|
// Tests that the derivative matches for a number of different quantiles
|
|
// along the distribution.
|
|
nTest := 10
|
|
quantiles := make([]float64, nTest)
|
|
floats.Span(quantiles, 0.1, 0.9)
|
|
|
|
deriv := make([]float64, d.NumParameters())
|
|
fdDeriv := make([]float64, d.NumParameters())
|
|
|
|
initParams := d.parameters(nil)
|
|
init := make([]float64, d.NumParameters())
|
|
for i, v := range initParams {
|
|
init[i] = v.Value
|
|
}
|
|
for _, v := range quantiles {
|
|
d.setParameters(initParams)
|
|
x := d.Quantile(v)
|
|
d.Score(deriv, x)
|
|
f := func(p []float64) float64 {
|
|
params := d.parameters(nil)
|
|
for i, v := range p {
|
|
params[i].Value = v
|
|
}
|
|
d.setParameters(params)
|
|
return d.LogProb(x)
|
|
}
|
|
fd.Gradient(fdDeriv, f, init, nil)
|
|
if !floats.EqualApprox(deriv, fdDeriv, 1e-6) {
|
|
t.Fatal("Derivative mismatch. Want", fdDeriv, ", got", deriv, ".")
|
|
}
|
|
d.setParameters(initParams)
|
|
d2 := d.Score(nil, x)
|
|
if !floats.EqualApprox(d2, deriv, 1e-14) {
|
|
t.Errorf("Derivative mismatch when input nil Want %v, got %v", d2, deriv)
|
|
}
|
|
}
|
|
}
|