mirror of
https://github.com/gonum/gonum.git
synced 2025-10-18 13:10:47 +08:00
stat/all: reduce random size and test tolerance to decrease testing t… (#181)
* stat/all: reduce random size and test tolerance to decrease testing time. We were generating a lot of random numbers, which is slow. Decrease the size of those random numbers, and in some cases increase the tolerance to compensate. In a couple cases, pull out code from testFullDist to allow for more fine-grained testing. This decrases: distmat from 4.5s to 0.5s distmv from 24.8s to 9s distuv from 65.2s to 13s samplemv from 2.8s to 1.2s sampleuv from 3.5s to 2.1s
This commit is contained in:
@@ -82,7 +82,7 @@ func TestImportance(t *testing.T) {
|
||||
weights := make([]float64, nSamples)
|
||||
Importance(batch, weights, target, proposal)
|
||||
|
||||
compareNormal(t, target, batch, weights)
|
||||
compareNormal(t, target, batch, weights, 5e-2, 5e-2)
|
||||
}
|
||||
|
||||
func TestRejection(t *testing.T) {
|
||||
@@ -145,14 +145,14 @@ func TestMetropolisHastings(t *testing.T) {
|
||||
t.Fatal("bad test, sigma not pos def")
|
||||
}
|
||||
|
||||
nSamples := 1000000
|
||||
nSamples := 100000
|
||||
burnin := 5000
|
||||
batch := mat.NewDense(nSamples, dim, nil)
|
||||
initial := make([]float64, dim)
|
||||
MetropolisHastings(batch, initial, target, proposal, nil)
|
||||
batch = batch.Slice(burnin, nSamples, 0, dim).(*mat.Dense)
|
||||
|
||||
compareNormal(t, target, batch, nil)
|
||||
compareNormal(t, target, batch, nil, 5e-1, 5e-1)
|
||||
}
|
||||
|
||||
// randomNormal constructs a random Normal distribution.
|
||||
@@ -171,7 +171,7 @@ func randomNormal(dim int) (*distmv.Normal, bool) {
|
||||
return distmv.NewNormal(mu, &sigma, nil)
|
||||
}
|
||||
|
||||
func compareNormal(t *testing.T, want *distmv.Normal, batch *mat.Dense, weights []float64) {
|
||||
func compareNormal(t *testing.T, want *distmv.Normal, batch *mat.Dense, weights []float64, meanTol, covTol float64) {
|
||||
dim := want.Dim()
|
||||
mu := want.Mean(nil)
|
||||
sigma := want.CovarianceMatrix(nil)
|
||||
@@ -185,13 +185,13 @@ func compareNormal(t *testing.T, want *distmv.Normal, batch *mat.Dense, weights
|
||||
for i := 0; i < dim; i++ {
|
||||
col := mat.Col(nil, i, batch)
|
||||
ev := stat.Mean(col, weights)
|
||||
if math.Abs(ev-mu[i]) > 1e-2 {
|
||||
if math.Abs(ev-mu[i]) > meanTol {
|
||||
t.Errorf("Mean mismatch: Want %v, got %v", mu[i], ev)
|
||||
}
|
||||
}
|
||||
|
||||
cov := stat.CovarianceMatrix(nil, batch, weights)
|
||||
if !mat.EqualApprox(cov, sigma, 1.5e-1) {
|
||||
if !mat.EqualApprox(cov, sigma, covTol) {
|
||||
t.Errorf("Covariance matrix mismatch")
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user