mirror of
https://github.com/gonum/gonum.git
synced 2025-10-23 15:13:31 +08:00
split out non-blas impl into new function
Splitting out the non blas path makes the code flow easier to read.
This commit is contained in:
@@ -6,17 +6,10 @@ package stat
|
||||
|
||||
import (
|
||||
"sync"
|
||||
// "runtime"
|
||||
|
||||
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
|
||||
type covMatSlice struct {
|
||||
i, j int
|
||||
x, y []float64
|
||||
}
|
||||
|
||||
// CovarianceMatrix calculates a covariance matrix (also known as a
|
||||
// variance-covariance matrix) from a matrix of data, using a two-pass
|
||||
// algorithm. It will have better performance if a BLAS engine is
|
||||
@@ -27,71 +20,12 @@ func CovarianceMatrix(x mat64.Matrix) *mat64.Dense {
|
||||
|
||||
// matrix version of the two pass algorithm. This doesn't use
|
||||
// the correction found in the Covariance and Variance functions.
|
||||
r, c := x.Dims()
|
||||
|
||||
if x, ok := x.(mat64.Vectorer); ok {
|
||||
cols := make([][]float64, c)
|
||||
// perform the covariance or variance as required
|
||||
blockSize := 1024
|
||||
if blockSize > c {
|
||||
blockSize = c
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(c)
|
||||
for j := 0; j < c; j++ {
|
||||
go func(j int) {
|
||||
// pull the columns out and subtract the means
|
||||
cols[j] = make([]float64, r)
|
||||
x.Col(cols[j], j)
|
||||
mean := Mean(cols[j], nil)
|
||||
for i := range cols[j] {
|
||||
cols[j][i] -= mean
|
||||
}
|
||||
wg.Done()
|
||||
}(j)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
colCh := make(chan covMatSlice, blockSize)
|
||||
|
||||
wg.Add(blockSize)
|
||||
m := mat64.NewDense(c, c, nil)
|
||||
for i := 0; i < blockSize; i++ {
|
||||
go func(in <-chan covMatSlice) {
|
||||
for {
|
||||
xy, more := <-in
|
||||
if !more {
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
|
||||
if xy.i == xy.j {
|
||||
m.Set(xy.i, xy.j, centeredVariance(xy.x))
|
||||
continue
|
||||
}
|
||||
v := centeredCovariance(xy.x, xy.y)
|
||||
m.Set(xy.i, xy.j, v)
|
||||
m.Set(xy.j, xy.i, v)
|
||||
}
|
||||
}(colCh)
|
||||
}
|
||||
go func(out chan<- covMatSlice) {
|
||||
for i := 0; i < c; i++ {
|
||||
for j := 0; j <= i; j++ {
|
||||
out <- covMatSlice{
|
||||
i: i,
|
||||
j: j,
|
||||
x: cols[i],
|
||||
y: cols[j],
|
||||
}
|
||||
}
|
||||
}
|
||||
close(out)
|
||||
}(colCh)
|
||||
// create the output matrix
|
||||
wg.Wait()
|
||||
return m
|
||||
if mat64.Registered() == nil {
|
||||
// implementation that doesn't rely on a blasEngine
|
||||
return covarianceMatrixWithoutBLAS(x)
|
||||
}
|
||||
r, _ := x.Dims()
|
||||
|
||||
// determine the mean of each of the columns
|
||||
b := ones(1, r)
|
||||
b.Mul(b, x)
|
||||
@@ -107,18 +41,99 @@ func CovarianceMatrix(x mat64.Matrix) *mat64.Dense {
|
||||
}
|
||||
}
|
||||
|
||||
// todo: avoid matrix copy?
|
||||
var xt mat64.Dense
|
||||
xt.TCopy(xc)
|
||||
|
||||
// It would be nice if we could indicate that this was a symmetric
|
||||
// matrix.
|
||||
// TODO: indicate that the resulting matrix is symmetric, which
|
||||
// should improve performance.
|
||||
var ss mat64.Dense
|
||||
ss.Mul(&xt, xc)
|
||||
ss.Scale(1/float64(r-1), &ss)
|
||||
return &ss
|
||||
}
|
||||
|
||||
type covMatSlice struct {
|
||||
i, j int
|
||||
x, y []float64
|
||||
}
|
||||
|
||||
func covarianceMatrixWithoutBLAS(x mat64.Matrix) *mat64.Dense {
|
||||
r, c := x.Dims()
|
||||
|
||||
// split out the matrix into columns
|
||||
cols := make([][]float64, c)
|
||||
for j := range cols {
|
||||
cols[j] = make([]float64, r)
|
||||
}
|
||||
|
||||
if xRaw, ok := x.(mat64.RawMatrixer); ok {
|
||||
for k, v := range xRaw.RawMatrix().Data {
|
||||
i := k / c
|
||||
j := k % c
|
||||
cols[j][i] = v
|
||||
}
|
||||
} else {
|
||||
for j := 0; j < c; j++ {
|
||||
for i := 0; i < r; i++ {
|
||||
cols[j][i] = x.At(i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// center the columns
|
||||
for j := range cols {
|
||||
mean := Mean(cols[j], nil)
|
||||
for i := range cols[j] {
|
||||
cols[j][i] -= mean
|
||||
}
|
||||
}
|
||||
|
||||
blockSize := 1024
|
||||
if blockSize > c {
|
||||
blockSize = c
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(blockSize)
|
||||
colCh := make(chan covMatSlice, blockSize)
|
||||
|
||||
m := mat64.NewDense(c, c, nil)
|
||||
for i := 0; i < blockSize; i++ {
|
||||
go func(in <-chan covMatSlice) {
|
||||
for {
|
||||
xy, more := <-in
|
||||
if !more {
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
|
||||
if xy.i == xy.j {
|
||||
m.Set(xy.i, xy.j, centeredVariance(xy.x))
|
||||
continue
|
||||
}
|
||||
v := centeredCovariance(xy.x, xy.y)
|
||||
m.Set(xy.i, xy.j, v)
|
||||
m.Set(xy.j, xy.i, v)
|
||||
}
|
||||
}(colCh)
|
||||
}
|
||||
go func(out chan<- covMatSlice) {
|
||||
for i := 0; i < c; i++ {
|
||||
for j := 0; j <= i; j++ {
|
||||
out <- covMatSlice{
|
||||
i: i,
|
||||
j: j,
|
||||
x: cols[i],
|
||||
y: cols[j],
|
||||
}
|
||||
}
|
||||
}
|
||||
close(out)
|
||||
}(colCh)
|
||||
// create the output matrix
|
||||
wg.Wait()
|
||||
return m
|
||||
}
|
||||
|
||||
// ones is a matrix of all ones.
|
||||
func ones(r, c int) *mat64.Dense {
|
||||
x := make([]float64, r*c)
|
||||
|
Reference in New Issue
Block a user