mirror of
https://github.com/gonum/gonum.git
synced 2025-10-14 11:23:50 +08:00
Add distance functions between probability distributions
This commit is contained in:
184
distmv/statdist.go
Normal file
184
distmv/statdist.go
Normal file
@@ -0,0 +1,184 @@
|
||||
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package distmv
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"github.com/gonum/stat"
|
||||
)
|
||||
|
||||
// Bhattacharyya is a type for computing the Bhattacharyya distance between
|
||||
// probability distributions.
|
||||
//
|
||||
// The Battachara distance is defined as
|
||||
// D_B = -ln(BC(l,r))
|
||||
// BC = \int_x (p(x)q(x))^(1/2) dx
|
||||
// Where BC is known as the Bhattacharyya coefficient.
|
||||
// The Bhattacharyya distance is related to the Hellinger distance by
|
||||
// H = sqrt(1-BC)
|
||||
// For more information, see
|
||||
// https://en.wikipedia.org/wiki/Bhattacharyya_distance
|
||||
type Bhattacharyya struct{}
|
||||
|
||||
// DistNormal computes the Bhattacharyya distance between normal distributions l and r.
|
||||
// The dimensions of the input distributions must match or DistNormal will panic.
|
||||
//
|
||||
// For Normal distributions, the Bhattacharyya distance is
|
||||
// Σ = (Σ_l + Σ_r)/2
|
||||
// D_B = (1/8)*(μ_l - μ_r)^T*Σ^-1*(μ_l - μ_r) + (1/2)*ln(det(Σ)/(det(Σ_l)*det(Σ_r))^(1/2))
|
||||
func (Bhattacharyya) DistNormal(l, r *Normal) float64 {
|
||||
dim := l.Dim()
|
||||
if dim != r.Dim() {
|
||||
panic(badSizeMismatch)
|
||||
}
|
||||
|
||||
var sigma mat64.SymDense
|
||||
sigma.AddSym(&l.sigma, &r.sigma)
|
||||
sigma.ScaleSym(0.5, &sigma)
|
||||
|
||||
var chol mat64.Cholesky
|
||||
chol.Factorize(&sigma)
|
||||
|
||||
mahalanobis := stat.Mahalanobis(mat64.NewVector(dim, l.mu), mat64.NewVector(dim, r.mu), &chol)
|
||||
mahalanobisSq := mahalanobis * mahalanobis
|
||||
|
||||
dl := l.chol.LogDet()
|
||||
dr := r.chol.LogDet()
|
||||
ds := chol.LogDet()
|
||||
|
||||
return 0.125*mahalanobisSq + 0.5*ds - 0.25*dl - 0.25*dr
|
||||
}
|
||||
|
||||
// CrossEntropy is a type for computing the cross-entropy between probability
|
||||
// distributions.
|
||||
//
|
||||
// The cross-entropy is defined as
|
||||
// - \int_x l(x) log(r(x)) dx = KL(l || r) + H(l)
|
||||
// where KL is the Kullback-Leibler divergence and H is the entropy.
|
||||
// For more information, see
|
||||
// https://en.wikipedia.org/wiki/Cross_entropy
|
||||
type CrossEntropy struct{}
|
||||
|
||||
// DistNormal returns the cross-entropy between normal distributions l and r.
|
||||
// The dimensions of the input distributions must match or DistNormal will panic.
|
||||
func (CrossEntropy) DistNormal(l, r *Normal) float64 {
|
||||
if l.Dim() != r.Dim() {
|
||||
panic(badSizeMismatch)
|
||||
}
|
||||
kl := KullbackLeibler{}.DistNormal(l, r)
|
||||
return kl + l.Entropy()
|
||||
}
|
||||
|
||||
// Hellinger is a type for computing the Hellinger distance between probability
|
||||
// distributions.
|
||||
//
|
||||
// The Hellinger distance is defined as
|
||||
// H^2(l,r) = 1/2 * int_x (\sqrt(l(x)) - \sqrt(r(x)))^2 dx
|
||||
// and is bounded between 0 and 1.
|
||||
// The Hellinger distance is related to the Bhattacharyya distance by
|
||||
// H^2 = 1 - exp(-Db)
|
||||
// For more information, see
|
||||
// https://en.wikipedia.org/wiki/Hellinger_distance
|
||||
type Hellinger struct{}
|
||||
|
||||
// DistNormal returns the Hellinger distance between normal distributions l and r.
|
||||
// The dimensions of the input distributions must match or DistNormal will panic.
|
||||
//
|
||||
// See the documentation of Bhattacharyya.DistNormal for the formula for Normal
|
||||
// distributions.
|
||||
func (Hellinger) DistNormal(l, r *Normal) float64 {
|
||||
if l.Dim() != r.Dim() {
|
||||
panic(badSizeMismatch)
|
||||
}
|
||||
db := Bhattacharyya{}.DistNormal(l, r)
|
||||
bc := math.Exp(-db)
|
||||
return math.Sqrt(1 - bc)
|
||||
}
|
||||
|
||||
// KullbackLiebler is a type for computing the Kullback-Leibler divergence from l to r.
|
||||
// The dimensions of the input distributions must match or the function will panic.
|
||||
//
|
||||
// The Kullback-Liebler divergence is defined as
|
||||
// D_KL(l || r ) = \int_x p(x) log(p(x)/q(x)) dx
|
||||
// Note that the Kullback-Liebler divergence is not symmetric with respect to
|
||||
// the order of the input arguments.
|
||||
type KullbackLeibler struct{}
|
||||
|
||||
// DistNormal returns the KullbackLeibler distance between normal distributions l and r.
|
||||
// The dimensions of the input distributions must match or DistNormal will panic.
|
||||
//
|
||||
// For two normal distributions, the KL divergence is computed as
|
||||
// D_KL(l || r) = 0.5*[ln(|Σ_r|) - ln(|Σ_l|) + (μ_l - μ_r)^T*Σ_r^-1*(μ_l - μ_r) + tr(Σ_r^-1*Σ_l)-d]
|
||||
func (KullbackLeibler) DistNormal(l, r *Normal) float64 {
|
||||
dim := l.Dim()
|
||||
if dim != r.Dim() {
|
||||
panic(badSizeMismatch)
|
||||
}
|
||||
|
||||
mahalanobis := stat.Mahalanobis(mat64.NewVector(dim, l.mu), mat64.NewVector(dim, r.mu), &r.chol)
|
||||
mahalanobisSq := mahalanobis * mahalanobis
|
||||
|
||||
// TODO(btracey): Optimize where there is a SolveCholeskySym
|
||||
// TODO(btracey): There may be a more efficient way to just compute the trace
|
||||
// Compute tr(Σ_r^-1*Σ_l) using the fact that Σ_l = U^T * U
|
||||
var u mat64.TriDense
|
||||
u.UFromCholesky(&l.chol)
|
||||
var m mat64.Dense
|
||||
err := m.SolveCholesky(&r.chol, u.T())
|
||||
if err != nil {
|
||||
return math.NaN()
|
||||
}
|
||||
m.Mul(&m, &u)
|
||||
tr := mat64.Trace(&m)
|
||||
|
||||
return r.logSqrtDet - l.logSqrtDet + 0.5*(mahalanobisSq+tr-float64(l.dim))
|
||||
}
|
||||
|
||||
// Wasserstein is a type for computing the Wasserstein distance between two
|
||||
// probability distributions.
|
||||
//
|
||||
// The Wasserstein distance is defined as
|
||||
// W(l,r) := inf 𝔼(||X-Y||_2^2)^1/2
|
||||
// For more information, see
|
||||
// https://en.wikipedia.org/wiki/Wasserstein_metric
|
||||
type Wasserstein struct{}
|
||||
|
||||
// DistNormal returns the Wasserstein distance between normal distributions l and r.
|
||||
// The dimensions of the input distributions must match or DistNormal will panic.
|
||||
//
|
||||
// The Wasserstein distance for Normal distributions is
|
||||
// d^2 = ||m_l - m_r||_2^2 + Tr(Σ_l + Σ_r - 2(Σ_l^(1/2)*Σ_r*Σ_l^(1/2))^(1/2))
|
||||
// For more information, see
|
||||
// http://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/
|
||||
func (Wasserstein) DistNormal(l, r *Normal) float64 {
|
||||
dim := l.Dim()
|
||||
if dim != r.Dim() {
|
||||
panic(badSizeMismatch)
|
||||
}
|
||||
|
||||
d := floats.Distance(l.mu, r.mu, 2)
|
||||
d = d * d
|
||||
|
||||
// Compute Σ_l^(1/2)
|
||||
var ssl mat64.SymDense
|
||||
ssl.PowPSD(&l.sigma, 0.5)
|
||||
// Compute Σ_l^(1/2)*Σ_r*Σ_l^(1/2)
|
||||
var mean mat64.Dense
|
||||
mean.Mul(&ssl, &r.sigma)
|
||||
mean.Mul(&mean, &ssl)
|
||||
|
||||
// Reinterpret as symdense, and take Σ^(1/2)
|
||||
meanSym := mat64.NewSymDense(dim, mean.RawMatrix().Data)
|
||||
ssl.PowPSD(meanSym, 0.5)
|
||||
|
||||
tr := mat64.Trace(&r.sigma)
|
||||
tl := mat64.Trace(&l.sigma)
|
||||
tm := mat64.Trace(&ssl)
|
||||
|
||||
return d + tl + tr - 2*tm
|
||||
}
|
181
distmv/statdist_test.go
Normal file
181
distmv/statdist_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package distmv
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
func TestBhattacharyyaNormal(t *testing.T) {
|
||||
for cas, test := range []struct {
|
||||
am, bm []float64
|
||||
ac, bc *mat64.SymDense
|
||||
samples int
|
||||
tol float64
|
||||
}{
|
||||
{
|
||||
am: []float64{2, 3},
|
||||
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}),
|
||||
bm: []float64{-1, 1},
|
||||
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
|
||||
samples: 100000,
|
||||
tol: 1e-2,
|
||||
},
|
||||
} {
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
a, ok := NewNormal(test.am, test.ac, rnd)
|
||||
if !ok {
|
||||
panic("bad test")
|
||||
}
|
||||
b, ok := NewNormal(test.bm, test.bc, rnd)
|
||||
if !ok {
|
||||
panic("bad test")
|
||||
}
|
||||
lBhatt := make([]float64, test.samples)
|
||||
x := make([]float64, a.Dim())
|
||||
for i := 0; i < test.samples; i++ {
|
||||
// Do importance sampling over a: \int sqrt(a*b)/a * a dx
|
||||
a.Rand(x)
|
||||
pa := a.LogProb(x)
|
||||
pb := b.LogProb(x)
|
||||
lBhatt[i] = 0.5*pb - 0.5*pa
|
||||
}
|
||||
logBc := floats.LogSumExp(lBhatt) - math.Log(float64(test.samples))
|
||||
db := -logBc
|
||||
got := Bhattacharyya{}.DistNormal(a, b)
|
||||
if math.Abs(db-got) > test.tol {
|
||||
t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, db)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCrossEntropyNormal(t *testing.T) {
|
||||
for cas, test := range []struct {
|
||||
am, bm []float64
|
||||
ac, bc *mat64.SymDense
|
||||
samples int
|
||||
tol float64
|
||||
}{
|
||||
{
|
||||
am: []float64{2, 3},
|
||||
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}),
|
||||
bm: []float64{-1, 1},
|
||||
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
|
||||
samples: 100000,
|
||||
tol: 1e-2,
|
||||
},
|
||||
} {
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
a, ok := NewNormal(test.am, test.ac, rnd)
|
||||
if !ok {
|
||||
panic("bad test")
|
||||
}
|
||||
b, ok := NewNormal(test.bm, test.bc, rnd)
|
||||
if !ok {
|
||||
panic("bad test")
|
||||
}
|
||||
var ce float64
|
||||
x := make([]float64, a.Dim())
|
||||
for i := 0; i < test.samples; i++ {
|
||||
a.Rand(x)
|
||||
ce -= b.LogProb(x)
|
||||
}
|
||||
ce /= float64(test.samples)
|
||||
got := CrossEntropy{}.DistNormal(a, b)
|
||||
if math.Abs(ce-got) > test.tol {
|
||||
t.Errorf("CrossEntropy mismatch, case %d: got %v, want %v", cas, got, ce)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHellingerNormal(t *testing.T) {
|
||||
for cas, test := range []struct {
|
||||
am, bm []float64
|
||||
ac, bc *mat64.SymDense
|
||||
samples int
|
||||
tol float64
|
||||
}{
|
||||
{
|
||||
am: []float64{2, 3},
|
||||
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}),
|
||||
bm: []float64{-1, 1},
|
||||
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
|
||||
samples: 100000,
|
||||
tol: 5e-1,
|
||||
},
|
||||
} {
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
a, ok := NewNormal(test.am, test.ac, rnd)
|
||||
if !ok {
|
||||
panic("bad test")
|
||||
}
|
||||
b, ok := NewNormal(test.bm, test.bc, rnd)
|
||||
if !ok {
|
||||
panic("bad test")
|
||||
}
|
||||
lAitchEDoubleHockeySticks := make([]float64, test.samples)
|
||||
x := make([]float64, a.Dim())
|
||||
for i := 0; i < test.samples; i++ {
|
||||
// Do importance sampling over a: \int (\sqrt(a)-\sqrt(b))^2/a * a dx
|
||||
a.Rand(x)
|
||||
pa := a.LogProb(x)
|
||||
pb := b.LogProb(x)
|
||||
d := math.Exp(0.5*pa) - math.Exp(0.5*pb)
|
||||
d = d * d
|
||||
lAitchEDoubleHockeySticks[i] = math.Log(d) - pa
|
||||
}
|
||||
want := math.Sqrt(0.5 * math.Exp(floats.LogSumExp(lAitchEDoubleHockeySticks)-math.Log(float64(test.samples))))
|
||||
got := Hellinger{}.DistNormal(a, b)
|
||||
if math.Abs(want-got) > test.tol {
|
||||
t.Errorf("Hellinger mismatch, case %d: got %v, want %v", cas, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKullbackLieblerNormal(t *testing.T) {
|
||||
for cas, test := range []struct {
|
||||
am, bm []float64
|
||||
ac, bc *mat64.SymDense
|
||||
samples int
|
||||
tol float64
|
||||
}{
|
||||
{
|
||||
am: []float64{2, 3},
|
||||
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}),
|
||||
bm: []float64{-1, 1},
|
||||
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
|
||||
samples: 10000,
|
||||
tol: 1e-2,
|
||||
},
|
||||
} {
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
a, ok := NewNormal(test.am, test.ac, rnd)
|
||||
if !ok {
|
||||
panic("bad test")
|
||||
}
|
||||
b, ok := NewNormal(test.bm, test.bc, rnd)
|
||||
if !ok {
|
||||
panic("bad test")
|
||||
}
|
||||
var klmc float64
|
||||
x := make([]float64, a.Dim())
|
||||
for i := 0; i < test.samples; i++ {
|
||||
a.Rand(x)
|
||||
pa := a.LogProb(x)
|
||||
pb := b.LogProb(x)
|
||||
klmc += pa - pb
|
||||
}
|
||||
klmc /= float64(test.samples)
|
||||
kl := KullbackLeibler{}.DistNormal(a, b)
|
||||
if !floats.EqualWithinAbsOrRel(kl, klmc, test.tol, test.tol) {
|
||||
t.Errorf("Case %d, KL mismatch: got %v, want %v", cas, kl, klmc)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user