From db9cfea7b08d58655fb691f90dc34921834514bb Mon Sep 17 00:00:00 2001 From: btracey Date: Mon, 9 Nov 2015 21:28:04 -0700 Subject: [PATCH] Add ConditionNormal --- distmv/normal.go | 110 ++++++++++++++++++++++ distmv/normal_test.go | 214 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 324 insertions(+) diff --git a/distmv/normal.go b/distmv/normal.go index 9753c1c5..ecbbbca4 100644 --- a/distmv/normal.go +++ b/distmv/normal.go @@ -7,6 +7,7 @@ package distmv import ( "math" "math/rand" + "sync" "github.com/gonum/floats" "github.com/gonum/matrix/mat64" @@ -20,6 +21,9 @@ import ( type Normal struct { mu []float64 + once sync.Once + sigma *mat64.SymDense // only stored if needed + chol mat64.Cholesky lower mat64.TriDense logSqrtDet float64 @@ -118,3 +122,109 @@ func (n *Normal) Rand(x []float64) []float64 { floats.Add(x, n.mu) return x } + +// ConditionNormal returns the Normal distribution that is the receiver conditioned +// on the input evidence. The returned multivariate normal has dimension +// n - len(observed), where n is the dimension of the original receiver. The updated +// mean and covariance are +// mu = mu_un + sigma_{ob,un}^T * sigma_{ob,ob}^-1 (v - mu_ob) +// sigma = sigma_{un,un} - sigma_{ob,un}^T * sigma_{ob,ob}^-1 * sigma_{ob,un} +// where mu_un and mu_ob are the original means of the unobserved and observed +// variables respectively, sigma_{un,un} is the unobserved subset of the covariance +// matrix, sigma_{ob,ob} is the observed subset of the covariance matrix, and +// sigma_{un,ob} are the cross terms. The elements of x_2 have been observed with +// values v. The dimension order is preserved during conditioning, so if the value +// of dimension 1 is observed, the returned normal represents dimensions {0, 2, ...} +// of the original Normal distribution. +// +// ConditionNormal returns {nil, false} if there is a failure during the update. +// Mathematically this is impossible, but can occur with finite precision arithmetic. +func (n *Normal) ConditionNormal(observed []int, values []float64, src *rand.Rand) (*Normal, bool) { + if len(observed) != len(values) { + panic("normal: input slice length mismatch") + } + + n.setSigma() + + ob := len(observed) + unob := n.Dim() - ob + obMap := make(map[int]struct{}) + for _, v := range observed { + if _, ok := obMap[v]; ok { + panic("normal: observed dimension occurs twice") + } + obMap[v] = struct{}{} + } + unobserved := make([]int, 0, unob) + for i := 0; i < n.Dim(); i++ { + if _, ok := obMap[i]; !ok { + unobserved = append(unobserved, i) + } + } + mu1 := make([]float64, unob) + for i, v := range unobserved { + mu1[i] = n.mu[v] + } + mu2 := make([]float64, ob) // really v - mu2 + for i, v := range observed { + mu2[i] = values[i] - n.mu[v] + } + + var sigma11, sigma22 mat64.SymDense + sigma11.SubsetSym(n.sigma, unobserved) + sigma22.SubsetSym(n.sigma, observed) + + sigma21 := mat64.NewDense(ob, unob, nil) + for i, r := range observed { + for j, c := range unobserved { + v := n.sigma.At(r, c) + sigma21.Set(i, j, v) + } + } + + var chol mat64.Cholesky + ok := chol.Factorize(&sigma22) + if !ok { + return nil, ok + } + + // Compute sigma_{2,1}^T * sigma_{2,2}^-1 (v - mu_2). + v := mat64.NewVector(ob, mu2) + var tmp, tmp2 mat64.Vector + err := tmp.SolveCholeskyVec(&chol, v) + if err != nil { + return nil, false + } + tmp2.MulVec(sigma21.T(), &tmp) + + // Compute sigma_{2,1}^T * sigma_{2,2}^-1 * sigma_{2,1}. + // TODO(btracey): Should this be a method of SymDense? + var tmp3, tmp4 mat64.Dense + err = tmp3.SolveCholesky(&chol, sigma21) + if err != nil { + return nil, false + } + tmp4.Mul(sigma21.T(), &tmp3) + + for i := range mu1 { + mu1[i] += tmp2.At(i, 0) + } + + // TODO(btracey): If tmp2 can constructed with a method, then this can be + // replaced with SubSym. + for i := 0; i < len(unobserved); i++ { + for j := i; j < len(unobserved); j++ { + v := sigma11.At(i, j) + sigma11.SetSym(i, j, v-tmp4.At(i, j)) + } + } + return NewNormal(mu1, &sigma11, src) +} + +// setSigma computes and stores the covariance matrix of the distribution. +func (n *Normal) setSigma() { + n.once.Do(func() { + n.sigma = mat64.NewSymDense(n.Dim(), nil) + n.sigma.FromCholesky(&n.chol) + }) +} diff --git a/distmv/normal_test.go b/distmv/normal_test.go index 059b347c..4b08d677 100644 --- a/distmv/normal_test.go +++ b/distmv/normal_test.go @@ -1,6 +1,7 @@ package distmv import ( + "math" "testing" "github.com/gonum/floats" @@ -96,3 +97,216 @@ func TestNormRand(t *testing.T) { } } } + +func TestConditionNormal(t *testing.T) { + // Uncorrelated values shouldn't influence the updated values. + for _, test := range []struct { + mu []float64 + sigma *mat64.SymDense + observed []int + values []float64 + + newMu []float64 + newSigma *mat64.SymDense + }{ + { + mu: []float64{2, 3}, + sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}), + observed: []int{0}, + values: []float64{10}, + + newMu: []float64{3}, + newSigma: mat64.NewSymDense(1, []float64{5}), + }, + { + mu: []float64{2, 3}, + sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}), + observed: []int{1}, + values: []float64{10}, + + newMu: []float64{2}, + newSigma: mat64.NewSymDense(1, []float64{2}), + }, + { + mu: []float64{2, 3, 4}, + sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}), + observed: []int{1}, + values: []float64{10}, + + newMu: []float64{2, 4}, + newSigma: mat64.NewSymDense(2, []float64{2, 0, 0, 10}), + }, + { + mu: []float64{2, 3, 4}, + sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}), + observed: []int{0, 1}, + values: []float64{10, 15}, + + newMu: []float64{4}, + newSigma: mat64.NewSymDense(1, []float64{10}), + }, + { + mu: []float64{2, 3, 4, 5}, + sigma: mat64.NewSymDense(4, []float64{2, 0.5, 0, 0, 0.5, 5, 0, 0, 0, 0, 10, 2, 0, 0, 2, 3}), + observed: []int{0, 1}, + values: []float64{10, 15}, + + newMu: []float64{4, 5}, + newSigma: mat64.NewSymDense(2, []float64{10, 2, 2, 3}), + }, + } { + normal, ok := NewNormal(test.mu, test.sigma, nil) + if !ok { + t.Fatalf("Bad test, original sigma not positive definite") + } + newNormal, ok := normal.ConditionNormal(test.observed, test.values, nil) + if !ok { + t.Fatalf("Bad test, update failure") + } + + if !floats.EqualApprox(test.newMu, newNormal.mu, 1e-12) { + t.Errorf("Updated mean mismatch. Want %v, got %v.", test.newMu, newNormal.mu) + } + + var sigma mat64.SymDense + sigma.FromCholesky(&newNormal.chol) + if !mat64.EqualApprox(test.newSigma, &sigma, 1e-12) { + t.Errorf("Updated sigma mismatch\n.Want:\n% v\nGot:\n% v\n", test.newSigma, sigma) + } + } + + // Test bivariate case where the update rule is analytic + for _, test := range []struct { + mu []float64 + std []float64 + rho float64 + value float64 + }{ + { + mu: []float64{2, 3}, + std: []float64{3, 5}, + rho: 0.9, + value: 1000, + }, + { + mu: []float64{2, 3}, + std: []float64{3, 5}, + rho: -0.9, + value: 1000, + }, + } { + std := test.std + rho := test.rho + sigma := mat64.NewSymDense(2, []float64{std[0] * std[0], std[0] * std[1] * rho, std[0] * std[1] * rho, std[1] * std[1]}) + normal, ok := NewNormal(test.mu, sigma, nil) + if !ok { + t.Fatalf("Bad test, original sigma not positive definite") + } + newNormal, ok := normal.ConditionNormal([]int{1}, []float64{test.value}, nil) + if !ok { + t.Fatalf("Bad test, update failed") + } + var newSigma mat64.SymDense + newSigma.FromCholesky(&newNormal.chol) + trueMean := test.mu[0] + rho*(std[0]/std[1])*(test.value-test.mu[1]) + if math.Abs(trueMean-newNormal.mu[0]) > 1e-14 { + t.Errorf("Mean mismatch. Want %v, got %v", trueMean, newNormal.mu[0]) + } + trueVar := (1 - rho*rho) * std[0] * std[0] + if math.Abs(trueVar-newSigma.At(0, 0)) > 1e-14 { + t.Errorf("Std mismatch. Want %v, got %v", trueMean, newNormal.mu[0]) + } + } + + // Test via sampling. + for _, test := range []struct { + mu []float64 + sigma *mat64.SymDense + observed []int + unobserved []int + value []float64 + }{ + // The indices in unobserved must be in ascending order for this test. + { + mu: []float64{2, 3, 4}, + sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), + + observed: []int{0}, + unobserved: []int{1, 2}, + value: []float64{1.9}, + }, + { + mu: []float64{2, 3, 4, 5}, + sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}), + + observed: []int{0, 3}, + unobserved: []int{1, 2}, + value: []float64{1.9, 2.9}, + }, + } { + totalSamp := 4000000 + var nSamp int + samples := mat64.NewDense(totalSamp, len(test.mu), nil) + normal, ok := NewNormal(test.mu, test.sigma, nil) + if !ok { + t.Errorf("bad test") + } + sample := make([]float64, len(test.mu)) + for i := 0; i < totalSamp; i++ { + normal.Rand(sample) + isClose := true + for i, v := range test.observed { + if math.Abs(sample[v]-test.value[i]) > 1e-1 { + isClose = false + break + } + } + if isClose { + samples.SetRow(nSamp, sample) + nSamp++ + } + } + + if nSamp < 100 { + t.Errorf("bad test, not enough samples") + continue + } + samples = samples.View(0, 0, nSamp, len(test.mu)).(*mat64.Dense) + + // Compute mean and covariance matrix. + estMean := make([]float64, len(test.mu)) + for i := range estMean { + estMean[i] = stat.Mean(mat64.Col(nil, i, samples), nil) + } + estCov := stat.CovarianceMatrix(nil, samples, nil) + + // Compute update rule. + newNormal, ok := normal.ConditionNormal(test.observed, test.value, nil) + if !ok { + t.Fatalf("Bad test, update failure") + } + + var subEstMean []float64 + for _, v := range test.unobserved { + + subEstMean = append(subEstMean, estMean[v]) + } + subEstCov := mat64.NewSymDense(len(test.unobserved), nil) + for i := 0; i < len(test.unobserved); i++ { + for j := i; j < len(test.unobserved); j++ { + subEstCov.SetSym(i, j, estCov.At(test.unobserved[i], test.unobserved[j])) + } + } + + for i, v := range subEstMean { + if math.Abs(newNormal.mu[i]-v) > 5e-2 { + t.Errorf("Mean mismatch. Want %v, got %v.", newNormal.mu[i], v) + } + } + var sigma mat64.SymDense + sigma.FromCholesky(&newNormal.chol) + if !mat64.EqualApprox(&sigma, subEstCov, 1e-1) { + t.Errorf("Covariance mismatch. Want:\n%0.8v\nGot:\n%0.8v\n", subEstCov, sigma) + } + } +}