Initial work on covariancematrix

This commit is contained in:
Jonathan J Lawlor
2014-11-15 00:49:46 -05:00
parent 007b3e00f6
commit 63a1bac14c
2 changed files with 100 additions and 0 deletions

48
covariancematrix.go Normal file
View File

@@ -0,0 +1,48 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stat
import (
"github.com/gonum/matrix/mat64"
)
// CovarianceMatrix calculates a covariance matrix (also known as a
// variance-covariance matrix) from a matrix of data.
func CovarianceMatrix(x mat64.Matrix) *mat64.Dense {
// matrix version of the two pass algorithm. This doesn't use
// the correction found in the Covariance and Variance functions.
r, _ := x.Dims()
b := ones(1, r)
b.Mul(b, x)
b.Scale(1/float64(r), b)
// todo: avoid unneeded memory expansion here.
mu := new(mat64.Dense)
mu.Mul(ones(r,1),b)
// this could also be done with a clone & row viewer
xc := mat64.DenseCopyOf(x)
xc.Sub(xc, mu)
// todo: avoid matrix copy
xt := new(mat64.Dense)
xt.TCopy(xc)
ss := new(mat64.Dense)
ss.Mul(xt, xc)
ss.Scale(1/float64(r-1), ss)
return ss
}
// ones is a matrix of all ones.
func ones(r, c int) *mat64.Dense {
x := make([]float64, r*c)
for i := range x {
x[i] = 1
}
return mat64.NewDense(r, c, x)
}

52
covariancematrix_test.go Normal file
View File

@@ -0,0 +1,52 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stat
import (
"testing"
"github.com/gonum/blas/cblas"
"github.com/gonum/floats"
"github.com/gonum/matrix/mat64"
)
func init() {
mat64.Register(cblas.Blas{})
}
func TestCovarianceMatrix(t *testing.T) {
for i, test := range []struct {
mat mat64.Matrix
r, c int
x []float64
}{
{
mat: mat64.NewDense(5, 2, []float64{
-2, -4,
-1, 2,
0, 0,
1, -2,
2, 4,
}),
r: 2,
c: 2,
x: []float64{
2.5, 3,
3, 10,
},
},
} {
c := CovarianceMatrix(test.mat).RawMatrix()
if c.Rows != test.r {
t.Errorf("%d: expected rows %d, found %d", i, test.r, c.Rows)
}
if c.Cols != test.c {
t.Errorf("%d: expected cols %d, found %d", i, test.c, c.Cols)
}
if !floats.Equal(test.x, c.Data) {
t.Errorf("%d: expected data %#q, found %#q", i, test.x, c.Data)
}
}
}