mirror of
https://github.com/gonum/gonum.git
synced 2025-10-16 04:00:48 +08:00

Currently, we throw sigma away, and recompute it if necessary. This PR keeps sigma. This fixes an issue with concurrent calling of methods. In addition, however, it removes any possible issues with reconstructing a badly-conditioned sigma from its Cholesky decomposition, and avoids an extra n^3 work if sigma does need to be recomputed. The complexity of the implementation and difficulties listed above is not worth the memory savings in some cases, especially since the memory of the type is already O(n^2)
539 lines
13 KiB
Go
539 lines
13 KiB
Go
// Copyright ©2015 The gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package distmv
|
|
|
|
import (
|
|
"math"
|
|
"math/rand"
|
|
"testing"
|
|
|
|
"github.com/gonum/floats"
|
|
"github.com/gonum/matrix/mat64"
|
|
"github.com/gonum/stat"
|
|
)
|
|
|
|
type mvTest struct {
|
|
Mu []float64
|
|
Sigma *mat64.SymDense
|
|
Loc []float64
|
|
Logprob float64
|
|
Prob float64
|
|
}
|
|
|
|
func TestNormProbs(t *testing.T) {
|
|
dist1, ok := NewNormal([]float64{0, 0}, mat64.NewSymDense(2, []float64{1, 0, 0, 1}), nil)
|
|
if !ok {
|
|
t.Errorf("bad test")
|
|
}
|
|
dist2, ok := NewNormal([]float64{6, 7}, mat64.NewSymDense(2, []float64{8, 2, 0, 4}), nil)
|
|
if !ok {
|
|
t.Errorf("bad test")
|
|
}
|
|
testProbability(t, []probCase{
|
|
{
|
|
dist: dist1,
|
|
loc: []float64{0, 0},
|
|
logProb: -1.837877066409345,
|
|
},
|
|
{
|
|
dist: dist2,
|
|
loc: []float64{6, 7},
|
|
logProb: -3.503979321496947,
|
|
},
|
|
{
|
|
dist: dist2,
|
|
loc: []float64{1, 2},
|
|
logProb: -7.075407892925519,
|
|
},
|
|
})
|
|
}
|
|
|
|
func TestNewNormalChol(t *testing.T) {
|
|
for _, test := range []struct {
|
|
mean []float64
|
|
cov *mat64.SymDense
|
|
}{
|
|
{
|
|
mean: []float64{2, 3},
|
|
cov: mat64.NewSymDense(2, []float64{1, 0.1, 0.1, 1}),
|
|
},
|
|
} {
|
|
var chol mat64.Cholesky
|
|
ok := chol.Factorize(test.cov)
|
|
if !ok {
|
|
panic("bad test")
|
|
}
|
|
n := NewNormalChol(test.mean, &chol, nil)
|
|
// Generate a random number and calculate probability to ensure things
|
|
// have been set properly. See issue #426.
|
|
x := n.Rand(nil)
|
|
_ = n.Prob(x)
|
|
}
|
|
}
|
|
|
|
func TestNormRand(t *testing.T) {
|
|
for _, test := range []struct {
|
|
mean []float64
|
|
cov []float64
|
|
}{
|
|
{
|
|
mean: []float64{0, 0},
|
|
cov: []float64{
|
|
1, 0,
|
|
0, 1,
|
|
},
|
|
},
|
|
{
|
|
mean: []float64{0, 0},
|
|
cov: []float64{
|
|
1, 0.9,
|
|
0.9, 1,
|
|
},
|
|
},
|
|
{
|
|
mean: []float64{6, 7},
|
|
cov: []float64{
|
|
5, 0.9,
|
|
0.9, 2,
|
|
},
|
|
},
|
|
} {
|
|
dim := len(test.mean)
|
|
cov := mat64.NewSymDense(dim, test.cov)
|
|
n, ok := NewNormal(test.mean, cov, nil)
|
|
if !ok {
|
|
t.Errorf("bad covariance matrix")
|
|
}
|
|
|
|
nSamples := 1000000
|
|
samps := mat64.NewDense(nSamples, dim, nil)
|
|
for i := 0; i < nSamples; i++ {
|
|
n.Rand(samps.RawRowView(i))
|
|
}
|
|
estMean := make([]float64, dim)
|
|
for i := range estMean {
|
|
estMean[i] = stat.Mean(mat64.Col(nil, i, samps), nil)
|
|
}
|
|
if !floats.EqualApprox(estMean, test.mean, 1e-2) {
|
|
t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean)
|
|
}
|
|
estCov := stat.CovarianceMatrix(nil, samps, nil)
|
|
if !mat64.EqualApprox(estCov, cov, 1e-2) {
|
|
t.Errorf("Cov mismatch: want: %v, got %v", cov, estCov)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestNormalQuantile(t *testing.T) {
|
|
for _, test := range []struct {
|
|
mean []float64
|
|
cov []float64
|
|
}{
|
|
{
|
|
mean: []float64{6, 7},
|
|
cov: []float64{
|
|
5, 0.9,
|
|
0.9, 2,
|
|
},
|
|
},
|
|
} {
|
|
dim := len(test.mean)
|
|
cov := mat64.NewSymDense(dim, test.cov)
|
|
n, ok := NewNormal(test.mean, cov, nil)
|
|
if !ok {
|
|
t.Errorf("bad covariance matrix")
|
|
}
|
|
|
|
nSamples := 1000000
|
|
rnd := rand.New(rand.NewSource(1))
|
|
samps := mat64.NewDense(nSamples, dim, nil)
|
|
tmp := make([]float64, dim)
|
|
for i := 0; i < nSamples; i++ {
|
|
for j := range tmp {
|
|
tmp[j] = rnd.Float64()
|
|
}
|
|
n.Quantile(samps.RawRowView(i), tmp)
|
|
}
|
|
estMean := make([]float64, dim)
|
|
for i := range estMean {
|
|
estMean[i] = stat.Mean(mat64.Col(nil, i, samps), nil)
|
|
}
|
|
if !floats.EqualApprox(estMean, test.mean, 1e-2) {
|
|
t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean)
|
|
}
|
|
estCov := stat.CovarianceMatrix(nil, samps, nil)
|
|
if !mat64.EqualApprox(estCov, cov, 1e-2) {
|
|
t.Errorf("Cov mismatch: want: %v, got %v", cov, estCov)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestConditionNormal(t *testing.T) {
|
|
// Uncorrelated values shouldn't influence the updated values.
|
|
for _, test := range []struct {
|
|
mu []float64
|
|
sigma *mat64.SymDense
|
|
observed []int
|
|
values []float64
|
|
|
|
newMu []float64
|
|
newSigma *mat64.SymDense
|
|
}{
|
|
{
|
|
mu: []float64{2, 3},
|
|
sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
|
|
observed: []int{0},
|
|
values: []float64{10},
|
|
|
|
newMu: []float64{3},
|
|
newSigma: mat64.NewSymDense(1, []float64{5}),
|
|
},
|
|
{
|
|
mu: []float64{2, 3},
|
|
sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
|
|
observed: []int{1},
|
|
values: []float64{10},
|
|
|
|
newMu: []float64{2},
|
|
newSigma: mat64.NewSymDense(1, []float64{2}),
|
|
},
|
|
{
|
|
mu: []float64{2, 3, 4},
|
|
sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
|
|
observed: []int{1},
|
|
values: []float64{10},
|
|
|
|
newMu: []float64{2, 4},
|
|
newSigma: mat64.NewSymDense(2, []float64{2, 0, 0, 10}),
|
|
},
|
|
{
|
|
mu: []float64{2, 3, 4},
|
|
sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
|
|
observed: []int{0, 1},
|
|
values: []float64{10, 15},
|
|
|
|
newMu: []float64{4},
|
|
newSigma: mat64.NewSymDense(1, []float64{10}),
|
|
},
|
|
{
|
|
mu: []float64{2, 3, 4, 5},
|
|
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 0, 0, 0.5, 5, 0, 0, 0, 0, 10, 2, 0, 0, 2, 3}),
|
|
observed: []int{0, 1},
|
|
values: []float64{10, 15},
|
|
|
|
newMu: []float64{4, 5},
|
|
newSigma: mat64.NewSymDense(2, []float64{10, 2, 2, 3}),
|
|
},
|
|
} {
|
|
normal, ok := NewNormal(test.mu, test.sigma, nil)
|
|
if !ok {
|
|
t.Fatalf("Bad test, original sigma not positive definite")
|
|
}
|
|
newNormal, ok := normal.ConditionNormal(test.observed, test.values, nil)
|
|
if !ok {
|
|
t.Fatalf("Bad test, update failure")
|
|
}
|
|
|
|
if !floats.EqualApprox(test.newMu, newNormal.mu, 1e-12) {
|
|
t.Errorf("Updated mean mismatch. Want %v, got %v.", test.newMu, newNormal.mu)
|
|
}
|
|
|
|
var sigma mat64.SymDense
|
|
sigma.FromCholesky(&newNormal.chol)
|
|
if !mat64.EqualApprox(test.newSigma, &sigma, 1e-12) {
|
|
t.Errorf("Updated sigma mismatch\n.Want:\n% v\nGot:\n% v\n", test.newSigma, sigma)
|
|
}
|
|
}
|
|
|
|
// Test bivariate case where the update rule is analytic
|
|
for _, test := range []struct {
|
|
mu []float64
|
|
std []float64
|
|
rho float64
|
|
value float64
|
|
}{
|
|
{
|
|
mu: []float64{2, 3},
|
|
std: []float64{3, 5},
|
|
rho: 0.9,
|
|
value: 1000,
|
|
},
|
|
{
|
|
mu: []float64{2, 3},
|
|
std: []float64{3, 5},
|
|
rho: -0.9,
|
|
value: 1000,
|
|
},
|
|
} {
|
|
std := test.std
|
|
rho := test.rho
|
|
sigma := mat64.NewSymDense(2, []float64{std[0] * std[0], std[0] * std[1] * rho, std[0] * std[1] * rho, std[1] * std[1]})
|
|
normal, ok := NewNormal(test.mu, sigma, nil)
|
|
if !ok {
|
|
t.Fatalf("Bad test, original sigma not positive definite")
|
|
}
|
|
newNormal, ok := normal.ConditionNormal([]int{1}, []float64{test.value}, nil)
|
|
if !ok {
|
|
t.Fatalf("Bad test, update failed")
|
|
}
|
|
var newSigma mat64.SymDense
|
|
newSigma.FromCholesky(&newNormal.chol)
|
|
trueMean := test.mu[0] + rho*(std[0]/std[1])*(test.value-test.mu[1])
|
|
if math.Abs(trueMean-newNormal.mu[0]) > 1e-14 {
|
|
t.Errorf("Mean mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
|
|
}
|
|
trueVar := (1 - rho*rho) * std[0] * std[0]
|
|
if math.Abs(trueVar-newSigma.At(0, 0)) > 1e-14 {
|
|
t.Errorf("Std mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
|
|
}
|
|
}
|
|
|
|
// Test via sampling.
|
|
for _, test := range []struct {
|
|
mu []float64
|
|
sigma *mat64.SymDense
|
|
observed []int
|
|
unobserved []int
|
|
value []float64
|
|
}{
|
|
// The indices in unobserved must be in ascending order for this test.
|
|
{
|
|
mu: []float64{2, 3, 4},
|
|
sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
|
|
|
|
observed: []int{0},
|
|
unobserved: []int{1, 2},
|
|
value: []float64{1.9},
|
|
},
|
|
{
|
|
mu: []float64{2, 3, 4, 5},
|
|
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
|
|
|
|
observed: []int{0, 3},
|
|
unobserved: []int{1, 2},
|
|
value: []float64{1.9, 2.9},
|
|
},
|
|
} {
|
|
totalSamp := 4000000
|
|
var nSamp int
|
|
samples := mat64.NewDense(totalSamp, len(test.mu), nil)
|
|
normal, ok := NewNormal(test.mu, test.sigma, nil)
|
|
if !ok {
|
|
t.Errorf("bad test")
|
|
}
|
|
sample := make([]float64, len(test.mu))
|
|
for i := 0; i < totalSamp; i++ {
|
|
normal.Rand(sample)
|
|
isClose := true
|
|
for i, v := range test.observed {
|
|
if math.Abs(sample[v]-test.value[i]) > 1e-1 {
|
|
isClose = false
|
|
break
|
|
}
|
|
}
|
|
if isClose {
|
|
samples.SetRow(nSamp, sample)
|
|
nSamp++
|
|
}
|
|
}
|
|
|
|
if nSamp < 100 {
|
|
t.Errorf("bad test, not enough samples")
|
|
continue
|
|
}
|
|
samples = samples.View(0, 0, nSamp, len(test.mu)).(*mat64.Dense)
|
|
|
|
// Compute mean and covariance matrix.
|
|
estMean := make([]float64, len(test.mu))
|
|
for i := range estMean {
|
|
estMean[i] = stat.Mean(mat64.Col(nil, i, samples), nil)
|
|
}
|
|
estCov := stat.CovarianceMatrix(nil, samples, nil)
|
|
|
|
// Compute update rule.
|
|
newNormal, ok := normal.ConditionNormal(test.observed, test.value, nil)
|
|
if !ok {
|
|
t.Fatalf("Bad test, update failure")
|
|
}
|
|
|
|
var subEstMean []float64
|
|
for _, v := range test.unobserved {
|
|
|
|
subEstMean = append(subEstMean, estMean[v])
|
|
}
|
|
subEstCov := mat64.NewSymDense(len(test.unobserved), nil)
|
|
for i := 0; i < len(test.unobserved); i++ {
|
|
for j := i; j < len(test.unobserved); j++ {
|
|
subEstCov.SetSym(i, j, estCov.At(test.unobserved[i], test.unobserved[j]))
|
|
}
|
|
}
|
|
|
|
for i, v := range subEstMean {
|
|
if math.Abs(newNormal.mu[i]-v) > 5e-2 {
|
|
t.Errorf("Mean mismatch. Want %v, got %v.", newNormal.mu[i], v)
|
|
}
|
|
}
|
|
var sigma mat64.SymDense
|
|
sigma.FromCholesky(&newNormal.chol)
|
|
if !mat64.EqualApprox(&sigma, subEstCov, 1e-1) {
|
|
t.Errorf("Covariance mismatch. Want:\n%0.8v\nGot:\n%0.8v\n", subEstCov, sigma)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCovarianceMatrix(t *testing.T) {
|
|
for _, test := range []struct {
|
|
mu []float64
|
|
sigma *mat64.SymDense
|
|
}{
|
|
{
|
|
mu: []float64{2, 3, 4},
|
|
sigma: mat64.NewSymDense(3, []float64{1, 0.5, 3, 0.5, 8, -1, 3, -1, 15}),
|
|
},
|
|
} {
|
|
normal, ok := NewNormal(test.mu, test.sigma, nil)
|
|
if !ok {
|
|
t.Fatalf("Bad test, covariance matrix not positive definite")
|
|
}
|
|
cov := normal.CovarianceMatrix(nil)
|
|
if !mat64.EqualApprox(cov, test.sigma, 1e-14) {
|
|
t.Errorf("Covariance mismatch with nil input")
|
|
}
|
|
dim := test.sigma.Symmetric()
|
|
cov = mat64.NewSymDense(dim, nil)
|
|
normal.CovarianceMatrix(cov)
|
|
if !mat64.EqualApprox(cov, test.sigma, 1e-14) {
|
|
t.Errorf("Covariance mismatch with supplied input")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestMarginal(t *testing.T) {
|
|
for _, test := range []struct {
|
|
mu []float64
|
|
sigma *mat64.SymDense
|
|
marginal []int
|
|
}{
|
|
{
|
|
mu: []float64{2, 3, 4},
|
|
sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
|
|
marginal: []int{0},
|
|
},
|
|
{
|
|
mu: []float64{2, 3, 4},
|
|
sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
|
|
marginal: []int{0, 2},
|
|
},
|
|
{
|
|
mu: []float64{2, 3, 4, 5},
|
|
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
|
|
|
|
marginal: []int{0, 3},
|
|
},
|
|
} {
|
|
normal, ok := NewNormal(test.mu, test.sigma, nil)
|
|
if !ok {
|
|
t.Fatalf("Bad test, covariance matrix not positive definite")
|
|
}
|
|
marginal, ok := normal.MarginalNormal(test.marginal, nil)
|
|
if !ok {
|
|
t.Fatalf("Bad test, marginal matrix not positive definite")
|
|
}
|
|
dim := normal.Dim()
|
|
nSamples := 1000000
|
|
samps := mat64.NewDense(nSamples, dim, nil)
|
|
for i := 0; i < nSamples; i++ {
|
|
normal.Rand(samps.RawRowView(i))
|
|
}
|
|
estMean := make([]float64, dim)
|
|
for i := range estMean {
|
|
estMean[i] = stat.Mean(mat64.Col(nil, i, samps), nil)
|
|
}
|
|
for i, v := range test.marginal {
|
|
if math.Abs(marginal.mu[i]-estMean[v]) > 1e-2 {
|
|
t.Errorf("Mean mismatch: want: %v, got %v", estMean[v], marginal.mu[i])
|
|
}
|
|
}
|
|
|
|
marginalCov := marginal.CovarianceMatrix(nil)
|
|
estCov := stat.CovarianceMatrix(nil, samps, nil)
|
|
for i, v1 := range test.marginal {
|
|
for j, v2 := range test.marginal {
|
|
c := marginalCov.At(i, j)
|
|
ec := estCov.At(v1, v2)
|
|
if math.Abs(c-ec) > 5e-2 {
|
|
t.Errorf("Cov mismatch element i = %d, j = %d: want: %v, got %v", i, j, c, ec)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestMarginalSingle(t *testing.T) {
|
|
for _, test := range []struct {
|
|
mu []float64
|
|
sigma *mat64.SymDense
|
|
}{
|
|
{
|
|
mu: []float64{2, 3, 4},
|
|
sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
|
|
},
|
|
{
|
|
mu: []float64{2, 3, 4, 5},
|
|
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
|
|
},
|
|
} {
|
|
normal, ok := NewNormal(test.mu, test.sigma, nil)
|
|
if !ok {
|
|
t.Fatalf("Bad test, covariance matrix not positive definite")
|
|
}
|
|
for i, mean := range test.mu {
|
|
norm := normal.MarginalNormalSingle(i, nil)
|
|
if norm.Mean() != mean {
|
|
t.Errorf("Mean mismatch nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean())
|
|
}
|
|
std := math.Sqrt(test.sigma.At(i, i))
|
|
if math.Abs(norm.StdDev()-std) > 1e-14 {
|
|
t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev())
|
|
}
|
|
}
|
|
}
|
|
|
|
// Test matching with TestMarginal.
|
|
rnd := rand.New(rand.NewSource(1))
|
|
for cas := 0; cas < 10; cas++ {
|
|
dim := rnd.Intn(10) + 1
|
|
mu := make([]float64, dim)
|
|
for i := range mu {
|
|
mu[i] = rnd.Float64()
|
|
}
|
|
x := make([]float64, dim*dim)
|
|
for i := range x {
|
|
x[i] = rnd.Float64()
|
|
}
|
|
mat := mat64.NewDense(dim, dim, x)
|
|
var sigma mat64.SymDense
|
|
sigma.SymOuterK(1, mat)
|
|
|
|
normal, ok := NewNormal(mu, &sigma, nil)
|
|
if !ok {
|
|
t.Fatal("bad test")
|
|
}
|
|
for i := 0; i < dim; i++ {
|
|
single := normal.MarginalNormalSingle(i, nil)
|
|
mult, ok := normal.MarginalNormal([]int{i}, nil)
|
|
if !ok {
|
|
t.Fatal("bad test")
|
|
}
|
|
if math.Abs(single.Mean()-mult.Mean(nil)[0]) > 1e-14 {
|
|
t.Errorf("Mean mismatch")
|
|
}
|
|
if math.Abs(single.Variance()-mult.CovarianceMatrix(nil).At(0, 0)) > 1e-14 {
|
|
t.Errorf("Variance mismatch")
|
|
}
|
|
}
|
|
}
|
|
}
|