mirror of
https://github.com/gonum/gonum.git
synced 2025-11-01 19:12:45 +08:00
Add Mahalanobis distance and update Normal.LogProb
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"github.com/gonum/stat"
|
||||
"github.com/gonum/stat/distuv"
|
||||
)
|
||||
|
||||
@@ -183,17 +184,9 @@ func (n *Normal) LogProb(x []float64) float64 {
|
||||
if len(x) != dim {
|
||||
panic(badSizeMismatch)
|
||||
}
|
||||
// Compute the normalization constant
|
||||
c := -0.5*float64(dim)*logTwoPi - n.logSqrtDet
|
||||
|
||||
// Compute (x-mu)'Sigma^-1 (x-mu)
|
||||
xMinusMu := make([]float64, dim)
|
||||
floats.SubTo(xMinusMu, x, n.mu)
|
||||
d := mat64.NewVector(dim, xMinusMu)
|
||||
tmp := make([]float64, dim)
|
||||
tmpVec := mat64.NewVector(dim, tmp)
|
||||
tmpVec.SolveCholeskyVec(&n.chol, d)
|
||||
return c - 0.5*floats.Dot(tmp, xMinusMu)
|
||||
dst := stat.Mahalanobis(mat64.NewVector(dim, x), mat64.NewVector(dim, n.mu), &n.chol)
|
||||
return c - 0.5*dst*dst
|
||||
}
|
||||
|
||||
// MarginalNormal returns the marginal distribution of the given input variables.
|
||||
|
||||
@@ -128,3 +128,20 @@ func corrToCov(c *mat64.SymDense, sigma []float64) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mahalanobis computes the Mahalanobis distance
|
||||
// D = sqrt((x-y)^T * Σ^-1 * (x-y))
|
||||
// between the vectors x and y given the cholesky decomposition of Σ.
|
||||
// Mahalanobis returns NaN if the linear solve fails.
|
||||
//
|
||||
// See https://en.wikipedia.org/wiki/Mahalanobis_distance for more information.
|
||||
func Mahalanobis(x, y *mat64.Vector, chol *mat64.Cholesky) float64 {
|
||||
var diff mat64.Vector
|
||||
diff.SubVec(x, y)
|
||||
var tmp mat64.Vector
|
||||
err := tmp.SolveCholeskyVec(chol, &diff)
|
||||
if err != nil {
|
||||
return math.NaN()
|
||||
}
|
||||
return math.Sqrt(mat64.Dot(&tmp, &diff))
|
||||
}
|
||||
@@ -262,6 +262,36 @@ func TestCorrCov(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMahalanobis(t *testing.T) {
|
||||
// Comparison with scipy.
|
||||
for cas, test := range []struct {
|
||||
x, y *mat64.Vector
|
||||
Sigma *mat64.SymDense
|
||||
ans float64
|
||||
}{
|
||||
{
|
||||
x: mat64.NewVector(3, []float64{1, 2, 3}),
|
||||
y: mat64.NewVector(3, []float64{0.8, 1.1, -1}),
|
||||
Sigma: mat64.NewSymDense(3,
|
||||
[]float64{
|
||||
0.8, 0.3, 0.1,
|
||||
0.3, 0.7, -0.1,
|
||||
0.1, -0.1, 7}),
|
||||
ans: 1.9251757377680914,
|
||||
},
|
||||
} {
|
||||
var chol mat64.Cholesky
|
||||
ok := chol.Factorize(test.Sigma)
|
||||
if !ok {
|
||||
panic("bad test")
|
||||
}
|
||||
ans := Mahalanobis(test.x, test.y, &chol)
|
||||
if math.Abs(ans-test.ans) > 1e-14 {
|
||||
t.Errorf("Cas %d: got %v, want %v", cas, ans, test.ans)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarks
|
||||
|
||||
func randMat(r, c int) mat64.Matrix {
|
||||
Reference in New Issue
Block a user