From fac834f7a90164fd271ee76b0891f474f90d6f4b Mon Sep 17 00:00:00 2001 From: btracey Date: Wed, 20 Apr 2016 11:26:53 -0600 Subject: [PATCH] Add MarginalNormalSingle --- distmv/normal.go | 24 +++++++++++ distmv/normal_test.go | 82 ++++++++++++++++++++++++++++++++++++++ distmv/normalbench_test.go | 76 +++++++++++++++++++++++++++++++++++ 3 files changed, 182 insertions(+) create mode 100644 distmv/normalbench_test.go diff --git a/distmv/normal.go b/distmv/normal.go index 82f87aed..73ed392a 100644 --- a/distmv/normal.go +++ b/distmv/normal.go @@ -11,6 +11,7 @@ import ( "github.com/gonum/floats" "github.com/gonum/matrix/mat64" + "github.com/gonum/stat/distuv" ) // Normal is a multivariate normal distribution (also known as the multivariate @@ -278,6 +279,29 @@ func (n *Normal) MarginalNormal(vars []int, src *rand.Rand) (*Normal, bool) { return NewNormal(newMean, &s, src) } +// MarginalNormalSingle returns the marginal of the given input variable. +// That is, MarginalNormal returns +// p(x_i) = \int_{x_¬i} p(x_i | x_¬i) p(x_¬i) dx_¬i +// where i is the input index. +// The input src is passed to the constructed distuv.Normal. +func (n *Normal) MarginalNormalSingle(i int, src *rand.Rand) distuv.Normal { + var std float64 + if n.sigma != nil { + std = n.sigma.At(i, i) + } else { + // Reconstruct the {i,i} diagonal element of the covariance directly. + for j := 0; j <= i; j++ { + v := n.lower.At(i, j) + std += v * v + } + } + return distuv.Normal{ + Mu: n.mu[i], + Sigma: math.Sqrt(std), + Source: src, + } +} + // Mean returns the mean of the probability distribution at x. If the // input argument is nil, a new slice will be allocated, otherwise the result // will be put in-place into the receiver. diff --git a/distmv/normal_test.go b/distmv/normal_test.go index d1b54dd2..a5b324b2 100644 --- a/distmv/normal_test.go +++ b/distmv/normal_test.go @@ -6,6 +6,7 @@ package distmv import ( "math" + "math/rand" "testing" "github.com/gonum/floats" @@ -402,3 +403,84 @@ func TestMarginal(t *testing.T) { } } } + +func TestMarginalSingle(t *testing.T) { + for _, test := range []struct { + mu []float64 + sigma *mat64.SymDense + }{ + { + mu: []float64{2, 3, 4}, + sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), + }, + { + mu: []float64{2, 3, 4, 5}, + sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}), + }, + } { + normal, ok := NewNormal(test.mu, test.sigma, nil) + if !ok { + t.Fatalf("Bad test, covariance matrix not positive definite") + } + // Verify with nil Sigma. + normal.sigma = nil + for i, mean := range test.mu { + norm := normal.MarginalNormalSingle(i, nil) + if norm.Mean() != mean { + t.Errorf("Mean mismatch nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean()) + } + std := math.Sqrt(test.sigma.At(i, i)) + if math.Abs(norm.StdDev()-std) > 1e-14 { + t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev()) + } + } + + // Verify with non-nil Sigma. + normal.setSigma() + for i, mean := range test.mu { + norm := normal.MarginalNormalSingle(i, nil) + if norm.Mean() != mean { + t.Errorf("Mean mismatch non-nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean()) + } + std := math.Sqrt(test.sigma.At(i, i)) + if math.Abs(norm.StdDev()-std) > 1e-14 { + t.Errorf("StdDev mismatch non-nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev()) + } + } + } + + // Test matching with TestMarginal. + rnd := rand.New(rand.NewSource(1)) + for cas := 0; cas < 10; cas++ { + dim := rnd.Intn(10) + 1 + mu := make([]float64, dim) + for i := range mu { + mu[i] = rnd.Float64() + } + x := make([]float64, dim*dim) + for i := range x { + x[i] = rnd.Float64() + } + mat := mat64.NewDense(dim, dim, x) + var sigma mat64.SymDense + sigma.SymOuterK(1, mat) + + normal, ok := NewNormal(mu, &sigma, nil) + if !ok { + t.Fatal("bad test") + } + for i := 0; i < dim; i++ { + single := normal.MarginalNormalSingle(i, nil) + mult, ok := normal.MarginalNormal([]int{i}, nil) + if !ok { + t.Fatal("bad test") + } + if math.Abs(single.Mean()-mult.Mean(nil)[0]) > 1e-14 { + t.Errorf("Mean mismatch") + } + if math.Abs(single.Variance()-mult.CovarianceMatrix(nil).At(0, 0)) > 1e-14 { + t.Errorf("Variance mismatch") + } + } + } +} diff --git a/distmv/normalbench_test.go b/distmv/normalbench_test.go new file mode 100644 index 00000000..d497365e --- /dev/null +++ b/distmv/normalbench_test.go @@ -0,0 +1,76 @@ +// Copyright ©2016 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package distmv + +import ( + "log" + "math/rand" + "sync" + "testing" + + "github.com/gonum/matrix/mat64" +) + +func BenchmarkMarginalNormal10(b *testing.B) { + sz := 10 + rnd := rand.New(rand.NewSource(1)) + normal := randomNormal(sz, rnd) + _ = normal.CovarianceMatrix(nil) // pre-compute sigma + b.ResetTimer() + for i := 0; i < b.N; i++ { + marg, ok := normal.MarginalNormal([]int{1}, nil) + if !ok { + b.Error("bad test") + } + _ = marg + } +} + +func BenchmarkMarginalNormalReset10(b *testing.B) { + sz := 10 + rnd := rand.New(rand.NewSource(1)) + normal := randomNormal(sz, rnd) + b.ResetTimer() + for i := 0; i < b.N; i++ { + marg, ok := normal.MarginalNormal([]int{1}, nil) + if !ok { + b.Error("bad test") + } + normal.sigma = nil + normal.once = sync.Once{} + _ = marg + } +} + +func BenchmarkMarginalNormalSingle10(b *testing.B) { + sz := 10 + rnd := rand.New(rand.NewSource(1)) + normal := randomNormal(sz, rnd) + b.ResetTimer() + for i := 0; i < b.N; i++ { + marg := normal.MarginalNormalSingle(1, nil) + _ = marg + } +} + +func randomNormal(sz int, rnd *rand.Rand) *Normal { + mu := make([]float64, sz) + for i := range mu { + mu[i] = rnd.Float64() + } + data := make([]float64, sz*sz) + for i := range data { + data[i] = rnd.Float64() + } + dM := mat64.NewDense(sz, sz, data) + var sigma mat64.SymDense + sigma.SymOuterK(1, dM) + + normal, ok := NewNormal(mu, &sigma, nil) + if !ok { + log.Fatal("bad test, not pos def") + } + return normal +}