mirror of
https://github.com/gonum/gonum.git
synced 2025-10-25 08:10:28 +08:00
Add ConditionNormal
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package distmv
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
@@ -96,3 +97,216 @@ func TestNormRand(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConditionNormal(t *testing.T) {
|
||||
// Uncorrelated values shouldn't influence the updated values.
|
||||
for _, test := range []struct {
|
||||
mu []float64
|
||||
sigma *mat64.SymDense
|
||||
observed []int
|
||||
values []float64
|
||||
|
||||
newMu []float64
|
||||
newSigma *mat64.SymDense
|
||||
}{
|
||||
{
|
||||
mu: []float64{2, 3},
|
||||
sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
|
||||
observed: []int{0},
|
||||
values: []float64{10},
|
||||
|
||||
newMu: []float64{3},
|
||||
newSigma: mat64.NewSymDense(1, []float64{5}),
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3},
|
||||
sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
|
||||
observed: []int{1},
|
||||
values: []float64{10},
|
||||
|
||||
newMu: []float64{2},
|
||||
newSigma: mat64.NewSymDense(1, []float64{2}),
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3, 4},
|
||||
sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
|
||||
observed: []int{1},
|
||||
values: []float64{10},
|
||||
|
||||
newMu: []float64{2, 4},
|
||||
newSigma: mat64.NewSymDense(2, []float64{2, 0, 0, 10}),
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3, 4},
|
||||
sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
|
||||
observed: []int{0, 1},
|
||||
values: []float64{10, 15},
|
||||
|
||||
newMu: []float64{4},
|
||||
newSigma: mat64.NewSymDense(1, []float64{10}),
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3, 4, 5},
|
||||
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 0, 0, 0.5, 5, 0, 0, 0, 0, 10, 2, 0, 0, 2, 3}),
|
||||
observed: []int{0, 1},
|
||||
values: []float64{10, 15},
|
||||
|
||||
newMu: []float64{4, 5},
|
||||
newSigma: mat64.NewSymDense(2, []float64{10, 2, 2, 3}),
|
||||
},
|
||||
} {
|
||||
normal, ok := NewNormal(test.mu, test.sigma, nil)
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, original sigma not positive definite")
|
||||
}
|
||||
newNormal, ok := normal.ConditionNormal(test.observed, test.values, nil)
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, update failure")
|
||||
}
|
||||
|
||||
if !floats.EqualApprox(test.newMu, newNormal.mu, 1e-12) {
|
||||
t.Errorf("Updated mean mismatch. Want %v, got %v.", test.newMu, newNormal.mu)
|
||||
}
|
||||
|
||||
var sigma mat64.SymDense
|
||||
sigma.FromCholesky(&newNormal.chol)
|
||||
if !mat64.EqualApprox(test.newSigma, &sigma, 1e-12) {
|
||||
t.Errorf("Updated sigma mismatch\n.Want:\n% v\nGot:\n% v\n", test.newSigma, sigma)
|
||||
}
|
||||
}
|
||||
|
||||
// Test bivariate case where the update rule is analytic
|
||||
for _, test := range []struct {
|
||||
mu []float64
|
||||
std []float64
|
||||
rho float64
|
||||
value float64
|
||||
}{
|
||||
{
|
||||
mu: []float64{2, 3},
|
||||
std: []float64{3, 5},
|
||||
rho: 0.9,
|
||||
value: 1000,
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3},
|
||||
std: []float64{3, 5},
|
||||
rho: -0.9,
|
||||
value: 1000,
|
||||
},
|
||||
} {
|
||||
std := test.std
|
||||
rho := test.rho
|
||||
sigma := mat64.NewSymDense(2, []float64{std[0] * std[0], std[0] * std[1] * rho, std[0] * std[1] * rho, std[1] * std[1]})
|
||||
normal, ok := NewNormal(test.mu, sigma, nil)
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, original sigma not positive definite")
|
||||
}
|
||||
newNormal, ok := normal.ConditionNormal([]int{1}, []float64{test.value}, nil)
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, update failed")
|
||||
}
|
||||
var newSigma mat64.SymDense
|
||||
newSigma.FromCholesky(&newNormal.chol)
|
||||
trueMean := test.mu[0] + rho*(std[0]/std[1])*(test.value-test.mu[1])
|
||||
if math.Abs(trueMean-newNormal.mu[0]) > 1e-14 {
|
||||
t.Errorf("Mean mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
|
||||
}
|
||||
trueVar := (1 - rho*rho) * std[0] * std[0]
|
||||
if math.Abs(trueVar-newSigma.At(0, 0)) > 1e-14 {
|
||||
t.Errorf("Std mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Test via sampling.
|
||||
for _, test := range []struct {
|
||||
mu []float64
|
||||
sigma *mat64.SymDense
|
||||
observed []int
|
||||
unobserved []int
|
||||
value []float64
|
||||
}{
|
||||
// The indices in unobserved must be in ascending order for this test.
|
||||
{
|
||||
mu: []float64{2, 3, 4},
|
||||
sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
|
||||
|
||||
observed: []int{0},
|
||||
unobserved: []int{1, 2},
|
||||
value: []float64{1.9},
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3, 4, 5},
|
||||
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
|
||||
|
||||
observed: []int{0, 3},
|
||||
unobserved: []int{1, 2},
|
||||
value: []float64{1.9, 2.9},
|
||||
},
|
||||
} {
|
||||
totalSamp := 4000000
|
||||
var nSamp int
|
||||
samples := mat64.NewDense(totalSamp, len(test.mu), nil)
|
||||
normal, ok := NewNormal(test.mu, test.sigma, nil)
|
||||
if !ok {
|
||||
t.Errorf("bad test")
|
||||
}
|
||||
sample := make([]float64, len(test.mu))
|
||||
for i := 0; i < totalSamp; i++ {
|
||||
normal.Rand(sample)
|
||||
isClose := true
|
||||
for i, v := range test.observed {
|
||||
if math.Abs(sample[v]-test.value[i]) > 1e-1 {
|
||||
isClose = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isClose {
|
||||
samples.SetRow(nSamp, sample)
|
||||
nSamp++
|
||||
}
|
||||
}
|
||||
|
||||
if nSamp < 100 {
|
||||
t.Errorf("bad test, not enough samples")
|
||||
continue
|
||||
}
|
||||
samples = samples.View(0, 0, nSamp, len(test.mu)).(*mat64.Dense)
|
||||
|
||||
// Compute mean and covariance matrix.
|
||||
estMean := make([]float64, len(test.mu))
|
||||
for i := range estMean {
|
||||
estMean[i] = stat.Mean(mat64.Col(nil, i, samples), nil)
|
||||
}
|
||||
estCov := stat.CovarianceMatrix(nil, samples, nil)
|
||||
|
||||
// Compute update rule.
|
||||
newNormal, ok := normal.ConditionNormal(test.observed, test.value, nil)
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, update failure")
|
||||
}
|
||||
|
||||
var subEstMean []float64
|
||||
for _, v := range test.unobserved {
|
||||
|
||||
subEstMean = append(subEstMean, estMean[v])
|
||||
}
|
||||
subEstCov := mat64.NewSymDense(len(test.unobserved), nil)
|
||||
for i := 0; i < len(test.unobserved); i++ {
|
||||
for j := i; j < len(test.unobserved); j++ {
|
||||
subEstCov.SetSym(i, j, estCov.At(test.unobserved[i], test.unobserved[j]))
|
||||
}
|
||||
}
|
||||
|
||||
for i, v := range subEstMean {
|
||||
if math.Abs(newNormal.mu[i]-v) > 5e-2 {
|
||||
t.Errorf("Mean mismatch. Want %v, got %v.", newNormal.mu[i], v)
|
||||
}
|
||||
}
|
||||
var sigma mat64.SymDense
|
||||
sigma.FromCholesky(&newNormal.chol)
|
||||
if !mat64.EqualApprox(&sigma, subEstCov, 1e-1) {
|
||||
t.Errorf("Covariance mismatch. Want:\n%0.8v\nGot:\n%0.8v\n", subEstCov, sigma)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user