mirror of
https://github.com/gonum/gonum.git
synced 2025-10-25 08:10:28 +08:00
Add ConditionNormal
This commit is contained in:
110
distmv/normal.go
110
distmv/normal.go
@@ -7,6 +7,7 @@ package distmv
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"sync"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
@@ -20,6 +21,9 @@ import (
|
||||
type Normal struct {
|
||||
mu []float64
|
||||
|
||||
once sync.Once
|
||||
sigma *mat64.SymDense // only stored if needed
|
||||
|
||||
chol mat64.Cholesky
|
||||
lower mat64.TriDense
|
||||
logSqrtDet float64
|
||||
@@ -118,3 +122,109 @@ func (n *Normal) Rand(x []float64) []float64 {
|
||||
floats.Add(x, n.mu)
|
||||
return x
|
||||
}
|
||||
|
||||
// ConditionNormal returns the Normal distribution that is the receiver conditioned
|
||||
// on the input evidence. The returned multivariate normal has dimension
|
||||
// n - len(observed), where n is the dimension of the original receiver. The updated
|
||||
// mean and covariance are
|
||||
// mu = mu_un + sigma_{ob,un}^T * sigma_{ob,ob}^-1 (v - mu_ob)
|
||||
// sigma = sigma_{un,un} - sigma_{ob,un}^T * sigma_{ob,ob}^-1 * sigma_{ob,un}
|
||||
// where mu_un and mu_ob are the original means of the unobserved and observed
|
||||
// variables respectively, sigma_{un,un} is the unobserved subset of the covariance
|
||||
// matrix, sigma_{ob,ob} is the observed subset of the covariance matrix, and
|
||||
// sigma_{un,ob} are the cross terms. The elements of x_2 have been observed with
|
||||
// values v. The dimension order is preserved during conditioning, so if the value
|
||||
// of dimension 1 is observed, the returned normal represents dimensions {0, 2, ...}
|
||||
// of the original Normal distribution.
|
||||
//
|
||||
// ConditionNormal returns {nil, false} if there is a failure during the update.
|
||||
// Mathematically this is impossible, but can occur with finite precision arithmetic.
|
||||
func (n *Normal) ConditionNormal(observed []int, values []float64, src *rand.Rand) (*Normal, bool) {
|
||||
if len(observed) != len(values) {
|
||||
panic("normal: input slice length mismatch")
|
||||
}
|
||||
|
||||
n.setSigma()
|
||||
|
||||
ob := len(observed)
|
||||
unob := n.Dim() - ob
|
||||
obMap := make(map[int]struct{})
|
||||
for _, v := range observed {
|
||||
if _, ok := obMap[v]; ok {
|
||||
panic("normal: observed dimension occurs twice")
|
||||
}
|
||||
obMap[v] = struct{}{}
|
||||
}
|
||||
unobserved := make([]int, 0, unob)
|
||||
for i := 0; i < n.Dim(); i++ {
|
||||
if _, ok := obMap[i]; !ok {
|
||||
unobserved = append(unobserved, i)
|
||||
}
|
||||
}
|
||||
mu1 := make([]float64, unob)
|
||||
for i, v := range unobserved {
|
||||
mu1[i] = n.mu[v]
|
||||
}
|
||||
mu2 := make([]float64, ob) // really v - mu2
|
||||
for i, v := range observed {
|
||||
mu2[i] = values[i] - n.mu[v]
|
||||
}
|
||||
|
||||
var sigma11, sigma22 mat64.SymDense
|
||||
sigma11.SubsetSym(n.sigma, unobserved)
|
||||
sigma22.SubsetSym(n.sigma, observed)
|
||||
|
||||
sigma21 := mat64.NewDense(ob, unob, nil)
|
||||
for i, r := range observed {
|
||||
for j, c := range unobserved {
|
||||
v := n.sigma.At(r, c)
|
||||
sigma21.Set(i, j, v)
|
||||
}
|
||||
}
|
||||
|
||||
var chol mat64.Cholesky
|
||||
ok := chol.Factorize(&sigma22)
|
||||
if !ok {
|
||||
return nil, ok
|
||||
}
|
||||
|
||||
// Compute sigma_{2,1}^T * sigma_{2,2}^-1 (v - mu_2).
|
||||
v := mat64.NewVector(ob, mu2)
|
||||
var tmp, tmp2 mat64.Vector
|
||||
err := tmp.SolveCholeskyVec(&chol, v)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
tmp2.MulVec(sigma21.T(), &tmp)
|
||||
|
||||
// Compute sigma_{2,1}^T * sigma_{2,2}^-1 * sigma_{2,1}.
|
||||
// TODO(btracey): Should this be a method of SymDense?
|
||||
var tmp3, tmp4 mat64.Dense
|
||||
err = tmp3.SolveCholesky(&chol, sigma21)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
tmp4.Mul(sigma21.T(), &tmp3)
|
||||
|
||||
for i := range mu1 {
|
||||
mu1[i] += tmp2.At(i, 0)
|
||||
}
|
||||
|
||||
// TODO(btracey): If tmp2 can constructed with a method, then this can be
|
||||
// replaced with SubSym.
|
||||
for i := 0; i < len(unobserved); i++ {
|
||||
for j := i; j < len(unobserved); j++ {
|
||||
v := sigma11.At(i, j)
|
||||
sigma11.SetSym(i, j, v-tmp4.At(i, j))
|
||||
}
|
||||
}
|
||||
return NewNormal(mu1, &sigma11, src)
|
||||
}
|
||||
|
||||
// setSigma computes and stores the covariance matrix of the distribution.
|
||||
func (n *Normal) setSigma() {
|
||||
n.once.Do(func() {
|
||||
n.sigma = mat64.NewSymDense(n.Dim(), nil)
|
||||
n.sigma.FromCholesky(&n.chol)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package distmv
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
@@ -96,3 +97,216 @@ func TestNormRand(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConditionNormal(t *testing.T) {
|
||||
// Uncorrelated values shouldn't influence the updated values.
|
||||
for _, test := range []struct {
|
||||
mu []float64
|
||||
sigma *mat64.SymDense
|
||||
observed []int
|
||||
values []float64
|
||||
|
||||
newMu []float64
|
||||
newSigma *mat64.SymDense
|
||||
}{
|
||||
{
|
||||
mu: []float64{2, 3},
|
||||
sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
|
||||
observed: []int{0},
|
||||
values: []float64{10},
|
||||
|
||||
newMu: []float64{3},
|
||||
newSigma: mat64.NewSymDense(1, []float64{5}),
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3},
|
||||
sigma: mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
|
||||
observed: []int{1},
|
||||
values: []float64{10},
|
||||
|
||||
newMu: []float64{2},
|
||||
newSigma: mat64.NewSymDense(1, []float64{2}),
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3, 4},
|
||||
sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
|
||||
observed: []int{1},
|
||||
values: []float64{10},
|
||||
|
||||
newMu: []float64{2, 4},
|
||||
newSigma: mat64.NewSymDense(2, []float64{2, 0, 0, 10}),
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3, 4},
|
||||
sigma: mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
|
||||
observed: []int{0, 1},
|
||||
values: []float64{10, 15},
|
||||
|
||||
newMu: []float64{4},
|
||||
newSigma: mat64.NewSymDense(1, []float64{10}),
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3, 4, 5},
|
||||
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 0, 0, 0.5, 5, 0, 0, 0, 0, 10, 2, 0, 0, 2, 3}),
|
||||
observed: []int{0, 1},
|
||||
values: []float64{10, 15},
|
||||
|
||||
newMu: []float64{4, 5},
|
||||
newSigma: mat64.NewSymDense(2, []float64{10, 2, 2, 3}),
|
||||
},
|
||||
} {
|
||||
normal, ok := NewNormal(test.mu, test.sigma, nil)
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, original sigma not positive definite")
|
||||
}
|
||||
newNormal, ok := normal.ConditionNormal(test.observed, test.values, nil)
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, update failure")
|
||||
}
|
||||
|
||||
if !floats.EqualApprox(test.newMu, newNormal.mu, 1e-12) {
|
||||
t.Errorf("Updated mean mismatch. Want %v, got %v.", test.newMu, newNormal.mu)
|
||||
}
|
||||
|
||||
var sigma mat64.SymDense
|
||||
sigma.FromCholesky(&newNormal.chol)
|
||||
if !mat64.EqualApprox(test.newSigma, &sigma, 1e-12) {
|
||||
t.Errorf("Updated sigma mismatch\n.Want:\n% v\nGot:\n% v\n", test.newSigma, sigma)
|
||||
}
|
||||
}
|
||||
|
||||
// Test bivariate case where the update rule is analytic
|
||||
for _, test := range []struct {
|
||||
mu []float64
|
||||
std []float64
|
||||
rho float64
|
||||
value float64
|
||||
}{
|
||||
{
|
||||
mu: []float64{2, 3},
|
||||
std: []float64{3, 5},
|
||||
rho: 0.9,
|
||||
value: 1000,
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3},
|
||||
std: []float64{3, 5},
|
||||
rho: -0.9,
|
||||
value: 1000,
|
||||
},
|
||||
} {
|
||||
std := test.std
|
||||
rho := test.rho
|
||||
sigma := mat64.NewSymDense(2, []float64{std[0] * std[0], std[0] * std[1] * rho, std[0] * std[1] * rho, std[1] * std[1]})
|
||||
normal, ok := NewNormal(test.mu, sigma, nil)
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, original sigma not positive definite")
|
||||
}
|
||||
newNormal, ok := normal.ConditionNormal([]int{1}, []float64{test.value}, nil)
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, update failed")
|
||||
}
|
||||
var newSigma mat64.SymDense
|
||||
newSigma.FromCholesky(&newNormal.chol)
|
||||
trueMean := test.mu[0] + rho*(std[0]/std[1])*(test.value-test.mu[1])
|
||||
if math.Abs(trueMean-newNormal.mu[0]) > 1e-14 {
|
||||
t.Errorf("Mean mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
|
||||
}
|
||||
trueVar := (1 - rho*rho) * std[0] * std[0]
|
||||
if math.Abs(trueVar-newSigma.At(0, 0)) > 1e-14 {
|
||||
t.Errorf("Std mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Test via sampling.
|
||||
for _, test := range []struct {
|
||||
mu []float64
|
||||
sigma *mat64.SymDense
|
||||
observed []int
|
||||
unobserved []int
|
||||
value []float64
|
||||
}{
|
||||
// The indices in unobserved must be in ascending order for this test.
|
||||
{
|
||||
mu: []float64{2, 3, 4},
|
||||
sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
|
||||
|
||||
observed: []int{0},
|
||||
unobserved: []int{1, 2},
|
||||
value: []float64{1.9},
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3, 4, 5},
|
||||
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
|
||||
|
||||
observed: []int{0, 3},
|
||||
unobserved: []int{1, 2},
|
||||
value: []float64{1.9, 2.9},
|
||||
},
|
||||
} {
|
||||
totalSamp := 4000000
|
||||
var nSamp int
|
||||
samples := mat64.NewDense(totalSamp, len(test.mu), nil)
|
||||
normal, ok := NewNormal(test.mu, test.sigma, nil)
|
||||
if !ok {
|
||||
t.Errorf("bad test")
|
||||
}
|
||||
sample := make([]float64, len(test.mu))
|
||||
for i := 0; i < totalSamp; i++ {
|
||||
normal.Rand(sample)
|
||||
isClose := true
|
||||
for i, v := range test.observed {
|
||||
if math.Abs(sample[v]-test.value[i]) > 1e-1 {
|
||||
isClose = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isClose {
|
||||
samples.SetRow(nSamp, sample)
|
||||
nSamp++
|
||||
}
|
||||
}
|
||||
|
||||
if nSamp < 100 {
|
||||
t.Errorf("bad test, not enough samples")
|
||||
continue
|
||||
}
|
||||
samples = samples.View(0, 0, nSamp, len(test.mu)).(*mat64.Dense)
|
||||
|
||||
// Compute mean and covariance matrix.
|
||||
estMean := make([]float64, len(test.mu))
|
||||
for i := range estMean {
|
||||
estMean[i] = stat.Mean(mat64.Col(nil, i, samples), nil)
|
||||
}
|
||||
estCov := stat.CovarianceMatrix(nil, samples, nil)
|
||||
|
||||
// Compute update rule.
|
||||
newNormal, ok := normal.ConditionNormal(test.observed, test.value, nil)
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, update failure")
|
||||
}
|
||||
|
||||
var subEstMean []float64
|
||||
for _, v := range test.unobserved {
|
||||
|
||||
subEstMean = append(subEstMean, estMean[v])
|
||||
}
|
||||
subEstCov := mat64.NewSymDense(len(test.unobserved), nil)
|
||||
for i := 0; i < len(test.unobserved); i++ {
|
||||
for j := i; j < len(test.unobserved); j++ {
|
||||
subEstCov.SetSym(i, j, estCov.At(test.unobserved[i], test.unobserved[j]))
|
||||
}
|
||||
}
|
||||
|
||||
for i, v := range subEstMean {
|
||||
if math.Abs(newNormal.mu[i]-v) > 5e-2 {
|
||||
t.Errorf("Mean mismatch. Want %v, got %v.", newNormal.mu[i], v)
|
||||
}
|
||||
}
|
||||
var sigma mat64.SymDense
|
||||
sigma.FromCholesky(&newNormal.chol)
|
||||
if !mat64.EqualApprox(&sigma, subEstCov, 1e-1) {
|
||||
t.Errorf("Covariance mismatch. Want:\n%0.8v\nGot:\n%0.8v\n", subEstCov, sigma)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user