Add MarginalNormalSingle

This commit is contained in:
btracey
2016-04-20 11:26:53 -06:00
committed by Brendan Tracey
parent 9032651ba9
commit fac834f7a9
3 changed files with 182 additions and 0 deletions

View File

@@ -11,6 +11,7 @@ import (
"github.com/gonum/floats"
"github.com/gonum/matrix/mat64"
"github.com/gonum/stat/distuv"
)
// Normal is a multivariate normal distribution (also known as the multivariate
@@ -278,6 +279,29 @@ func (n *Normal) MarginalNormal(vars []int, src *rand.Rand) (*Normal, bool) {
return NewNormal(newMean, &s, src)
}
// MarginalNormalSingle returns the marginal of the given input variable.
// That is, MarginalNormal returns
// p(x_i) = \int_{x_¬i} p(x_i | x_¬i) p(x_¬i) dx_¬i
// where i is the input index.
// The input src is passed to the constructed distuv.Normal.
func (n *Normal) MarginalNormalSingle(i int, src *rand.Rand) distuv.Normal {
var std float64
if n.sigma != nil {
std = n.sigma.At(i, i)
} else {
// Reconstruct the {i,i} diagonal element of the covariance directly.
for j := 0; j <= i; j++ {
v := n.lower.At(i, j)
std += v * v
}
}
return distuv.Normal{
Mu: n.mu[i],
Sigma: math.Sqrt(std),
Source: src,
}
}
// Mean returns the mean of the probability distribution at x. If the
// input argument is nil, a new slice will be allocated, otherwise the result
// will be put in-place into the receiver.

View File

@@ -6,6 +6,7 @@ package distmv
import (
"math"
"math/rand"
"testing"
"github.com/gonum/floats"
@@ -402,3 +403,84 @@ func TestMarginal(t *testing.T) {
}
}
}
func TestMarginalSingle(t *testing.T) {
for _, test := range []struct {
mu []float64
sigma *mat64.SymDense
}{
{
mu: []float64{2, 3, 4},
sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
},
} {
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Fatalf("Bad test, covariance matrix not positive definite")
}
// Verify with nil Sigma.
normal.sigma = nil
for i, mean := range test.mu {
norm := normal.MarginalNormalSingle(i, nil)
if norm.Mean() != mean {
t.Errorf("Mean mismatch nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean())
}
std := math.Sqrt(test.sigma.At(i, i))
if math.Abs(norm.StdDev()-std) > 1e-14 {
t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev())
}
}
// Verify with non-nil Sigma.
normal.setSigma()
for i, mean := range test.mu {
norm := normal.MarginalNormalSingle(i, nil)
if norm.Mean() != mean {
t.Errorf("Mean mismatch non-nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean())
}
std := math.Sqrt(test.sigma.At(i, i))
if math.Abs(norm.StdDev()-std) > 1e-14 {
t.Errorf("StdDev mismatch non-nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev())
}
}
}
// Test matching with TestMarginal.
rnd := rand.New(rand.NewSource(1))
for cas := 0; cas < 10; cas++ {
dim := rnd.Intn(10) + 1
mu := make([]float64, dim)
for i := range mu {
mu[i] = rnd.Float64()
}
x := make([]float64, dim*dim)
for i := range x {
x[i] = rnd.Float64()
}
mat := mat64.NewDense(dim, dim, x)
var sigma mat64.SymDense
sigma.SymOuterK(1, mat)
normal, ok := NewNormal(mu, &sigma, nil)
if !ok {
t.Fatal("bad test")
}
for i := 0; i < dim; i++ {
single := normal.MarginalNormalSingle(i, nil)
mult, ok := normal.MarginalNormal([]int{i}, nil)
if !ok {
t.Fatal("bad test")
}
if math.Abs(single.Mean()-mult.Mean(nil)[0]) > 1e-14 {
t.Errorf("Mean mismatch")
}
if math.Abs(single.Variance()-mult.CovarianceMatrix(nil).At(0, 0)) > 1e-14 {
t.Errorf("Variance mismatch")
}
}
}
}

View File

@@ -0,0 +1,76 @@
// Copyright ©2016 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package distmv
import (
"log"
"math/rand"
"sync"
"testing"
"github.com/gonum/matrix/mat64"
)
func BenchmarkMarginalNormal10(b *testing.B) {
sz := 10
rnd := rand.New(rand.NewSource(1))
normal := randomNormal(sz, rnd)
_ = normal.CovarianceMatrix(nil) // pre-compute sigma
b.ResetTimer()
for i := 0; i < b.N; i++ {
marg, ok := normal.MarginalNormal([]int{1}, nil)
if !ok {
b.Error("bad test")
}
_ = marg
}
}
func BenchmarkMarginalNormalReset10(b *testing.B) {
sz := 10
rnd := rand.New(rand.NewSource(1))
normal := randomNormal(sz, rnd)
b.ResetTimer()
for i := 0; i < b.N; i++ {
marg, ok := normal.MarginalNormal([]int{1}, nil)
if !ok {
b.Error("bad test")
}
normal.sigma = nil
normal.once = sync.Once{}
_ = marg
}
}
func BenchmarkMarginalNormalSingle10(b *testing.B) {
sz := 10
rnd := rand.New(rand.NewSource(1))
normal := randomNormal(sz, rnd)
b.ResetTimer()
for i := 0; i < b.N; i++ {
marg := normal.MarginalNormalSingle(1, nil)
_ = marg
}
}
func randomNormal(sz int, rnd *rand.Rand) *Normal {
mu := make([]float64, sz)
for i := range mu {
mu[i] = rnd.Float64()
}
data := make([]float64, sz*sz)
for i := range data {
data[i] = rnd.Float64()
}
dM := mat64.NewDense(sz, sz, data)
var sigma mat64.SymDense
sigma.SymOuterK(1, dM)
normal, ok := NewNormal(mu, &sigma, nil)
if !ok {
log.Fatal("bad test, not pos def")
}
return normal
}