mirror of
https://github.com/gonum/gonum.git
synced 2025-11-01 11:02:45 +08:00
Add Mahalanobis distance and update Normal.LogProb
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gonum/floats"
|
"github.com/gonum/floats"
|
||||||
"github.com/gonum/matrix/mat64"
|
"github.com/gonum/matrix/mat64"
|
||||||
|
"github.com/gonum/stat"
|
||||||
"github.com/gonum/stat/distuv"
|
"github.com/gonum/stat/distuv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -183,17 +184,9 @@ func (n *Normal) LogProb(x []float64) float64 {
|
|||||||
if len(x) != dim {
|
if len(x) != dim {
|
||||||
panic(badSizeMismatch)
|
panic(badSizeMismatch)
|
||||||
}
|
}
|
||||||
// Compute the normalization constant
|
|
||||||
c := -0.5*float64(dim)*logTwoPi - n.logSqrtDet
|
c := -0.5*float64(dim)*logTwoPi - n.logSqrtDet
|
||||||
|
dst := stat.Mahalanobis(mat64.NewVector(dim, x), mat64.NewVector(dim, n.mu), &n.chol)
|
||||||
// Compute (x-mu)'Sigma^-1 (x-mu)
|
return c - 0.5*dst*dst
|
||||||
xMinusMu := make([]float64, dim)
|
|
||||||
floats.SubTo(xMinusMu, x, n.mu)
|
|
||||||
d := mat64.NewVector(dim, xMinusMu)
|
|
||||||
tmp := make([]float64, dim)
|
|
||||||
tmpVec := mat64.NewVector(dim, tmp)
|
|
||||||
tmpVec.SolveCholeskyVec(&n.chol, d)
|
|
||||||
return c - 0.5*floats.Dot(tmp, xMinusMu)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarginalNormal returns the marginal distribution of the given input variables.
|
// MarginalNormal returns the marginal distribution of the given input variables.
|
||||||
|
|||||||
@@ -128,3 +128,20 @@ func corrToCov(c *mat64.SymDense, sigma []float64) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mahalanobis computes the Mahalanobis distance
|
||||||
|
// D = sqrt((x-y)^T * Σ^-1 * (x-y))
|
||||||
|
// between the vectors x and y given the cholesky decomposition of Σ.
|
||||||
|
// Mahalanobis returns NaN if the linear solve fails.
|
||||||
|
//
|
||||||
|
// See https://en.wikipedia.org/wiki/Mahalanobis_distance for more information.
|
||||||
|
func Mahalanobis(x, y *mat64.Vector, chol *mat64.Cholesky) float64 {
|
||||||
|
var diff mat64.Vector
|
||||||
|
diff.SubVec(x, y)
|
||||||
|
var tmp mat64.Vector
|
||||||
|
err := tmp.SolveCholeskyVec(chol, &diff)
|
||||||
|
if err != nil {
|
||||||
|
return math.NaN()
|
||||||
|
}
|
||||||
|
return math.Sqrt(mat64.Dot(&tmp, &diff))
|
||||||
|
}
|
||||||
@@ -262,6 +262,36 @@ func TestCorrCov(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMahalanobis(t *testing.T) {
|
||||||
|
// Comparison with scipy.
|
||||||
|
for cas, test := range []struct {
|
||||||
|
x, y *mat64.Vector
|
||||||
|
Sigma *mat64.SymDense
|
||||||
|
ans float64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
x: mat64.NewVector(3, []float64{1, 2, 3}),
|
||||||
|
y: mat64.NewVector(3, []float64{0.8, 1.1, -1}),
|
||||||
|
Sigma: mat64.NewSymDense(3,
|
||||||
|
[]float64{
|
||||||
|
0.8, 0.3, 0.1,
|
||||||
|
0.3, 0.7, -0.1,
|
||||||
|
0.1, -0.1, 7}),
|
||||||
|
ans: 1.9251757377680914,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
var chol mat64.Cholesky
|
||||||
|
ok := chol.Factorize(test.Sigma)
|
||||||
|
if !ok {
|
||||||
|
panic("bad test")
|
||||||
|
}
|
||||||
|
ans := Mahalanobis(test.x, test.y, &chol)
|
||||||
|
if math.Abs(ans-test.ans) > 1e-14 {
|
||||||
|
t.Errorf("Cas %d: got %v, want %v", cas, ans, test.ans)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// benchmarks
|
// benchmarks
|
||||||
|
|
||||||
func randMat(r, c int) mat64.Matrix {
|
func randMat(r, c int) mat64.Matrix {
|
||||||
Reference in New Issue
Block a user