mirror of
https://github.com/gonum/gonum.git
synced 2025-10-25 00:00:24 +08:00
Add MarginalNormalSingle
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"github.com/gonum/stat/distuv"
|
||||
)
|
||||
|
||||
// Normal is a multivariate normal distribution (also known as the multivariate
|
||||
@@ -278,6 +279,29 @@ func (n *Normal) MarginalNormal(vars []int, src *rand.Rand) (*Normal, bool) {
|
||||
return NewNormal(newMean, &s, src)
|
||||
}
|
||||
|
||||
// MarginalNormalSingle returns the marginal of the given input variable.
|
||||
// That is, MarginalNormal returns
|
||||
// p(x_i) = \int_{x_¬i} p(x_i | x_¬i) p(x_¬i) dx_¬i
|
||||
// where i is the input index.
|
||||
// The input src is passed to the constructed distuv.Normal.
|
||||
func (n *Normal) MarginalNormalSingle(i int, src *rand.Rand) distuv.Normal {
|
||||
var std float64
|
||||
if n.sigma != nil {
|
||||
std = n.sigma.At(i, i)
|
||||
} else {
|
||||
// Reconstruct the {i,i} diagonal element of the covariance directly.
|
||||
for j := 0; j <= i; j++ {
|
||||
v := n.lower.At(i, j)
|
||||
std += v * v
|
||||
}
|
||||
}
|
||||
return distuv.Normal{
|
||||
Mu: n.mu[i],
|
||||
Sigma: math.Sqrt(std),
|
||||
Source: src,
|
||||
}
|
||||
}
|
||||
|
||||
// Mean returns the mean of the probability distribution at x. If the
|
||||
// input argument is nil, a new slice will be allocated, otherwise the result
|
||||
// will be put in-place into the receiver.
|
||||
|
||||
@@ -6,6 +6,7 @@ package distmv
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
@@ -402,3 +403,84 @@ func TestMarginal(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarginalSingle(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
mu []float64
|
||||
sigma *mat64.SymDense
|
||||
}{
|
||||
{
|
||||
mu: []float64{2, 3, 4},
|
||||
sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
|
||||
},
|
||||
{
|
||||
mu: []float64{2, 3, 4, 5},
|
||||
sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
|
||||
},
|
||||
} {
|
||||
normal, ok := NewNormal(test.mu, test.sigma, nil)
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, covariance matrix not positive definite")
|
||||
}
|
||||
// Verify with nil Sigma.
|
||||
normal.sigma = nil
|
||||
for i, mean := range test.mu {
|
||||
norm := normal.MarginalNormalSingle(i, nil)
|
||||
if norm.Mean() != mean {
|
||||
t.Errorf("Mean mismatch nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean())
|
||||
}
|
||||
std := math.Sqrt(test.sigma.At(i, i))
|
||||
if math.Abs(norm.StdDev()-std) > 1e-14 {
|
||||
t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev())
|
||||
}
|
||||
}
|
||||
|
||||
// Verify with non-nil Sigma.
|
||||
normal.setSigma()
|
||||
for i, mean := range test.mu {
|
||||
norm := normal.MarginalNormalSingle(i, nil)
|
||||
if norm.Mean() != mean {
|
||||
t.Errorf("Mean mismatch non-nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean())
|
||||
}
|
||||
std := math.Sqrt(test.sigma.At(i, i))
|
||||
if math.Abs(norm.StdDev()-std) > 1e-14 {
|
||||
t.Errorf("StdDev mismatch non-nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test matching with TestMarginal.
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
for cas := 0; cas < 10; cas++ {
|
||||
dim := rnd.Intn(10) + 1
|
||||
mu := make([]float64, dim)
|
||||
for i := range mu {
|
||||
mu[i] = rnd.Float64()
|
||||
}
|
||||
x := make([]float64, dim*dim)
|
||||
for i := range x {
|
||||
x[i] = rnd.Float64()
|
||||
}
|
||||
mat := mat64.NewDense(dim, dim, x)
|
||||
var sigma mat64.SymDense
|
||||
sigma.SymOuterK(1, mat)
|
||||
|
||||
normal, ok := NewNormal(mu, &sigma, nil)
|
||||
if !ok {
|
||||
t.Fatal("bad test")
|
||||
}
|
||||
for i := 0; i < dim; i++ {
|
||||
single := normal.MarginalNormalSingle(i, nil)
|
||||
mult, ok := normal.MarginalNormal([]int{i}, nil)
|
||||
if !ok {
|
||||
t.Fatal("bad test")
|
||||
}
|
||||
if math.Abs(single.Mean()-mult.Mean(nil)[0]) > 1e-14 {
|
||||
t.Errorf("Mean mismatch")
|
||||
}
|
||||
if math.Abs(single.Variance()-mult.CovarianceMatrix(nil).At(0, 0)) > 1e-14 {
|
||||
t.Errorf("Variance mismatch")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
76
distmv/normalbench_test.go
Normal file
76
distmv/normalbench_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package distmv
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
func BenchmarkMarginalNormal10(b *testing.B) {
|
||||
sz := 10
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
normal := randomNormal(sz, rnd)
|
||||
_ = normal.CovarianceMatrix(nil) // pre-compute sigma
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
marg, ok := normal.MarginalNormal([]int{1}, nil)
|
||||
if !ok {
|
||||
b.Error("bad test")
|
||||
}
|
||||
_ = marg
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarginalNormalReset10(b *testing.B) {
|
||||
sz := 10
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
normal := randomNormal(sz, rnd)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
marg, ok := normal.MarginalNormal([]int{1}, nil)
|
||||
if !ok {
|
||||
b.Error("bad test")
|
||||
}
|
||||
normal.sigma = nil
|
||||
normal.once = sync.Once{}
|
||||
_ = marg
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarginalNormalSingle10(b *testing.B) {
|
||||
sz := 10
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
normal := randomNormal(sz, rnd)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
marg := normal.MarginalNormalSingle(1, nil)
|
||||
_ = marg
|
||||
}
|
||||
}
|
||||
|
||||
func randomNormal(sz int, rnd *rand.Rand) *Normal {
|
||||
mu := make([]float64, sz)
|
||||
for i := range mu {
|
||||
mu[i] = rnd.Float64()
|
||||
}
|
||||
data := make([]float64, sz*sz)
|
||||
for i := range data {
|
||||
data[i] = rnd.Float64()
|
||||
}
|
||||
dM := mat64.NewDense(sz, sz, data)
|
||||
var sigma mat64.SymDense
|
||||
sigma.SymOuterK(1, dM)
|
||||
|
||||
normal, ok := NewNormal(mu, &sigma, nil)
|
||||
if !ok {
|
||||
log.Fatal("bad test, not pos def")
|
||||
}
|
||||
return normal
|
||||
}
|
||||
Reference in New Issue
Block a user