mirror of
https://github.com/gonum/gonum.git
synced 2025-10-16 20:20:41 +08:00
Merge pull request #168 from gonum/keepsigma
Store sigma in Normal and StudentsT
This commit is contained in:
@@ -7,7 +7,6 @@ package distmv
|
|||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/gonum/floats"
|
"github.com/gonum/floats"
|
||||||
"github.com/gonum/matrix/mat64"
|
"github.com/gonum/matrix/mat64"
|
||||||
@@ -27,8 +26,7 @@ var (
|
|||||||
type Normal struct {
|
type Normal struct {
|
||||||
mu []float64
|
mu []float64
|
||||||
|
|
||||||
once sync.Once
|
sigma mat64.SymDense
|
||||||
sigma *mat64.SymDense // only stored if needed
|
|
||||||
|
|
||||||
chol mat64.Cholesky
|
chol mat64.Cholesky
|
||||||
lower mat64.TriDense
|
lower mat64.TriDense
|
||||||
@@ -59,6 +57,8 @@ func NewNormal(mu []float64, sigma mat64.Symmetric, src *rand.Rand) (*Normal, bo
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
n.sigma = *mat64.NewSymDense(dim, nil)
|
||||||
|
n.sigma.CopySym(sigma)
|
||||||
n.lower.LFromCholesky(&n.chol)
|
n.lower.LFromCholesky(&n.chol)
|
||||||
n.logSqrtDet = 0.5 * n.chol.LogDet()
|
n.logSqrtDet = 0.5 * n.chol.LogDet()
|
||||||
return n, true
|
return n, true
|
||||||
@@ -141,9 +141,7 @@ func (n *Normal) ConditionNormal(observed []int, values []float64, src *rand.Ran
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n.setSigma()
|
_, mu1, sigma11 := studentsTConditional(observed, values, math.Inf(1), n.mu, &n.sigma)
|
||||||
|
|
||||||
_, mu1, sigma11 := studentsTConditional(observed, values, math.Inf(1), n.mu, n.sigma)
|
|
||||||
if mu1 == nil {
|
if mu1 == nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -164,8 +162,7 @@ func (n *Normal) CovarianceMatrix(s *mat64.SymDense) *mat64.SymDense {
|
|||||||
if sn != n.Dim() {
|
if sn != n.Dim() {
|
||||||
panic("normal: input matrix size mismatch")
|
panic("normal: input matrix size mismatch")
|
||||||
}
|
}
|
||||||
n.setSigma()
|
s.CopySym(&n.sigma)
|
||||||
s.CopySym(n.sigma)
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,9 +199,8 @@ func (n *Normal) MarginalNormal(vars []int, src *rand.Rand) (*Normal, bool) {
|
|||||||
for i, v := range vars {
|
for i, v := range vars {
|
||||||
newMean[i] = n.mu[v]
|
newMean[i] = n.mu[v]
|
||||||
}
|
}
|
||||||
n.setSigma()
|
|
||||||
var s mat64.SymDense
|
var s mat64.SymDense
|
||||||
s.SubsetSym(n.sigma, vars)
|
s.SubsetSym(&n.sigma, vars)
|
||||||
return NewNormal(newMean, &s, src)
|
return NewNormal(newMean, &s, src)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,19 +212,9 @@ func (n *Normal) MarginalNormal(vars []int, src *rand.Rand) (*Normal, bool) {
|
|||||||
//
|
//
|
||||||
// The input src is passed to the constructed distuv.Normal.
|
// The input src is passed to the constructed distuv.Normal.
|
||||||
func (n *Normal) MarginalNormalSingle(i int, src *rand.Rand) distuv.Normal {
|
func (n *Normal) MarginalNormalSingle(i int, src *rand.Rand) distuv.Normal {
|
||||||
var std float64
|
|
||||||
if n.sigma != nil {
|
|
||||||
std = n.sigma.At(i, i)
|
|
||||||
} else {
|
|
||||||
// Reconstruct the {i,i} diagonal element of the covariance directly.
|
|
||||||
for j := 0; j <= i; j++ {
|
|
||||||
v := n.lower.At(i, j)
|
|
||||||
std += v * v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return distuv.Normal{
|
return distuv.Normal{
|
||||||
Mu: n.mu[i],
|
Mu: n.mu[i],
|
||||||
Sigma: math.Sqrt(std),
|
Sigma: math.Sqrt(n.sigma.At(i, i)),
|
||||||
Source: src,
|
Source: src,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -300,14 +286,6 @@ func (n *Normal) SetMean(mu []float64) {
|
|||||||
copy(n.mu, mu)
|
copy(n.mu, mu)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setSigma computes and stores the covariance matrix of the distribution.
|
|
||||||
func (n *Normal) setSigma() {
|
|
||||||
n.once.Do(func() {
|
|
||||||
n.sigma = mat64.NewSymDense(n.Dim(), nil)
|
|
||||||
n.sigma.FromCholesky(&n.chol)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TransformNormal transforms the vector, normal, generated from a standard
|
// TransformNormal transforms the vector, normal, generated from a standard
|
||||||
// multidimensional normal into a vector that has been generated under the
|
// multidimensional normal into a vector that has been generated under the
|
||||||
// distribution of the receiver.
|
// distribution of the receiver.
|
||||||
|
@@ -489,8 +489,6 @@ func TestMarginalSingle(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("Bad test, covariance matrix not positive definite")
|
t.Fatalf("Bad test, covariance matrix not positive definite")
|
||||||
}
|
}
|
||||||
// Verify with nil Sigma.
|
|
||||||
normal.sigma = nil
|
|
||||||
for i, mean := range test.mu {
|
for i, mean := range test.mu {
|
||||||
norm := normal.MarginalNormalSingle(i, nil)
|
norm := normal.MarginalNormalSingle(i, nil)
|
||||||
if norm.Mean() != mean {
|
if norm.Mean() != mean {
|
||||||
@@ -501,19 +499,6 @@ func TestMarginalSingle(t *testing.T) {
|
|||||||
t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev())
|
t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify with non-nil Sigma.
|
|
||||||
normal.setSigma()
|
|
||||||
for i, mean := range test.mu {
|
|
||||||
norm := normal.MarginalNormalSingle(i, nil)
|
|
||||||
if norm.Mean() != mean {
|
|
||||||
t.Errorf("Mean mismatch non-nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean())
|
|
||||||
}
|
|
||||||
std := math.Sqrt(test.sigma.At(i, i))
|
|
||||||
if math.Abs(norm.StdDev()-std) > 1e-14 {
|
|
||||||
t.Errorf("StdDev mismatch non-nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test matching with TestMarginal.
|
// Test matching with TestMarginal.
|
||||||
|
@@ -7,7 +7,6 @@ package distmv
|
|||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gonum/matrix/mat64"
|
"github.com/gonum/matrix/mat64"
|
||||||
@@ -38,8 +37,6 @@ func BenchmarkMarginalNormalReset10(b *testing.B) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
b.Error("bad test")
|
b.Error("bad test")
|
||||||
}
|
}
|
||||||
normal.sigma = nil
|
|
||||||
normal.once = sync.Once{}
|
|
||||||
_ = marg
|
_ = marg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -8,7 +8,6 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/tools/container/intsets"
|
"golang.org/x/tools/container/intsets"
|
||||||
|
|
||||||
@@ -36,8 +35,7 @@ type StudentsT struct {
|
|||||||
mu []float64
|
mu []float64
|
||||||
src *rand.Rand
|
src *rand.Rand
|
||||||
|
|
||||||
once sync.Once
|
sigma mat64.SymDense // only stored if needed
|
||||||
sigma *mat64.SymDense // only stored if needed
|
|
||||||
|
|
||||||
chol mat64.Cholesky
|
chol mat64.Cholesky
|
||||||
lower mat64.TriDense
|
lower mat64.TriDense
|
||||||
@@ -71,6 +69,8 @@ func NewStudentsT(mu []float64, sigma mat64.Symmetric, nu float64, src *rand.Ran
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
s.sigma = *mat64.NewSymDense(dim, nil)
|
||||||
|
s.sigma.CopySym(sigma)
|
||||||
s.lower.LFromCholesky(&s.chol)
|
s.lower.LFromCholesky(&s.chol)
|
||||||
s.logSqrtDet = 0.5 * s.chol.LogDet()
|
s.logSqrtDet = 0.5 * s.chol.LogDet()
|
||||||
return s, true
|
return s, true
|
||||||
@@ -101,9 +101,7 @@ func (s *StudentsT) ConditionStudentsT(observed []int, values []float64, src *ra
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.setSigma()
|
newNu, newMean, newSigma := studentsTConditional(observed, values, s.nu, s.mu, &s.sigma)
|
||||||
|
|
||||||
newNu, newMean, newSigma := studentsTConditional(observed, values, s.nu, s.mu, s.sigma)
|
|
||||||
if newMean == nil {
|
if newMean == nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -231,8 +229,7 @@ func (st *StudentsT) CovarianceMatrix(s *mat64.SymDense) *mat64.SymDense {
|
|||||||
if sn != st.dim {
|
if sn != st.dim {
|
||||||
panic("normal: input matrix size mismatch")
|
panic("normal: input matrix size mismatch")
|
||||||
}
|
}
|
||||||
st.setSigma()
|
s.CopySym(&st.sigma)
|
||||||
s.CopySym(st.sigma)
|
|
||||||
s.ScaleSym(st.nu/(st.nu-2), s)
|
s.ScaleSym(st.nu/(st.nu-2), s)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
@@ -286,9 +283,8 @@ func (s *StudentsT) MarginalStudentsT(vars []int, src *rand.Rand) (dist *Student
|
|||||||
for i, v := range vars {
|
for i, v := range vars {
|
||||||
newMean[i] = s.mu[v]
|
newMean[i] = s.mu[v]
|
||||||
}
|
}
|
||||||
s.setSigma()
|
|
||||||
var newSigma mat64.SymDense
|
var newSigma mat64.SymDense
|
||||||
newSigma.SubsetSym(s.sigma, vars)
|
newSigma.SubsetSym(&s.sigma, vars)
|
||||||
return NewStudentsT(newMean, &newSigma, s.nu, src)
|
return NewStudentsT(newMean, &newSigma, s.nu, src)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -300,20 +296,9 @@ func (s *StudentsT) MarginalStudentsT(vars []int, src *rand.Rand) (dist *Student
|
|||||||
//
|
//
|
||||||
// The input src is passed to the call to NewStudentsT.
|
// The input src is passed to the call to NewStudentsT.
|
||||||
func (s *StudentsT) MarginalStudentsTSingle(i int, src *rand.Rand) distuv.StudentsT {
|
func (s *StudentsT) MarginalStudentsTSingle(i int, src *rand.Rand) distuv.StudentsT {
|
||||||
var std float64
|
|
||||||
if s.sigma != nil {
|
|
||||||
std = s.sigma.At(i, i)
|
|
||||||
} else {
|
|
||||||
// Reconstruct the {i,i} diagonal element of the covariance directly.
|
|
||||||
for j := 0; j <= i; j++ {
|
|
||||||
v := s.lower.At(i, j)
|
|
||||||
std += v * v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return distuv.StudentsT{
|
return distuv.StudentsT{
|
||||||
Mu: s.mu[i],
|
Mu: s.mu[i],
|
||||||
Sigma: math.Sqrt(std),
|
Sigma: math.Sqrt(s.sigma.At(i, i)),
|
||||||
Nu: s.nu,
|
Nu: s.nu,
|
||||||
Src: src,
|
Src: src,
|
||||||
}
|
}
|
||||||
@@ -367,11 +352,3 @@ func (s *StudentsT) Rand(x []float64) []float64 {
|
|||||||
floats.Add(x, s.mu)
|
floats.Add(x, s.mu)
|
||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
// setSigma computes and stores the covariance matrix of the distribution.
|
|
||||||
func (s *StudentsT) setSigma() {
|
|
||||||
s.once.Do(func() {
|
|
||||||
s.sigma = mat64.NewSymDense(s.dim, nil)
|
|
||||||
s.sigma.FromCholesky(&s.chol)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
@@ -182,12 +182,9 @@ func TestStudentsTConditional(t *testing.T) {
|
|||||||
muOb[i] = test.mean[v]
|
muOb[i] = test.mean[v]
|
||||||
}
|
}
|
||||||
|
|
||||||
s.setSigma()
|
|
||||||
sUp.setSigma()
|
|
||||||
|
|
||||||
var sig11, sig22 mat64.SymDense
|
var sig11, sig22 mat64.SymDense
|
||||||
sig11.SubsetSym(s.sigma, unob)
|
sig11.SubsetSym(&s.sigma, unob)
|
||||||
sig22.SubsetSym(s.sigma, ob)
|
sig22.SubsetSym(&s.sigma, ob)
|
||||||
|
|
||||||
sig12 := mat64.NewDense(len(unob), len(ob), nil)
|
sig12 := mat64.NewDense(len(unob), len(ob), nil)
|
||||||
for i := range unob {
|
for i := range unob {
|
||||||
@@ -221,7 +218,7 @@ func TestStudentsTConditional(t *testing.T) {
|
|||||||
|
|
||||||
dot := mat64.Dot(shiftVec, &tmp)
|
dot := mat64.Dot(shiftVec, &tmp)
|
||||||
tmp3.Scale((test.nu+dot)/(test.nu+float64(len(ob))), &tmp3)
|
tmp3.Scale((test.nu+dot)/(test.nu+float64(len(ob))), &tmp3)
|
||||||
if !mat64.EqualApprox(&tmp3, sUp.sigma, 1e-10) {
|
if !mat64.EqualApprox(&tmp3, &sUp.sigma, 1e-10) {
|
||||||
t.Errorf("Sigma mismatch")
|
t.Errorf("Sigma mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -248,8 +245,6 @@ func TestStudentsTMarginalSingle(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("Bad test, covariance matrix not positive definite")
|
t.Fatalf("Bad test, covariance matrix not positive definite")
|
||||||
}
|
}
|
||||||
// Verify with nil Sigma.
|
|
||||||
studentst.sigma = nil
|
|
||||||
for i, mean := range test.mu {
|
for i, mean := range test.mu {
|
||||||
st := studentst.MarginalStudentsTSingle(i, nil)
|
st := studentst.MarginalStudentsTSingle(i, nil)
|
||||||
if st.Mean() != mean {
|
if st.Mean() != mean {
|
||||||
@@ -263,18 +258,5 @@ func TestStudentsTMarginalSingle(t *testing.T) {
|
|||||||
t.Errorf("Nu mismatch nil Sigma, idx %v: want %v, got %v ", i, test.nu, st.Nu)
|
t.Errorf("Nu mismatch nil Sigma, idx %v: want %v, got %v ", i, test.nu, st.Nu)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify with non-nil Sigma.
|
|
||||||
studentst.setSigma()
|
|
||||||
for i, mean := range test.mu {
|
|
||||||
st := studentst.MarginalStudentsTSingle(i, nil)
|
|
||||||
if st.Mean() != mean {
|
|
||||||
t.Errorf("Mean mismatch non-nil Sigma, idx %v: want %v, got %v.", i, mean, st.Mean())
|
|
||||||
}
|
|
||||||
std := math.Sqrt(test.sigma.At(i, i))
|
|
||||||
if math.Abs(st.Sigma-std) > 1e-14 {
|
|
||||||
t.Errorf("StdDev mismatch non-nil Sigma, idx %v: want %v, got %v.", i, std, st.StdDev())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user