Add ConditionNormal

This commit is contained in:
btracey
2015-11-09 21:28:04 -07:00
parent 74a6648c88
commit db9cfea7b0
2 changed files with 324 additions and 0 deletions

View File

@@ -7,6 +7,7 @@ package distmv
import (
"math"
"math/rand"
"sync"
"github.com/gonum/floats"
"github.com/gonum/matrix/mat64"
@@ -20,6 +21,9 @@ import (
type Normal struct {
mu []float64
once sync.Once
sigma *mat64.SymDense // only stored if needed
chol mat64.Cholesky
lower mat64.TriDense
logSqrtDet float64
@@ -118,3 +122,109 @@ func (n *Normal) Rand(x []float64) []float64 {
floats.Add(x, n.mu)
return x
}
// ConditionNormal returns the Normal distribution that is the receiver conditioned
// on the input evidence. The returned multivariate normal has dimension
// n - len(observed), where n is the dimension of the original receiver. The updated
// mean and covariance are
// mu = mu_un + sigma_{ob,un}^T * sigma_{ob,ob}^-1 (v - mu_ob)
// sigma = sigma_{un,un} - sigma_{ob,un}^T * sigma_{ob,ob}^-1 * sigma_{ob,un}
// where mu_un and mu_ob are the original means of the unobserved and observed
// variables respectively, sigma_{un,un} is the unobserved subset of the covariance
// matrix, sigma_{ob,ob} is the observed subset of the covariance matrix, and
// sigma_{un,ob} are the cross terms. The elements of x_2 have been observed with
// values v. The dimension order is preserved during conditioning, so if the value
// of dimension 1 is observed, the returned normal represents dimensions {0, 2, ...}
// of the original Normal distribution.
//
// ConditionNormal returns {nil, false} if there is a failure during the update.
// Mathematically this is impossible, but can occur with finite precision arithmetic.
func (n *Normal) ConditionNormal(observed []int, values []float64, src *rand.Rand) (*Normal, bool) {
if len(observed) != len(values) {
panic("normal: input slice length mismatch")
}
n.setSigma()
ob := len(observed)
unob := n.Dim() - ob
obMap := make(map[int]struct{})
for _, v := range observed {
if _, ok := obMap[v]; ok {
panic("normal: observed dimension occurs twice")
}
obMap[v] = struct{}{}
}
unobserved := make([]int, 0, unob)
for i := 0; i < n.Dim(); i++ {
if _, ok := obMap[i]; !ok {
unobserved = append(unobserved, i)
}
}
mu1 := make([]float64, unob)
for i, v := range unobserved {
mu1[i] = n.mu[v]
}
mu2 := make([]float64, ob) // really v - mu2
for i, v := range observed {
mu2[i] = values[i] - n.mu[v]
}
var sigma11, sigma22 mat64.SymDense
sigma11.SubsetSym(n.sigma, unobserved)
sigma22.SubsetSym(n.sigma, observed)
sigma21 := mat64.NewDense(ob, unob, nil)
for i, r := range observed {
for j, c := range unobserved {
v := n.sigma.At(r, c)
sigma21.Set(i, j, v)
}
}
var chol mat64.Cholesky
ok := chol.Factorize(&sigma22)
if !ok {
return nil, ok
}
// Compute sigma_{2,1}^T * sigma_{2,2}^-1 (v - mu_2).
v := mat64.NewVector(ob, mu2)
var tmp, tmp2 mat64.Vector
err := tmp.SolveCholeskyVec(&chol, v)
if err != nil {
return nil, false
}
tmp2.MulVec(sigma21.T(), &tmp)
// Compute sigma_{2,1}^T * sigma_{2,2}^-1 * sigma_{2,1}.
// TODO(btracey): Should this be a method of SymDense?
var tmp3, tmp4 mat64.Dense
err = tmp3.SolveCholesky(&chol, sigma21)
if err != nil {
return nil, false
}
tmp4.Mul(sigma21.T(), &tmp3)
for i := range mu1 {
mu1[i] += tmp2.At(i, 0)
}
// TODO(btracey): If tmp2 can constructed with a method, then this can be
// replaced with SubSym.
for i := 0; i < len(unobserved); i++ {
for j := i; j < len(unobserved); j++ {
v := sigma11.At(i, j)
sigma11.SetSym(i, j, v-tmp4.At(i, j))
}
}
return NewNormal(mu1, &sigma11, src)
}
// setSigma computes and stores the covariance matrix of the distribution.
func (n *Normal) setSigma() {
n.once.Do(func() {
n.sigma = mat64.NewSymDense(n.Dim(), nil)
n.sigma.FromCholesky(&n.chol)
})
}

View File

@@ -1,6 +1,7 @@
package distmv
import (
"math"
"testing"
"github.com/gonum/floats"
@@ -96,3 +97,216 @@ func TestNormRand(t *testing.T) {
}
}
}
func TestConditionNormal(t *testing.T) {
// Uncorrelated values shouldn't influence the updated values.
for _, test := range []struct {
mu []float64
sigma *mat64.SymDense
observed []int
values []float64
newMu []float64
newSigma *mat64.SymDense
}{
{
mu: []float64{2, 3},
sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
observed: []int{0},
values: []float64{10},
newMu: []float64{3},
newSigma: mat64.NewSymDense(1, []float64{5}),
},
{
mu: []float64{2, 3},
sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
observed: []int{1},
values: []float64{10},
newMu: []float64{2},
newSigma: mat64.NewSymDense(1, []float64{2}),
},
{
mu: []float64{2, 3, 4},
sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
observed: []int{1},
values: []float64{10},
newMu: []float64{2, 4},
newSigma: mat64.NewSymDense(2, []float64{2, 0, 0, 10}),
},
{
mu: []float64{2, 3, 4},
sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
observed: []int{0, 1},
values: []float64{10, 15},
newMu: []float64{4},
newSigma: mat64.NewSymDense(1, []float64{10}),
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 0, 0, 0.5, 5, 0, 0, 0, 0, 10, 2, 0, 0, 2, 3}),
observed: []int{0, 1},
values: []float64{10, 15},
newMu: []float64{4, 5},
newSigma: mat64.NewSymDense(2, []float64{10, 2, 2, 3}),
},
} {
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Fatalf("Bad test, original sigma not positive definite")
}
newNormal, ok := normal.ConditionNormal(test.observed, test.values, nil)
if !ok {
t.Fatalf("Bad test, update failure")
}
if !floats.EqualApprox(test.newMu, newNormal.mu, 1e-12) {
t.Errorf("Updated mean mismatch. Want %v, got %v.", test.newMu, newNormal.mu)
}
var sigma mat64.SymDense
sigma.FromCholesky(&newNormal.chol)
if !mat64.EqualApprox(test.newSigma, &sigma, 1e-12) {
t.Errorf("Updated sigma mismatch\n.Want:\n% v\nGot:\n% v\n", test.newSigma, sigma)
}
}
// Test bivariate case where the update rule is analytic
for _, test := range []struct {
mu []float64
std []float64
rho float64
value float64
}{
{
mu: []float64{2, 3},
std: []float64{3, 5},
rho: 0.9,
value: 1000,
},
{
mu: []float64{2, 3},
std: []float64{3, 5},
rho: -0.9,
value: 1000,
},
} {
std := test.std
rho := test.rho
sigma := mat64.NewSymDense(2, []float64{std[0] * std[0], std[0] * std[1] * rho, std[0] * std[1] * rho, std[1] * std[1]})
normal, ok := NewNormal(test.mu, sigma, nil)
if !ok {
t.Fatalf("Bad test, original sigma not positive definite")
}
newNormal, ok := normal.ConditionNormal([]int{1}, []float64{test.value}, nil)
if !ok {
t.Fatalf("Bad test, update failed")
}
var newSigma mat64.SymDense
newSigma.FromCholesky(&newNormal.chol)
trueMean := test.mu[0] + rho*(std[0]/std[1])*(test.value-test.mu[1])
if math.Abs(trueMean-newNormal.mu[0]) > 1e-14 {
t.Errorf("Mean mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
}
trueVar := (1 - rho*rho) * std[0] * std[0]
if math.Abs(trueVar-newSigma.At(0, 0)) > 1e-14 {
t.Errorf("Std mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
}
}
// Test via sampling.
for _, test := range []struct {
mu []float64
sigma *mat64.SymDense
observed []int
unobserved []int
value []float64
}{
// The indices in unobserved must be in ascending order for this test.
{
mu: []float64{2, 3, 4},
sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
observed: []int{0},
unobserved: []int{1, 2},
value: []float64{1.9},
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
observed: []int{0, 3},
unobserved: []int{1, 2},
value: []float64{1.9, 2.9},
},
} {
totalSamp := 4000000
var nSamp int
samples := mat64.NewDense(totalSamp, len(test.mu), nil)
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Errorf("bad test")
}
sample := make([]float64, len(test.mu))
for i := 0; i < totalSamp; i++ {
normal.Rand(sample)
isClose := true
for i, v := range test.observed {
if math.Abs(sample[v]-test.value[i]) > 1e-1 {
isClose = false
break
}
}
if isClose {
samples.SetRow(nSamp, sample)
nSamp++
}
}
if nSamp < 100 {
t.Errorf("bad test, not enough samples")
continue
}
samples = samples.View(0, 0, nSamp, len(test.mu)).(*mat64.Dense)
// Compute mean and covariance matrix.
estMean := make([]float64, len(test.mu))
for i := range estMean {
estMean[i] = stat.Mean(mat64.Col(nil, i, samples), nil)
}
estCov := stat.CovarianceMatrix(nil, samples, nil)
// Compute update rule.
newNormal, ok := normal.ConditionNormal(test.observed, test.value, nil)
if !ok {
t.Fatalf("Bad test, update failure")
}
var subEstMean []float64
for _, v := range test.unobserved {
subEstMean = append(subEstMean, estMean[v])
}
subEstCov := mat64.NewSymDense(len(test.unobserved), nil)
for i := 0; i < len(test.unobserved); i++ {
for j := i; j < len(test.unobserved); j++ {
subEstCov.SetSym(i, j, estCov.At(test.unobserved[i], test.unobserved[j]))
}
}
for i, v := range subEstMean {
if math.Abs(newNormal.mu[i]-v) > 5e-2 {
t.Errorf("Mean mismatch. Want %v, got %v.", newNormal.mu[i], v)
}
}
var sigma mat64.SymDense
sigma.FromCholesky(&newNormal.chol)
if !mat64.EqualApprox(&sigma, subEstCov, 1e-1) {
t.Errorf("Covariance mismatch. Want:\n%0.8v\nGot:\n%0.8v\n", subEstCov, sigma)
}
}
}