mirror of
https://github.com/gonum/gonum.git
synced 2025-10-15 19:50:48 +08:00
Merge pull request #168 from gonum/keepsigma
Store sigma in Normal and StudentsT
This commit is contained in:
@@ -7,7 +7,6 @@ package distmv
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"sync"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
@@ -27,8 +26,7 @@ var (
|
||||
type Normal struct {
|
||||
mu []float64
|
||||
|
||||
once sync.Once
|
||||
sigma *mat64.SymDense // only stored if needed
|
||||
sigma mat64.SymDense
|
||||
|
||||
chol mat64.Cholesky
|
||||
lower mat64.TriDense
|
||||
@@ -59,6 +57,8 @@ func NewNormal(mu []float64, sigma mat64.Symmetric, src *rand.Rand) (*Normal, bo
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
n.sigma = *mat64.NewSymDense(dim, nil)
|
||||
n.sigma.CopySym(sigma)
|
||||
n.lower.LFromCholesky(&n.chol)
|
||||
n.logSqrtDet = 0.5 * n.chol.LogDet()
|
||||
return n, true
|
||||
@@ -141,9 +141,7 @@ func (n *Normal) ConditionNormal(observed []int, values []float64, src *rand.Ran
|
||||
}
|
||||
}
|
||||
|
||||
n.setSigma()
|
||||
|
||||
_, mu1, sigma11 := studentsTConditional(observed, values, math.Inf(1), n.mu, n.sigma)
|
||||
_, mu1, sigma11 := studentsTConditional(observed, values, math.Inf(1), n.mu, &n.sigma)
|
||||
if mu1 == nil {
|
||||
return nil, false
|
||||
}
|
||||
@@ -164,8 +162,7 @@ func (n *Normal) CovarianceMatrix(s *mat64.SymDense) *mat64.SymDense {
|
||||
if sn != n.Dim() {
|
||||
panic("normal: input matrix size mismatch")
|
||||
}
|
||||
n.setSigma()
|
||||
s.CopySym(n.sigma)
|
||||
s.CopySym(&n.sigma)
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -202,9 +199,8 @@ func (n *Normal) MarginalNormal(vars []int, src *rand.Rand) (*Normal, bool) {
|
||||
for i, v := range vars {
|
||||
newMean[i] = n.mu[v]
|
||||
}
|
||||
n.setSigma()
|
||||
var s mat64.SymDense
|
||||
s.SubsetSym(n.sigma, vars)
|
||||
s.SubsetSym(&n.sigma, vars)
|
||||
return NewNormal(newMean, &s, src)
|
||||
}
|
||||
|
||||
@@ -216,19 +212,9 @@ func (n *Normal) MarginalNormal(vars []int, src *rand.Rand) (*Normal, bool) {
|
||||
//
|
||||
// The input src is passed to the constructed distuv.Normal.
|
||||
func (n *Normal) MarginalNormalSingle(i int, src *rand.Rand) distuv.Normal {
|
||||
var std float64
|
||||
if n.sigma != nil {
|
||||
std = n.sigma.At(i, i)
|
||||
} else {
|
||||
// Reconstruct the {i,i} diagonal element of the covariance directly.
|
||||
for j := 0; j <= i; j++ {
|
||||
v := n.lower.At(i, j)
|
||||
std += v * v
|
||||
}
|
||||
}
|
||||
return distuv.Normal{
|
||||
Mu: n.mu[i],
|
||||
Sigma: math.Sqrt(std),
|
||||
Sigma: math.Sqrt(n.sigma.At(i, i)),
|
||||
Source: src,
|
||||
}
|
||||
}
|
||||
@@ -300,14 +286,6 @@ func (n *Normal) SetMean(mu []float64) {
|
||||
copy(n.mu, mu)
|
||||
}
|
||||
|
||||
// setSigma computes and stores the covariance matrix of the distribution.
|
||||
func (n *Normal) setSigma() {
|
||||
n.once.Do(func() {
|
||||
n.sigma = mat64.NewSymDense(n.Dim(), nil)
|
||||
n.sigma.FromCholesky(&n.chol)
|
||||
})
|
||||
}
|
||||
|
||||
// TransformNormal transforms the vector, normal, generated from a standard
|
||||
// multidimensional normal into a vector that has been generated under the
|
||||
// distribution of the receiver.
|
||||
|
@@ -489,8 +489,6 @@ func TestMarginalSingle(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, covariance matrix not positive definite")
|
||||
}
|
||||
// Verify with nil Sigma.
|
||||
normal.sigma = nil
|
||||
for i, mean := range test.mu {
|
||||
norm := normal.MarginalNormalSingle(i, nil)
|
||||
if norm.Mean() != mean {
|
||||
@@ -501,19 +499,6 @@ func TestMarginalSingle(t *testing.T) {
|
||||
t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev())
|
||||
}
|
||||
}
|
||||
|
||||
// Verify with non-nil Sigma.
|
||||
normal.setSigma()
|
||||
for i, mean := range test.mu {
|
||||
norm := normal.MarginalNormalSingle(i, nil)
|
||||
if norm.Mean() != mean {
|
||||
t.Errorf("Mean mismatch non-nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean())
|
||||
}
|
||||
std := math.Sqrt(test.sigma.At(i, i))
|
||||
if math.Abs(norm.StdDev()-std) > 1e-14 {
|
||||
t.Errorf("StdDev mismatch non-nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test matching with TestMarginal.
|
||||
|
@@ -7,7 +7,6 @@ package distmv
|
||||
import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/matrix/mat64"
|
||||
@@ -38,8 +37,6 @@ func BenchmarkMarginalNormalReset10(b *testing.B) {
|
||||
if !ok {
|
||||
b.Error("bad test")
|
||||
}
|
||||
normal.sigma = nil
|
||||
normal.once = sync.Once{}
|
||||
_ = marg
|
||||
}
|
||||
}
|
||||
|
@@ -8,7 +8,6 @@ import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/tools/container/intsets"
|
||||
|
||||
@@ -36,8 +35,7 @@ type StudentsT struct {
|
||||
mu []float64
|
||||
src *rand.Rand
|
||||
|
||||
once sync.Once
|
||||
sigma *mat64.SymDense // only stored if needed
|
||||
sigma mat64.SymDense // only stored if needed
|
||||
|
||||
chol mat64.Cholesky
|
||||
lower mat64.TriDense
|
||||
@@ -71,6 +69,8 @@ func NewStudentsT(mu []float64, sigma mat64.Symmetric, nu float64, src *rand.Ran
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
s.sigma = *mat64.NewSymDense(dim, nil)
|
||||
s.sigma.CopySym(sigma)
|
||||
s.lower.LFromCholesky(&s.chol)
|
||||
s.logSqrtDet = 0.5 * s.chol.LogDet()
|
||||
return s, true
|
||||
@@ -101,9 +101,7 @@ func (s *StudentsT) ConditionStudentsT(observed []int, values []float64, src *ra
|
||||
}
|
||||
}
|
||||
|
||||
s.setSigma()
|
||||
|
||||
newNu, newMean, newSigma := studentsTConditional(observed, values, s.nu, s.mu, s.sigma)
|
||||
newNu, newMean, newSigma := studentsTConditional(observed, values, s.nu, s.mu, &s.sigma)
|
||||
if newMean == nil {
|
||||
return nil, false
|
||||
}
|
||||
@@ -231,8 +229,7 @@ func (st *StudentsT) CovarianceMatrix(s *mat64.SymDense) *mat64.SymDense {
|
||||
if sn != st.dim {
|
||||
panic("normal: input matrix size mismatch")
|
||||
}
|
||||
st.setSigma()
|
||||
s.CopySym(st.sigma)
|
||||
s.CopySym(&st.sigma)
|
||||
s.ScaleSym(st.nu/(st.nu-2), s)
|
||||
return s
|
||||
}
|
||||
@@ -286,9 +283,8 @@ func (s *StudentsT) MarginalStudentsT(vars []int, src *rand.Rand) (dist *Student
|
||||
for i, v := range vars {
|
||||
newMean[i] = s.mu[v]
|
||||
}
|
||||
s.setSigma()
|
||||
var newSigma mat64.SymDense
|
||||
newSigma.SubsetSym(s.sigma, vars)
|
||||
newSigma.SubsetSym(&s.sigma, vars)
|
||||
return NewStudentsT(newMean, &newSigma, s.nu, src)
|
||||
}
|
||||
|
||||
@@ -300,20 +296,9 @@ func (s *StudentsT) MarginalStudentsT(vars []int, src *rand.Rand) (dist *Student
|
||||
//
|
||||
// The input src is passed to the call to NewStudentsT.
|
||||
func (s *StudentsT) MarginalStudentsTSingle(i int, src *rand.Rand) distuv.StudentsT {
|
||||
var std float64
|
||||
if s.sigma != nil {
|
||||
std = s.sigma.At(i, i)
|
||||
} else {
|
||||
// Reconstruct the {i,i} diagonal element of the covariance directly.
|
||||
for j := 0; j <= i; j++ {
|
||||
v := s.lower.At(i, j)
|
||||
std += v * v
|
||||
}
|
||||
}
|
||||
|
||||
return distuv.StudentsT{
|
||||
Mu: s.mu[i],
|
||||
Sigma: math.Sqrt(std),
|
||||
Sigma: math.Sqrt(s.sigma.At(i, i)),
|
||||
Nu: s.nu,
|
||||
Src: src,
|
||||
}
|
||||
@@ -367,11 +352,3 @@ func (s *StudentsT) Rand(x []float64) []float64 {
|
||||
floats.Add(x, s.mu)
|
||||
return x
|
||||
}
|
||||
|
||||
// setSigma computes and stores the covariance matrix of the distribution.
|
||||
func (s *StudentsT) setSigma() {
|
||||
s.once.Do(func() {
|
||||
s.sigma = mat64.NewSymDense(s.dim, nil)
|
||||
s.sigma.FromCholesky(&s.chol)
|
||||
})
|
||||
}
|
||||
|
@@ -182,12 +182,9 @@ func TestStudentsTConditional(t *testing.T) {
|
||||
muOb[i] = test.mean[v]
|
||||
}
|
||||
|
||||
s.setSigma()
|
||||
sUp.setSigma()
|
||||
|
||||
var sig11, sig22 mat64.SymDense
|
||||
sig11.SubsetSym(s.sigma, unob)
|
||||
sig22.SubsetSym(s.sigma, ob)
|
||||
sig11.SubsetSym(&s.sigma, unob)
|
||||
sig22.SubsetSym(&s.sigma, ob)
|
||||
|
||||
sig12 := mat64.NewDense(len(unob), len(ob), nil)
|
||||
for i := range unob {
|
||||
@@ -221,7 +218,7 @@ func TestStudentsTConditional(t *testing.T) {
|
||||
|
||||
dot := mat64.Dot(shiftVec, &tmp)
|
||||
tmp3.Scale((test.nu+dot)/(test.nu+float64(len(ob))), &tmp3)
|
||||
if !mat64.EqualApprox(&tmp3, sUp.sigma, 1e-10) {
|
||||
if !mat64.EqualApprox(&tmp3, &sUp.sigma, 1e-10) {
|
||||
t.Errorf("Sigma mismatch")
|
||||
}
|
||||
}
|
||||
@@ -248,8 +245,6 @@ func TestStudentsTMarginalSingle(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatalf("Bad test, covariance matrix not positive definite")
|
||||
}
|
||||
// Verify with nil Sigma.
|
||||
studentst.sigma = nil
|
||||
for i, mean := range test.mu {
|
||||
st := studentst.MarginalStudentsTSingle(i, nil)
|
||||
if st.Mean() != mean {
|
||||
@@ -263,18 +258,5 @@ func TestStudentsTMarginalSingle(t *testing.T) {
|
||||
t.Errorf("Nu mismatch nil Sigma, idx %v: want %v, got %v ", i, test.nu, st.Nu)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify with non-nil Sigma.
|
||||
studentst.setSigma()
|
||||
for i, mean := range test.mu {
|
||||
st := studentst.MarginalStudentsTSingle(i, nil)
|
||||
if st.Mean() != mean {
|
||||
t.Errorf("Mean mismatch non-nil Sigma, idx %v: want %v, got %v.", i, mean, st.Mean())
|
||||
}
|
||||
std := math.Sqrt(test.sigma.At(i, i))
|
||||
if math.Abs(st.Sigma-std) > 1e-14 {
|
||||
t.Errorf("StdDev mismatch non-nil Sigma, idx %v: want %v, got %v.", i, std, st.StdDev())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user