mirror of
https://github.com/gonum/gonum.git
synced 2025-10-07 08:01:20 +08:00
131 lines
2.9 KiB
Go
131 lines
2.9 KiB
Go
// Copyright ©2016 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package distmat
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
|
|
"golang.org/x/exp/rand"
|
|
|
|
"gonum.org/v1/gonum/floats"
|
|
"gonum.org/v1/gonum/mat"
|
|
)
|
|
|
|
func TestWishart(t *testing.T) {
|
|
for c, test := range []struct {
|
|
v *mat.SymDense
|
|
nu float64
|
|
xs []*mat.SymDense
|
|
lps []float64
|
|
}{
|
|
// Logprob data compared with scipy.
|
|
{
|
|
v: mat.NewSymDense(2, []float64{1, 0, 0, 1}),
|
|
nu: 4,
|
|
xs: []*mat.SymDense{
|
|
mat.NewSymDense(2, []float64{0.9, 0.1, 0.1, 0.9}),
|
|
},
|
|
lps: []float64{-4.2357432031863409},
|
|
},
|
|
{
|
|
v: mat.NewSymDense(2, []float64{0.8, -0.2, -0.2, 0.7}),
|
|
nu: 5,
|
|
xs: []*mat.SymDense{
|
|
mat.NewSymDense(2, []float64{0.9, 0.1, 0.1, 0.9}),
|
|
mat.NewSymDense(2, []float64{0.3, -0.1, -0.1, 0.7}),
|
|
},
|
|
lps: []float64{-4.2476495605333575, -4.9993285370378633},
|
|
},
|
|
{
|
|
v: mat.NewSymDense(3, []float64{0.8, 0.3, 0.1, 0.3, 0.7, -0.1, 0.1, -0.1, 7}),
|
|
nu: 5,
|
|
xs: []*mat.SymDense{
|
|
mat.NewSymDense(3, []float64{1, 0.2, -0.3, 0.2, 0.6, -0.2, -0.3, -0.2, 6}),
|
|
},
|
|
lps: []float64{-11.010982249229421},
|
|
},
|
|
} {
|
|
w, ok := NewWishart(test.v, test.nu, nil)
|
|
if !ok {
|
|
panic("bad test")
|
|
}
|
|
for i, x := range test.xs {
|
|
lp := w.LogProbSym(x)
|
|
|
|
var chol mat.Cholesky
|
|
ok := chol.Factorize(x)
|
|
if !ok {
|
|
panic("bad test")
|
|
}
|
|
lpc := w.LogProbSymChol(&chol)
|
|
|
|
if math.Abs(lp-lpc) > 1e-14 {
|
|
t.Errorf("Case %d, test %d: probability mismatch between chol and not", c, i)
|
|
}
|
|
if !floats.EqualWithinAbsOrRel(lp, test.lps[i], 1e-14, 1e-14) {
|
|
t.Errorf("Case %d, test %d: got %v, want %v", c, i, lp, test.lps[i])
|
|
}
|
|
}
|
|
|
|
ch := w.RandChol(nil)
|
|
w.RandChol(ch)
|
|
|
|
s := w.RandSym(nil)
|
|
w.RandSym(s)
|
|
|
|
}
|
|
}
|
|
|
|
func TestWishartRand(t *testing.T) {
|
|
for c, test := range []struct {
|
|
v *mat.SymDense
|
|
nu float64
|
|
samples int
|
|
tol float64
|
|
}{
|
|
{
|
|
v: mat.NewSymDense(2, []float64{0.8, -0.2, -0.2, 0.7}),
|
|
nu: 5,
|
|
samples: 30000,
|
|
tol: 3e-2,
|
|
},
|
|
{
|
|
v: mat.NewSymDense(3, []float64{0.8, 0.3, 0.1, 0.3, 0.7, -0.1, 0.1, -0.1, 7}),
|
|
nu: 5,
|
|
samples: 30000,
|
|
tol: 3e-1,
|
|
},
|
|
{
|
|
v: mat.NewSymDense(4, []float64{
|
|
0.8, 0.3, 0.1, -0.2,
|
|
0.3, 0.7, -0.1, 0.4,
|
|
0.1, -0.1, 7, 1,
|
|
-0.2, -0.1, 1, 6}),
|
|
nu: 6,
|
|
samples: 30000,
|
|
tol: 1e-1,
|
|
},
|
|
} {
|
|
rnd := rand.New(rand.NewSource(1))
|
|
dim := test.v.Symmetric()
|
|
w, ok := NewWishart(test.v, test.nu, rnd)
|
|
if !ok {
|
|
panic("bad test")
|
|
}
|
|
mean := mat.NewSymDense(dim, nil)
|
|
x := mat.NewSymDense(dim, nil)
|
|
for i := 0; i < test.samples; i++ {
|
|
w.RandSym(x)
|
|
x.ScaleSym(1/float64(test.samples), x)
|
|
mean.AddSym(mean, x)
|
|
}
|
|
trueMean := w.MeanSym(nil)
|
|
if !mat.EqualApprox(trueMean, mean, test.tol) {
|
|
t.Errorf("Case %d: Mismatch between estimated and true mean. Got\n%0.4v\nWant\n%0.4v\n", c, mat.Formatted(mean), mat.Formatted(trueMean))
|
|
}
|
|
}
|
|
}
|