mirror of
https://github.com/gonum/gonum.git
synced 2025-09-26 19:21:17 +08:00
stat: implement Wasserstein distance calculation
This commit is contained in:
272
stat/stat.go
272
stat/stat.go
@@ -9,6 +9,8 @@ import (
|
||||
"sort"
|
||||
|
||||
"gonum.org/v1/gonum/floats"
|
||||
"gonum.org/v1/gonum/mat"
|
||||
"gonum.org/v1/gonum/optimize/convex/lp"
|
||||
)
|
||||
|
||||
// CumulantKind specifies the behavior for calculating the empirical CDF or Quantile
|
||||
@@ -24,6 +26,276 @@ const (
|
||||
LinInterp CumulantKind = 4
|
||||
)
|
||||
|
||||
// WassersteinDistance computes the Wasserstein distance (Earth Mover's Distance)
|
||||
// between two 1D distributions p and q with optional weights. p and q are
|
||||
// the support points of each distribution. pWeights and qWeights are the weights
|
||||
// for each point.
|
||||
// If a weights slice is nil, then uniform weights are used. If it is not nil,
|
||||
// then the length of the weights slice must equal the length of the corresponding points.
|
||||
// Otherwise, the function will panic.
|
||||
//
|
||||
// The function returns the 1-Wasserstein distance (L1 metric).
|
||||
// This implementation uses the CDF-based algorithm for the 1D case.
|
||||
func WassersteinDistance(p, q, pWeights, qWeights []float64) float64 {
|
||||
if len(p) == 0 || len(q) == 0 {
|
||||
return math.NaN()
|
||||
}
|
||||
|
||||
if (pWeights != nil && len(p) != len(pWeights)) || (qWeights != nil && len(q) != len(qWeights)) {
|
||||
panic("stat: input distributions and their weights must have same length")
|
||||
}
|
||||
|
||||
// Special case for single-point distributions
|
||||
if len(p) == 1 && len(q) == 1 {
|
||||
return math.Abs(p[0] - q[0])
|
||||
}
|
||||
|
||||
// Handle identical distributions case.
|
||||
if floats.Equal(p, q) && floats.Equal(pWeights, qWeights) {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Use uniform weights if not provided.
|
||||
pUniform := pWeights == nil
|
||||
qUniform := qWeights == nil
|
||||
|
||||
var pUniformWeight, qUniformWeight float64
|
||||
|
||||
if pUniform {
|
||||
pUniformWeight = 1 / float64(len(p))
|
||||
} else {
|
||||
normalizeWeights(pWeights)
|
||||
}
|
||||
|
||||
if qUniform {
|
||||
qUniformWeight = 1 / float64(len(q))
|
||||
} else {
|
||||
normalizeWeights(qWeights)
|
||||
}
|
||||
|
||||
// Pair values with weights and sort.
|
||||
type pair struct {
|
||||
value float64
|
||||
weight float64
|
||||
}
|
||||
|
||||
pPairs := make([]pair, len(p))
|
||||
qPairs := make([]pair, len(q))
|
||||
for i := range p {
|
||||
if pUniform {
|
||||
pPairs[i] = pair{p[i], pUniformWeight}
|
||||
} else {
|
||||
pPairs[i] = pair{p[i], pWeights[i]}
|
||||
}
|
||||
}
|
||||
for i := range q {
|
||||
if qUniform {
|
||||
qPairs[i] = pair{q[i], qUniformWeight}
|
||||
} else {
|
||||
qPairs[i] = pair{q[i], qWeights[i]}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(pPairs, func(i, j int) bool { return pPairs[i].value < pPairs[j].value })
|
||||
sort.Slice(qPairs, func(i, j int) bool { return qPairs[i].value < qPairs[j].value })
|
||||
|
||||
// Compute CDFs.
|
||||
pCDF := make([]float64, len(pPairs))
|
||||
qCDF := make([]float64, len(qPairs))
|
||||
|
||||
pCDF[0] = pPairs[0].weight
|
||||
for i := 1; i < len(pPairs); i++ {
|
||||
pCDF[i] = pCDF[i-1] + pPairs[i].weight
|
||||
}
|
||||
|
||||
qCDF[0] = qPairs[0].weight
|
||||
for i := 1; i < len(qPairs); i++ {
|
||||
qCDF[i] = qCDF[i-1] + qPairs[i].weight
|
||||
}
|
||||
|
||||
// Merge the support points.
|
||||
allPoints := make([]float64, 0, len(pPairs)+len(qPairs))
|
||||
for _, pair := range pPairs {
|
||||
allPoints = append(allPoints, pair.value)
|
||||
}
|
||||
for _, pair := range qPairs {
|
||||
allPoints = append(allPoints, pair.value)
|
||||
}
|
||||
|
||||
sort.Float64s(allPoints)
|
||||
|
||||
// Remove duplicates.
|
||||
uniquePoints := make([]float64, 0, len(allPoints))
|
||||
for i := 0; i < len(allPoints); i++ {
|
||||
if i == 0 || allPoints[i] != allPoints[i-1] {
|
||||
uniquePoints = append(uniquePoints, allPoints[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate the Wasserstein distance.
|
||||
var distance float64
|
||||
|
||||
for i := 0; i < len(uniquePoints)-1; i++ {
|
||||
x1 := uniquePoints[i]
|
||||
x2 := uniquePoints[i+1]
|
||||
|
||||
// Find CDF values at x1.
|
||||
var pCDFAtX1 float64
|
||||
for j := 0; j < len(pPairs); j++ {
|
||||
if pPairs[j].value > x1 {
|
||||
break
|
||||
}
|
||||
pCDFAtX1 = pCDF[j]
|
||||
}
|
||||
|
||||
var qCDFAtX1 float64
|
||||
for j := 0; j < len(qPairs); j++ {
|
||||
if qPairs[j].value > x1 {
|
||||
break
|
||||
}
|
||||
qCDFAtX1 = qCDF[j]
|
||||
}
|
||||
|
||||
distance += float64(math.Abs(pCDFAtX1-qCDFAtX1) * (x2 - x1))
|
||||
}
|
||||
|
||||
return distance
|
||||
}
|
||||
|
||||
// WassersteinDistanceND computes the Wasserstein distance (Earth Mover's Distance)
|
||||
// between two n-dimensional distributions p and q with optional weights. p and q are
|
||||
// matrices where each row represents a point. pWeights and qWeights are the weights for each point.
|
||||
// The number of columns in p and q must be equal.
|
||||
// If a weights slice is nil, then uniform weights are used. If it is not nil, then the
|
||||
// length of the weights slice must equal the number of rows in the corresponding matrix.
|
||||
// WassersteinDistanceND will panic if these conditions are not met.
|
||||
//
|
||||
// This implementation uses linear programming to solve the optimal transport problem.
|
||||
// tol controls the solver's tolerance. See lp.Simplex for more information on this.
|
||||
func WassersteinDistanceND(p, q mat.Matrix, pWeights, qWeights []float64, tol float64) (float64, error) {
|
||||
pRows, pCols := p.Dims()
|
||||
qRows, qCols := q.Dims()
|
||||
|
||||
if pCols == 0 || qCols == 0 || pRows == 0 || qRows == 0 {
|
||||
return math.NaN(), nil
|
||||
}
|
||||
|
||||
if (pWeights != nil && pRows != len(pWeights)) || (qWeights != nil && qRows != len(qWeights)) {
|
||||
panic("stat: input distributions and their weights must have same length")
|
||||
}
|
||||
|
||||
if pCols != qCols {
|
||||
panic("stat: point dimensions must match between distributions")
|
||||
}
|
||||
|
||||
// Special case for single-point distributions
|
||||
if pRows == 1 && qRows == 1 {
|
||||
var sumSquares float64
|
||||
for j := 0; j < pCols; j++ {
|
||||
diff := p.At(0, j) - q.At(0, j)
|
||||
sumSquares += float64(diff * diff)
|
||||
}
|
||||
return math.Sqrt(sumSquares), nil
|
||||
}
|
||||
|
||||
// Handle identical distributions case.
|
||||
if pRows == qRows {
|
||||
identical := true
|
||||
for i := 0; i < pRows; i++ {
|
||||
for j := 0; j < pCols; j++ {
|
||||
if p.At(i, j) != q.At(i, j) {
|
||||
identical = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if identical {
|
||||
if floats.Equal(pWeights, qWeights) {
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use uniform weights if not provided.
|
||||
if pWeights == nil {
|
||||
pWeights = make([]float64, pRows)
|
||||
floats.AddConst(1/float64(pRows), pWeights)
|
||||
} else {
|
||||
normalizeWeights(pWeights)
|
||||
}
|
||||
|
||||
if qWeights == nil {
|
||||
qWeights = make([]float64, qRows)
|
||||
floats.AddConst(1/float64(qRows), qWeights)
|
||||
} else {
|
||||
normalizeWeights(qWeights)
|
||||
}
|
||||
|
||||
// Create cost matrix using Euclidean distance.
|
||||
cost := mat.NewDense(pRows, qRows, nil)
|
||||
|
||||
for i := 0; i < pRows; i++ {
|
||||
for j := 0; j < qRows; j++ {
|
||||
// Calculate Euclidean distance between points
|
||||
var sumSquares float64
|
||||
for k := 0; k < pCols; k++ {
|
||||
diff := p.At(i, k) - q.At(j, k)
|
||||
sumSquares += float64(diff * diff)
|
||||
}
|
||||
dist := math.Sqrt(sumSquares)
|
||||
cost.Set(i, j, dist)
|
||||
}
|
||||
}
|
||||
|
||||
return solveOptimalTransportLP(cost, pWeights, qWeights, tol)
|
||||
}
|
||||
|
||||
// normalizeWeights checks and normalizes the provided weight array if needed.
|
||||
func normalizeWeights(weights []float64) {
|
||||
var sum float64
|
||||
|
||||
for _, w := range weights {
|
||||
if w <= 0 {
|
||||
panic("stat: all weights must be positive")
|
||||
}
|
||||
sum += w
|
||||
}
|
||||
|
||||
floats.Scale(1/sum, weights)
|
||||
}
|
||||
|
||||
// solveOptimalTransportLP solves the optimal transport problem.
|
||||
func solveOptimalTransportLP(costMatrix *mat.Dense, supply, demand []float64, tol float64) (float64, error) {
|
||||
rows, cols := costMatrix.Dims()
|
||||
|
||||
// Formulate the linear program.
|
||||
c := costMatrix.RawMatrix().Data
|
||||
constraints := mat.NewDense(rows+cols-1, rows*cols, nil)
|
||||
|
||||
// Supply constraints (row sums) - all rows except the last one
|
||||
for i := 0; i < rows-1; i++ {
|
||||
for j := 0; j < cols; j++ {
|
||||
constraints.Set(i, i*cols+j, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
// Demand constraints (column sums)
|
||||
for j := 0; j < cols; j++ {
|
||||
for i := 0; i < rows; i++ {
|
||||
constraints.Set(rows-1+j, i*cols+j, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
// Right-hand side of constraints
|
||||
b := make([]float64, rows+cols-1)
|
||||
copy(b[:rows-1], supply[:rows-1])
|
||||
copy(b[rows-1:], demand)
|
||||
|
||||
// Solve the linear program.
|
||||
optVal, _, err := lp.Simplex(c, constraints, b, tol, nil)
|
||||
return optVal, err
|
||||
}
|
||||
|
||||
// bhattacharyyaCoeff computes the Bhattacharyya Coefficient for probability distributions given by:
|
||||
//
|
||||
// \sum_i \sqrt{p_i q_i}
|
||||
|
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"gonum.org/v1/gonum/floats"
|
||||
"gonum.org/v1/gonum/floats/scalar"
|
||||
"gonum.org/v1/gonum/mat"
|
||||
)
|
||||
|
||||
func ExampleCircularMean() {
|
||||
@@ -1887,3 +1888,262 @@ func TestStdScore(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWassersteinDistance(t *testing.T) {
|
||||
const tol = 1e-8
|
||||
|
||||
// sample cases were taken from scipy's documentation
|
||||
tests := []struct {
|
||||
name string
|
||||
p []float64
|
||||
q []float64
|
||||
pWeights []float64
|
||||
qWeights []float64
|
||||
want float64
|
||||
}{
|
||||
{
|
||||
name: "Example 1: Basic case with different distributions",
|
||||
p: []float64{0, 1, 3},
|
||||
q: []float64{5, 6, 8},
|
||||
pWeights: nil,
|
||||
qWeights: nil,
|
||||
want: 5.0,
|
||||
},
|
||||
{
|
||||
name: "Example 2: Same distributions with different weights",
|
||||
p: []float64{0, 1},
|
||||
q: []float64{0, 1},
|
||||
pWeights: []float64{3, 1},
|
||||
qWeights: []float64{2, 2},
|
||||
want: 0.25,
|
||||
},
|
||||
{
|
||||
name: "Example 3: Complex case with different sizes and weights",
|
||||
p: []float64{3.4, 3.9, 7.5, 7.8},
|
||||
q: []float64{4.5, 1.4},
|
||||
pWeights: []float64{1.4, 0.9, 3.1, 7.2},
|
||||
qWeights: []float64{3.2, 3.5},
|
||||
want: 4.078133143804786,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got := WassersteinDistance(test.p, test.q, test.pWeights, test.qWeights)
|
||||
if math.Abs(got-test.want) > tol {
|
||||
t.Errorf("WassersteinDistance() = %v, want %v", got, test.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWassersteinDistanceEdgeCases tests specific edge cases
|
||||
func TestWassersteinDistanceEdgeCases(t *testing.T) {
|
||||
// Test empty distribution.
|
||||
result := WassersteinDistance([]float64{}, []float64{1.0}, nil, nil)
|
||||
|
||||
if !math.IsNaN(result) {
|
||||
t.Errorf("WassersteinDistance() = %v, want %v", result, math.NaN())
|
||||
}
|
||||
}
|
||||
|
||||
// TestWassersteinDistanceWeightNormalization tests that weights are properly normalized
|
||||
func TestWassersteinDistanceWeightNormalization(t *testing.T) {
|
||||
const tol = 1e-8
|
||||
|
||||
p := []float64{1.0, 2.0, 3.0}
|
||||
q := []float64{2.0, 3.0, 4.0}
|
||||
|
||||
// Non-normalized weights that sum to different values
|
||||
pWeights := []float64{2.0, 3.0, 5.0} // Sum = 10
|
||||
qWeights := []float64{1.0, 1.0, 1.0} // Sum = 3
|
||||
|
||||
// Create normalized copies for reference
|
||||
pWeightsNorm := make([]float64, len(pWeights))
|
||||
qWeightsNorm := make([]float64, len(qWeights))
|
||||
copy(pWeightsNorm, pWeights)
|
||||
copy(qWeightsNorm, qWeights)
|
||||
floats.Scale(1.0/floats.Sum(pWeightsNorm), pWeightsNorm)
|
||||
floats.Scale(1.0/floats.Sum(qWeightsNorm), qWeightsNorm)
|
||||
|
||||
// Calculate with original and normalized weights
|
||||
result1 := WassersteinDistance(p, q, pWeights, qWeights)
|
||||
result2 := WassersteinDistance(p, q, pWeightsNorm, qWeightsNorm)
|
||||
|
||||
// Results should be the same
|
||||
if math.Abs(result1-result2) > tol {
|
||||
t.Errorf("Weight normalization failed: got %v and %v", result1, result2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWassersteinDistanceND(t *testing.T) {
|
||||
const tol = 1e-8
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
p *mat.Dense
|
||||
q *mat.Dense
|
||||
pWeights []float64
|
||||
qWeights []float64
|
||||
want float64
|
||||
}{
|
||||
{
|
||||
name: "Example 1: 2D with uniform weights",
|
||||
p: mat.NewDense(3, 2, []float64{
|
||||
0, 0,
|
||||
1, 1,
|
||||
2, 2,
|
||||
}),
|
||||
q: mat.NewDense(3, 2, []float64{
|
||||
0, 1,
|
||||
1, 2,
|
||||
2, 3,
|
||||
}),
|
||||
pWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3},
|
||||
qWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3},
|
||||
want: 1.0,
|
||||
},
|
||||
{
|
||||
name: "Example 2: 2D with different sizes",
|
||||
p: mat.NewDense(2, 2, []float64{
|
||||
0, 0,
|
||||
1, 1,
|
||||
}),
|
||||
q: mat.NewDense(3, 2, []float64{
|
||||
0, 1,
|
||||
1, 2,
|
||||
2, 3,
|
||||
}),
|
||||
pWeights: []float64{0.5, 0.5},
|
||||
qWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3},
|
||||
want: 1.618033988749895,
|
||||
},
|
||||
{
|
||||
name: "Example 3: 2D with custom weights",
|
||||
p: mat.NewDense(3, 2, []float64{
|
||||
0, 0,
|
||||
1, 1,
|
||||
2, 2,
|
||||
}),
|
||||
q: mat.NewDense(3, 2, []float64{
|
||||
0, 0,
|
||||
1, 1,
|
||||
2, 2,
|
||||
}),
|
||||
pWeights: []float64{0.2, 0.3, 0.5},
|
||||
qWeights: []float64{0.5, 0.3, 0.2},
|
||||
want: 0.848528137423857,
|
||||
},
|
||||
{
|
||||
name: "Example 4: 3D with uniform weights",
|
||||
p: mat.NewDense(3, 3, []float64{
|
||||
0, 0, 0,
|
||||
1, 1, 1,
|
||||
2, 2, 2,
|
||||
}),
|
||||
q: mat.NewDense(3, 3, []float64{
|
||||
0, 1, 0,
|
||||
1, 2, 1,
|
||||
2, 3, 2,
|
||||
}),
|
||||
pWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3},
|
||||
qWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3},
|
||||
want: 1.0,
|
||||
},
|
||||
{
|
||||
name: "Example 5: Single points",
|
||||
p: mat.NewDense(1, 2, []float64{
|
||||
1, 2,
|
||||
}),
|
||||
q: mat.NewDense(1, 2, []float64{
|
||||
3, 4,
|
||||
}),
|
||||
pWeights: []float64{1.0},
|
||||
qWeights: []float64{1.0},
|
||||
want: 2.8284271247461903,
|
||||
},
|
||||
{
|
||||
name: "Example 6: Identical distributions",
|
||||
p: mat.NewDense(2, 2, []float64{
|
||||
1, 2,
|
||||
3, 4,
|
||||
}),
|
||||
q: mat.NewDense(2, 2, []float64{
|
||||
1, 2,
|
||||
3, 4,
|
||||
}),
|
||||
pWeights: nil,
|
||||
qWeights: nil,
|
||||
want: 0.0,
|
||||
},
|
||||
{
|
||||
name: "Example 7: 3D points from SciPy docs",
|
||||
p: mat.NewDense(2, 3, []float64{
|
||||
0, 2, 3,
|
||||
1, 2, 5,
|
||||
}),
|
||||
q: mat.NewDense(2, 3, []float64{
|
||||
3, 2, 3,
|
||||
4, 2, 5,
|
||||
}),
|
||||
pWeights: nil,
|
||||
qWeights: nil,
|
||||
want: 3.0,
|
||||
},
|
||||
{
|
||||
name: "Example 8: 2D with custom weights from SciPy docs",
|
||||
p: mat.NewDense(3, 2, []float64{
|
||||
0, 2.75,
|
||||
2, 209.3,
|
||||
0, 0,
|
||||
}),
|
||||
q: mat.NewDense(2, 2, []float64{
|
||||
0.2, 0.322,
|
||||
4.5, 25.1808,
|
||||
}),
|
||||
pWeights: []float64{0.4, 5.2, 0.114},
|
||||
qWeights: []float64{0.8, 1.5},
|
||||
want: 174.15840245217169,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got, err := WassersteinDistanceND(test.p, test.q, test.pWeights, test.qWeights, tol)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("WassersteinDistanceND() error: %v", err)
|
||||
}
|
||||
|
||||
if math.Abs(got-test.want) > tol {
|
||||
t.Errorf("WassersteinDistanceND() = %v, want %v", got, test.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeWeights(t *testing.T) {
|
||||
t.Run("panics on zero weight", func(t *testing.T) {
|
||||
weights := []float64{1.0, 0.0, 3.0}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("expected panic on zero weight, but got none")
|
||||
}
|
||||
}()
|
||||
|
||||
normalizeWeights(weights)
|
||||
})
|
||||
|
||||
t.Run("panics on negative weight", func(t *testing.T) {
|
||||
weights := []float64{1.0, -2.0, 3.0}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("expected panic on negative weight, but got none")
|
||||
}
|
||||
}()
|
||||
|
||||
normalizeWeights(weights)
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user