diff --git a/stat_test.go b/stat_test.go index 3079e333..8296fad7 100644 --- a/stat_test.go +++ b/stat_test.go @@ -116,6 +116,51 @@ func ExampleCovariance() { // Cov2 is 37.7000, VarX is 37.7000 } +func TestCovariance(t *testing.T) { + for i, test := range []struct { + p []float64 + q []float64 + weights []float64 + ans float64 + }{ + { + p: []float64{0.75, 0.1, 0.05}, + q: []float64{0.5, 0.25, 0.25}, + ans: 0.05625, + }, + { + p: []float64{1, 2, 3}, + q: []float64{2, 4, 6}, + ans: 2, + }, + { + p: []float64{1, 2, 3}, + q: []float64{1, 4, 9}, + ans: 4, + }, + { + p: []float64{1, 2, 3}, + q: []float64{1, 4, 9}, + weights: []float64{1, 1.5, 1}, + ans: 3.2, + }, + } { + c := Covariance(test.p, Mean(test.p, test.weights), test.q, Mean(test.q, test.weights), test.weights) + if math.Abs(c-test.ans) > 1e-14 { + t.Errorf("Covariance mismatch case %d: Expected %v, Found %v", i, test.ans, c) + } + } + + // test the panic states + if !Panics(func() { Covariance(make([]float64, 2), 0.0, make([]float64, 3), 0.0, nil) }) { + t.Errorf("Covariance did not panic with x, y length mismatch") + } + if !Panics(func() { Covariance(make([]float64, 3), 0.0, make([]float64, 3), 0.0, make([]float64, 2)) }) { + t.Errorf("Covariance did not panic with x, weights length mismatch") + } + +} + func TestCrossEntropy(t *testing.T) { for i, test := range []struct { p []float64