diff --git a/distmv/normal.go b/distmv/normal.go index 585e0d66..1a20af2c 100644 --- a/distmv/normal.go +++ b/distmv/normal.go @@ -7,7 +7,6 @@ package distmv import ( "math" "math/rand" - "sync" "github.com/gonum/floats" "github.com/gonum/matrix/mat64" @@ -27,8 +26,7 @@ var ( type Normal struct { mu []float64 - once sync.Once - sigma *mat64.SymDense // only stored if needed + sigma mat64.SymDense chol mat64.Cholesky lower mat64.TriDense @@ -59,6 +57,8 @@ func NewNormal(mu []float64, sigma mat64.Symmetric, src *rand.Rand) (*Normal, bo if !ok { return nil, false } + n.sigma = *mat64.NewSymDense(dim, nil) + n.sigma.CopySym(sigma) n.lower.LFromCholesky(&n.chol) n.logSqrtDet = 0.5 * n.chol.LogDet() return n, true @@ -141,9 +141,7 @@ func (n *Normal) ConditionNormal(observed []int, values []float64, src *rand.Ran } } - n.setSigma() - - _, mu1, sigma11 := studentsTConditional(observed, values, math.Inf(1), n.mu, n.sigma) + _, mu1, sigma11 := studentsTConditional(observed, values, math.Inf(1), n.mu, &n.sigma) if mu1 == nil { return nil, false } @@ -164,8 +162,7 @@ func (n *Normal) CovarianceMatrix(s *mat64.SymDense) *mat64.SymDense { if sn != n.Dim() { panic("normal: input matrix size mismatch") } - n.setSigma() - s.CopySym(n.sigma) + s.CopySym(&n.sigma) return s } @@ -202,9 +199,8 @@ func (n *Normal) MarginalNormal(vars []int, src *rand.Rand) (*Normal, bool) { for i, v := range vars { newMean[i] = n.mu[v] } - n.setSigma() var s mat64.SymDense - s.SubsetSym(n.sigma, vars) + s.SubsetSym(&n.sigma, vars) return NewNormal(newMean, &s, src) } @@ -216,19 +212,9 @@ func (n *Normal) MarginalNormal(vars []int, src *rand.Rand) (*Normal, bool) { // // The input src is passed to the constructed distuv.Normal. func (n *Normal) MarginalNormalSingle(i int, src *rand.Rand) distuv.Normal { - var std float64 - if n.sigma != nil { - std = n.sigma.At(i, i) - } else { - // Reconstruct the {i,i} diagonal element of the covariance directly. - for j := 0; j <= i; j++ { - v := n.lower.At(i, j) - std += v * v - } - } return distuv.Normal{ Mu: n.mu[i], - Sigma: math.Sqrt(std), + Sigma: math.Sqrt(n.sigma.At(i, i)), Source: src, } } @@ -300,14 +286,6 @@ func (n *Normal) SetMean(mu []float64) { copy(n.mu, mu) } -// setSigma computes and stores the covariance matrix of the distribution. -func (n *Normal) setSigma() { - n.once.Do(func() { - n.sigma = mat64.NewSymDense(n.Dim(), nil) - n.sigma.FromCholesky(&n.chol) - }) -} - // TransformNormal transforms the vector, normal, generated from a standard // multidimensional normal into a vector that has been generated under the // distribution of the receiver. diff --git a/distmv/normal_test.go b/distmv/normal_test.go index 863928bb..7cd837ed 100644 --- a/distmv/normal_test.go +++ b/distmv/normal_test.go @@ -489,8 +489,6 @@ func TestMarginalSingle(t *testing.T) { if !ok { t.Fatalf("Bad test, covariance matrix not positive definite") } - // Verify with nil Sigma. - normal.sigma = nil for i, mean := range test.mu { norm := normal.MarginalNormalSingle(i, nil) if norm.Mean() != mean { @@ -501,19 +499,6 @@ func TestMarginalSingle(t *testing.T) { t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev()) } } - - // Verify with non-nil Sigma. - normal.setSigma() - for i, mean := range test.mu { - norm := normal.MarginalNormalSingle(i, nil) - if norm.Mean() != mean { - t.Errorf("Mean mismatch non-nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean()) - } - std := math.Sqrt(test.sigma.At(i, i)) - if math.Abs(norm.StdDev()-std) > 1e-14 { - t.Errorf("StdDev mismatch non-nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev()) - } - } } // Test matching with TestMarginal. diff --git a/distmv/normalbench_test.go b/distmv/normalbench_test.go index d497365e..6677ea1c 100644 --- a/distmv/normalbench_test.go +++ b/distmv/normalbench_test.go @@ -7,7 +7,6 @@ package distmv import ( "log" "math/rand" - "sync" "testing" "github.com/gonum/matrix/mat64" @@ -38,8 +37,6 @@ func BenchmarkMarginalNormalReset10(b *testing.B) { if !ok { b.Error("bad test") } - normal.sigma = nil - normal.once = sync.Once{} _ = marg } } diff --git a/distmv/studentst.go b/distmv/studentst.go index 3b0746f9..c691b96c 100644 --- a/distmv/studentst.go +++ b/distmv/studentst.go @@ -8,7 +8,6 @@ import ( "math" "math/rand" "sort" - "sync" "golang.org/x/tools/container/intsets" @@ -36,8 +35,7 @@ type StudentsT struct { mu []float64 src *rand.Rand - once sync.Once - sigma *mat64.SymDense // only stored if needed + sigma mat64.SymDense // only stored if needed chol mat64.Cholesky lower mat64.TriDense @@ -71,6 +69,8 @@ func NewStudentsT(mu []float64, sigma mat64.Symmetric, nu float64, src *rand.Ran if !ok { return nil, false } + s.sigma = *mat64.NewSymDense(dim, nil) + s.sigma.CopySym(sigma) s.lower.LFromCholesky(&s.chol) s.logSqrtDet = 0.5 * s.chol.LogDet() return s, true @@ -101,9 +101,7 @@ func (s *StudentsT) ConditionStudentsT(observed []int, values []float64, src *ra } } - s.setSigma() - - newNu, newMean, newSigma := studentsTConditional(observed, values, s.nu, s.mu, s.sigma) + newNu, newMean, newSigma := studentsTConditional(observed, values, s.nu, s.mu, &s.sigma) if newMean == nil { return nil, false } @@ -231,8 +229,7 @@ func (st *StudentsT) CovarianceMatrix(s *mat64.SymDense) *mat64.SymDense { if sn != st.dim { panic("normal: input matrix size mismatch") } - st.setSigma() - s.CopySym(st.sigma) + s.CopySym(&st.sigma) s.ScaleSym(st.nu/(st.nu-2), s) return s } @@ -286,9 +283,8 @@ func (s *StudentsT) MarginalStudentsT(vars []int, src *rand.Rand) (dist *Student for i, v := range vars { newMean[i] = s.mu[v] } - s.setSigma() var newSigma mat64.SymDense - newSigma.SubsetSym(s.sigma, vars) + newSigma.SubsetSym(&s.sigma, vars) return NewStudentsT(newMean, &newSigma, s.nu, src) } @@ -300,20 +296,9 @@ func (s *StudentsT) MarginalStudentsT(vars []int, src *rand.Rand) (dist *Student // // The input src is passed to the call to NewStudentsT. func (s *StudentsT) MarginalStudentsTSingle(i int, src *rand.Rand) distuv.StudentsT { - var std float64 - if s.sigma != nil { - std = s.sigma.At(i, i) - } else { - // Reconstruct the {i,i} diagonal element of the covariance directly. - for j := 0; j <= i; j++ { - v := s.lower.At(i, j) - std += v * v - } - } - return distuv.StudentsT{ Mu: s.mu[i], - Sigma: math.Sqrt(std), + Sigma: math.Sqrt(s.sigma.At(i, i)), Nu: s.nu, Src: src, } @@ -367,11 +352,3 @@ func (s *StudentsT) Rand(x []float64) []float64 { floats.Add(x, s.mu) return x } - -// setSigma computes and stores the covariance matrix of the distribution. -func (s *StudentsT) setSigma() { - s.once.Do(func() { - s.sigma = mat64.NewSymDense(s.dim, nil) - s.sigma.FromCholesky(&s.chol) - }) -} diff --git a/distmv/studentst_test.go b/distmv/studentst_test.go index 47543a2c..a0f6fc6f 100644 --- a/distmv/studentst_test.go +++ b/distmv/studentst_test.go @@ -182,12 +182,9 @@ func TestStudentsTConditional(t *testing.T) { muOb[i] = test.mean[v] } - s.setSigma() - sUp.setSigma() - var sig11, sig22 mat64.SymDense - sig11.SubsetSym(s.sigma, unob) - sig22.SubsetSym(s.sigma, ob) + sig11.SubsetSym(&s.sigma, unob) + sig22.SubsetSym(&s.sigma, ob) sig12 := mat64.NewDense(len(unob), len(ob), nil) for i := range unob { @@ -221,7 +218,7 @@ func TestStudentsTConditional(t *testing.T) { dot := mat64.Dot(shiftVec, &tmp) tmp3.Scale((test.nu+dot)/(test.nu+float64(len(ob))), &tmp3) - if !mat64.EqualApprox(&tmp3, sUp.sigma, 1e-10) { + if !mat64.EqualApprox(&tmp3, &sUp.sigma, 1e-10) { t.Errorf("Sigma mismatch") } } @@ -248,8 +245,6 @@ func TestStudentsTMarginalSingle(t *testing.T) { if !ok { t.Fatalf("Bad test, covariance matrix not positive definite") } - // Verify with nil Sigma. - studentst.sigma = nil for i, mean := range test.mu { st := studentst.MarginalStudentsTSingle(i, nil) if st.Mean() != mean { @@ -263,18 +258,5 @@ func TestStudentsTMarginalSingle(t *testing.T) { t.Errorf("Nu mismatch nil Sigma, idx %v: want %v, got %v ", i, test.nu, st.Nu) } } - - // Verify with non-nil Sigma. - studentst.setSigma() - for i, mean := range test.mu { - st := studentst.MarginalStudentsTSingle(i, nil) - if st.Mean() != mean { - t.Errorf("Mean mismatch non-nil Sigma, idx %v: want %v, got %v.", i, mean, st.Mean()) - } - std := math.Sqrt(test.sigma.At(i, i)) - if math.Abs(st.Sigma-std) > 1e-14 { - t.Errorf("StdDev mismatch non-nil Sigma, idx %v: want %v, got %v.", i, std, st.StdDev()) - } - } } }