Add ConditionNormal

This commit is contained in:
btracey
2015-11-09 21:28:04 -07:00
parent 74a6648c88
commit db9cfea7b0
2 changed files with 324 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
package distmv
import (
"math"
"testing"
"github.com/gonum/floats"
@@ -96,3 +97,216 @@ func TestNormRand(t *testing.T) {
}
}
}
func TestConditionNormal(t *testing.T) {
// Uncorrelated values shouldn't influence the updated values.
for _, test := range []struct {
mu []float64
sigma *mat64.SymDense
observed []int
values []float64
newMu []float64
newSigma *mat64.SymDense
}{
{
mu: []float64{2, 3},
sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
observed: []int{0},
values: []float64{10},
newMu: []float64{3},
newSigma: mat64.NewSymDense(1, []float64{5}),
},
{
mu: []float64{2, 3},
sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
observed: []int{1},
values: []float64{10},
newMu: []float64{2},
newSigma: mat64.NewSymDense(1, []float64{2}),
},
{
mu: []float64{2, 3, 4},
sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
observed: []int{1},
values: []float64{10},
newMu: []float64{2, 4},
newSigma: mat64.NewSymDense(2, []float64{2, 0, 0, 10}),
},
{
mu: []float64{2, 3, 4},
sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
observed: []int{0, 1},
values: []float64{10, 15},
newMu: []float64{4},
newSigma: mat64.NewSymDense(1, []float64{10}),
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 0, 0, 0.5, 5, 0, 0, 0, 0, 10, 2, 0, 0, 2, 3}),
observed: []int{0, 1},
values: []float64{10, 15},
newMu: []float64{4, 5},
newSigma: mat64.NewSymDense(2, []float64{10, 2, 2, 3}),
},
} {
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Fatalf("Bad test, original sigma not positive definite")
}
newNormal, ok := normal.ConditionNormal(test.observed, test.values, nil)
if !ok {
t.Fatalf("Bad test, update failure")
}
if !floats.EqualApprox(test.newMu, newNormal.mu, 1e-12) {
t.Errorf("Updated mean mismatch. Want %v, got %v.", test.newMu, newNormal.mu)
}
var sigma mat64.SymDense
sigma.FromCholesky(&newNormal.chol)
if !mat64.EqualApprox(test.newSigma, &sigma, 1e-12) {
t.Errorf("Updated sigma mismatch\n.Want:\n% v\nGot:\n% v\n", test.newSigma, sigma)
}
}
// Test bivariate case where the update rule is analytic
for _, test := range []struct {
mu []float64
std []float64
rho float64
value float64
}{
{
mu: []float64{2, 3},
std: []float64{3, 5},
rho: 0.9,
value: 1000,
},
{
mu: []float64{2, 3},
std: []float64{3, 5},
rho: -0.9,
value: 1000,
},
} {
std := test.std
rho := test.rho
sigma := mat64.NewSymDense(2, []float64{std[0] * std[0], std[0] * std[1] * rho, std[0] * std[1] * rho, std[1] * std[1]})
normal, ok := NewNormal(test.mu, sigma, nil)
if !ok {
t.Fatalf("Bad test, original sigma not positive definite")
}
newNormal, ok := normal.ConditionNormal([]int{1}, []float64{test.value}, nil)
if !ok {
t.Fatalf("Bad test, update failed")
}
var newSigma mat64.SymDense
newSigma.FromCholesky(&newNormal.chol)
trueMean := test.mu[0] + rho*(std[0]/std[1])*(test.value-test.mu[1])
if math.Abs(trueMean-newNormal.mu[0]) > 1e-14 {
t.Errorf("Mean mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
}
trueVar := (1 - rho*rho) * std[0] * std[0]
if math.Abs(trueVar-newSigma.At(0, 0)) > 1e-14 {
t.Errorf("Std mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
}
}
// Test via sampling.
for _, test := range []struct {
mu []float64
sigma *mat64.SymDense
observed []int
unobserved []int
value []float64
}{
// The indices in unobserved must be in ascending order for this test.
{
mu: []float64{2, 3, 4},
sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
observed: []int{0},
unobserved: []int{1, 2},
value: []float64{1.9},
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
observed: []int{0, 3},
unobserved: []int{1, 2},
value: []float64{1.9, 2.9},
},
} {
totalSamp := 4000000
var nSamp int
samples := mat64.NewDense(totalSamp, len(test.mu), nil)
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Errorf("bad test")
}
sample := make([]float64, len(test.mu))
for i := 0; i < totalSamp; i++ {
normal.Rand(sample)
isClose := true
for i, v := range test.observed {
if math.Abs(sample[v]-test.value[i]) > 1e-1 {
isClose = false
break
}
}
if isClose {
samples.SetRow(nSamp, sample)
nSamp++
}
}
if nSamp < 100 {
t.Errorf("bad test, not enough samples")
continue
}
samples = samples.View(0, 0, nSamp, len(test.mu)).(*mat64.Dense)
// Compute mean and covariance matrix.
estMean := make([]float64, len(test.mu))
for i := range estMean {
estMean[i] = stat.Mean(mat64.Col(nil, i, samples), nil)
}
estCov := stat.CovarianceMatrix(nil, samples, nil)
// Compute update rule.
newNormal, ok := normal.ConditionNormal(test.observed, test.value, nil)
if !ok {
t.Fatalf("Bad test, update failure")
}
var subEstMean []float64
for _, v := range test.unobserved {
subEstMean = append(subEstMean, estMean[v])
}
subEstCov := mat64.NewSymDense(len(test.unobserved), nil)
for i := 0; i < len(test.unobserved); i++ {
for j := i; j < len(test.unobserved); j++ {
subEstCov.SetSym(i, j, estCov.At(test.unobserved[i], test.unobserved[j]))
}
}
for i, v := range subEstMean {
if math.Abs(newNormal.mu[i]-v) > 5e-2 {
t.Errorf("Mean mismatch. Want %v, got %v.", newNormal.mu[i], v)
}
}
var sigma mat64.SymDense
sigma.FromCholesky(&newNormal.chol)
if !mat64.EqualApprox(&sigma, subEstCov, 1e-1) {
t.Errorf("Covariance mismatch. Want:\n%0.8v\nGot:\n%0.8v\n", subEstCov, sigma)
}
}
}