From 672aa59ec6a2b1611762b320f705781a1b61d65b Mon Sep 17 00:00:00 2001 From: Shieldine <74987363+Shieldine@users.noreply.github.com> Date: Wed, 23 Apr 2025 00:05:14 +0200 Subject: [PATCH] stat: implement Wasserstein distance calculation --- stat/stat.go | 272 ++++++++++++++++++++++++++++++++++++++++++++++ stat/stat_test.go | 260 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 532 insertions(+) diff --git a/stat/stat.go b/stat/stat.go index f7d43726..7b088a29 100644 --- a/stat/stat.go +++ b/stat/stat.go @@ -9,6 +9,8 @@ import ( "sort" "gonum.org/v1/gonum/floats" + "gonum.org/v1/gonum/mat" + "gonum.org/v1/gonum/optimize/convex/lp" ) // CumulantKind specifies the behavior for calculating the empirical CDF or Quantile @@ -24,6 +26,276 @@ const ( LinInterp CumulantKind = 4 ) +// WassersteinDistance computes the Wasserstein distance (Earth Mover's Distance) +// between two 1D distributions p and q with optional weights. p and q are +// the support points of each distribution. pWeights and qWeights are the weights +// for each point. +// If a weights slice is nil, then uniform weights are used. If it is not nil, +// then the length of the weights slice must equal the length of the corresponding points. +// Otherwise, the function will panic. +// +// The function returns the 1-Wasserstein distance (L1 metric). +// This implementation uses the CDF-based algorithm for the 1D case. +func WassersteinDistance(p, q, pWeights, qWeights []float64) float64 { + if len(p) == 0 || len(q) == 0 { + return math.NaN() + } + + if (pWeights != nil && len(p) != len(pWeights)) || (qWeights != nil && len(q) != len(qWeights)) { + panic("stat: input distributions and their weights must have same length") + } + + // Special case for single-point distributions + if len(p) == 1 && len(q) == 1 { + return math.Abs(p[0] - q[0]) + } + + // Handle identical distributions case. + if floats.Equal(p, q) && floats.Equal(pWeights, qWeights) { + return 0 + } + + // Use uniform weights if not provided. + pUniform := pWeights == nil + qUniform := qWeights == nil + + var pUniformWeight, qUniformWeight float64 + + if pUniform { + pUniformWeight = 1 / float64(len(p)) + } else { + normalizeWeights(pWeights) + } + + if qUniform { + qUniformWeight = 1 / float64(len(q)) + } else { + normalizeWeights(qWeights) + } + + // Pair values with weights and sort. + type pair struct { + value float64 + weight float64 + } + + pPairs := make([]pair, len(p)) + qPairs := make([]pair, len(q)) + for i := range p { + if pUniform { + pPairs[i] = pair{p[i], pUniformWeight} + } else { + pPairs[i] = pair{p[i], pWeights[i]} + } + } + for i := range q { + if qUniform { + qPairs[i] = pair{q[i], qUniformWeight} + } else { + qPairs[i] = pair{q[i], qWeights[i]} + } + } + + sort.Slice(pPairs, func(i, j int) bool { return pPairs[i].value < pPairs[j].value }) + sort.Slice(qPairs, func(i, j int) bool { return qPairs[i].value < qPairs[j].value }) + + // Compute CDFs. + pCDF := make([]float64, len(pPairs)) + qCDF := make([]float64, len(qPairs)) + + pCDF[0] = pPairs[0].weight + for i := 1; i < len(pPairs); i++ { + pCDF[i] = pCDF[i-1] + pPairs[i].weight + } + + qCDF[0] = qPairs[0].weight + for i := 1; i < len(qPairs); i++ { + qCDF[i] = qCDF[i-1] + qPairs[i].weight + } + + // Merge the support points. + allPoints := make([]float64, 0, len(pPairs)+len(qPairs)) + for _, pair := range pPairs { + allPoints = append(allPoints, pair.value) + } + for _, pair := range qPairs { + allPoints = append(allPoints, pair.value) + } + + sort.Float64s(allPoints) + + // Remove duplicates. + uniquePoints := make([]float64, 0, len(allPoints)) + for i := 0; i < len(allPoints); i++ { + if i == 0 || allPoints[i] != allPoints[i-1] { + uniquePoints = append(uniquePoints, allPoints[i]) + } + } + + // Calculate the Wasserstein distance. + var distance float64 + + for i := 0; i < len(uniquePoints)-1; i++ { + x1 := uniquePoints[i] + x2 := uniquePoints[i+1] + + // Find CDF values at x1. + var pCDFAtX1 float64 + for j := 0; j < len(pPairs); j++ { + if pPairs[j].value > x1 { + break + } + pCDFAtX1 = pCDF[j] + } + + var qCDFAtX1 float64 + for j := 0; j < len(qPairs); j++ { + if qPairs[j].value > x1 { + break + } + qCDFAtX1 = qCDF[j] + } + + distance += float64(math.Abs(pCDFAtX1-qCDFAtX1) * (x2 - x1)) + } + + return distance +} + +// WassersteinDistanceND computes the Wasserstein distance (Earth Mover's Distance) +// between two n-dimensional distributions p and q with optional weights. p and q are +// matrices where each row represents a point. pWeights and qWeights are the weights for each point. +// The number of columns in p and q must be equal. +// If a weights slice is nil, then uniform weights are used. If it is not nil, then the +// length of the weights slice must equal the number of rows in the corresponding matrix. +// WassersteinDistanceND will panic if these conditions are not met. +// +// This implementation uses linear programming to solve the optimal transport problem. +// tol controls the solver's tolerance. See lp.Simplex for more information on this. +func WassersteinDistanceND(p, q mat.Matrix, pWeights, qWeights []float64, tol float64) (float64, error) { + pRows, pCols := p.Dims() + qRows, qCols := q.Dims() + + if pCols == 0 || qCols == 0 || pRows == 0 || qRows == 0 { + return math.NaN(), nil + } + + if (pWeights != nil && pRows != len(pWeights)) || (qWeights != nil && qRows != len(qWeights)) { + panic("stat: input distributions and their weights must have same length") + } + + if pCols != qCols { + panic("stat: point dimensions must match between distributions") + } + + // Special case for single-point distributions + if pRows == 1 && qRows == 1 { + var sumSquares float64 + for j := 0; j < pCols; j++ { + diff := p.At(0, j) - q.At(0, j) + sumSquares += float64(diff * diff) + } + return math.Sqrt(sumSquares), nil + } + + // Handle identical distributions case. + if pRows == qRows { + identical := true + for i := 0; i < pRows; i++ { + for j := 0; j < pCols; j++ { + if p.At(i, j) != q.At(i, j) { + identical = false + break + } + } + } + if identical { + if floats.Equal(pWeights, qWeights) { + return 0, nil + } + } + } + + // Use uniform weights if not provided. + if pWeights == nil { + pWeights = make([]float64, pRows) + floats.AddConst(1/float64(pRows), pWeights) + } else { + normalizeWeights(pWeights) + } + + if qWeights == nil { + qWeights = make([]float64, qRows) + floats.AddConst(1/float64(qRows), qWeights) + } else { + normalizeWeights(qWeights) + } + + // Create cost matrix using Euclidean distance. + cost := mat.NewDense(pRows, qRows, nil) + + for i := 0; i < pRows; i++ { + for j := 0; j < qRows; j++ { + // Calculate Euclidean distance between points + var sumSquares float64 + for k := 0; k < pCols; k++ { + diff := p.At(i, k) - q.At(j, k) + sumSquares += float64(diff * diff) + } + dist := math.Sqrt(sumSquares) + cost.Set(i, j, dist) + } + } + + return solveOptimalTransportLP(cost, pWeights, qWeights, tol) +} + +// normalizeWeights checks and normalizes the provided weight array if needed. +func normalizeWeights(weights []float64) { + var sum float64 + + for _, w := range weights { + if w <= 0 { + panic("stat: all weights must be positive") + } + sum += w + } + + floats.Scale(1/sum, weights) +} + +// solveOptimalTransportLP solves the optimal transport problem. +func solveOptimalTransportLP(costMatrix *mat.Dense, supply, demand []float64, tol float64) (float64, error) { + rows, cols := costMatrix.Dims() + + // Formulate the linear program. + c := costMatrix.RawMatrix().Data + constraints := mat.NewDense(rows+cols-1, rows*cols, nil) + + // Supply constraints (row sums) - all rows except the last one + for i := 0; i < rows-1; i++ { + for j := 0; j < cols; j++ { + constraints.Set(i, i*cols+j, 1.0) + } + } + + // Demand constraints (column sums) + for j := 0; j < cols; j++ { + for i := 0; i < rows; i++ { + constraints.Set(rows-1+j, i*cols+j, 1.0) + } + } + + // Right-hand side of constraints + b := make([]float64, rows+cols-1) + copy(b[:rows-1], supply[:rows-1]) + copy(b[rows-1:], demand) + + // Solve the linear program. + optVal, _, err := lp.Simplex(c, constraints, b, tol, nil) + return optVal, err +} + // bhattacharyyaCoeff computes the Bhattacharyya Coefficient for probability distributions given by: // // \sum_i \sqrt{p_i q_i} diff --git a/stat/stat_test.go b/stat/stat_test.go index 7e31ef33..5321423f 100644 --- a/stat/stat_test.go +++ b/stat/stat_test.go @@ -14,6 +14,7 @@ import ( "gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/floats/scalar" + "gonum.org/v1/gonum/mat" ) func ExampleCircularMean() { @@ -1887,3 +1888,262 @@ func TestStdScore(t *testing.T) { } } } + +func TestWassersteinDistance(t *testing.T) { + const tol = 1e-8 + + // sample cases were taken from scipy's documentation + tests := []struct { + name string + p []float64 + q []float64 + pWeights []float64 + qWeights []float64 + want float64 + }{ + { + name: "Example 1: Basic case with different distributions", + p: []float64{0, 1, 3}, + q: []float64{5, 6, 8}, + pWeights: nil, + qWeights: nil, + want: 5.0, + }, + { + name: "Example 2: Same distributions with different weights", + p: []float64{0, 1}, + q: []float64{0, 1}, + pWeights: []float64{3, 1}, + qWeights: []float64{2, 2}, + want: 0.25, + }, + { + name: "Example 3: Complex case with different sizes and weights", + p: []float64{3.4, 3.9, 7.5, 7.8}, + q: []float64{4.5, 1.4}, + pWeights: []float64{1.4, 0.9, 3.1, 7.2}, + qWeights: []float64{3.2, 3.5}, + want: 4.078133143804786, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := WassersteinDistance(test.p, test.q, test.pWeights, test.qWeights) + if math.Abs(got-test.want) > tol { + t.Errorf("WassersteinDistance() = %v, want %v", got, test.want) + } + }) + } +} + +// TestWassersteinDistanceEdgeCases tests specific edge cases +func TestWassersteinDistanceEdgeCases(t *testing.T) { + // Test empty distribution. + result := WassersteinDistance([]float64{}, []float64{1.0}, nil, nil) + + if !math.IsNaN(result) { + t.Errorf("WassersteinDistance() = %v, want %v", result, math.NaN()) + } +} + +// TestWassersteinDistanceWeightNormalization tests that weights are properly normalized +func TestWassersteinDistanceWeightNormalization(t *testing.T) { + const tol = 1e-8 + + p := []float64{1.0, 2.0, 3.0} + q := []float64{2.0, 3.0, 4.0} + + // Non-normalized weights that sum to different values + pWeights := []float64{2.0, 3.0, 5.0} // Sum = 10 + qWeights := []float64{1.0, 1.0, 1.0} // Sum = 3 + + // Create normalized copies for reference + pWeightsNorm := make([]float64, len(pWeights)) + qWeightsNorm := make([]float64, len(qWeights)) + copy(pWeightsNorm, pWeights) + copy(qWeightsNorm, qWeights) + floats.Scale(1.0/floats.Sum(pWeightsNorm), pWeightsNorm) + floats.Scale(1.0/floats.Sum(qWeightsNorm), qWeightsNorm) + + // Calculate with original and normalized weights + result1 := WassersteinDistance(p, q, pWeights, qWeights) + result2 := WassersteinDistance(p, q, pWeightsNorm, qWeightsNorm) + + // Results should be the same + if math.Abs(result1-result2) > tol { + t.Errorf("Weight normalization failed: got %v and %v", result1, result2) + } +} + +func TestWassersteinDistanceND(t *testing.T) { + const tol = 1e-8 + + tests := []struct { + name string + p *mat.Dense + q *mat.Dense + pWeights []float64 + qWeights []float64 + want float64 + }{ + { + name: "Example 1: 2D with uniform weights", + p: mat.NewDense(3, 2, []float64{ + 0, 0, + 1, 1, + 2, 2, + }), + q: mat.NewDense(3, 2, []float64{ + 0, 1, + 1, 2, + 2, 3, + }), + pWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3}, + qWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3}, + want: 1.0, + }, + { + name: "Example 2: 2D with different sizes", + p: mat.NewDense(2, 2, []float64{ + 0, 0, + 1, 1, + }), + q: mat.NewDense(3, 2, []float64{ + 0, 1, + 1, 2, + 2, 3, + }), + pWeights: []float64{0.5, 0.5}, + qWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3}, + want: 1.618033988749895, + }, + { + name: "Example 3: 2D with custom weights", + p: mat.NewDense(3, 2, []float64{ + 0, 0, + 1, 1, + 2, 2, + }), + q: mat.NewDense(3, 2, []float64{ + 0, 0, + 1, 1, + 2, 2, + }), + pWeights: []float64{0.2, 0.3, 0.5}, + qWeights: []float64{0.5, 0.3, 0.2}, + want: 0.848528137423857, + }, + { + name: "Example 4: 3D with uniform weights", + p: mat.NewDense(3, 3, []float64{ + 0, 0, 0, + 1, 1, 1, + 2, 2, 2, + }), + q: mat.NewDense(3, 3, []float64{ + 0, 1, 0, + 1, 2, 1, + 2, 3, 2, + }), + pWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3}, + qWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3}, + want: 1.0, + }, + { + name: "Example 5: Single points", + p: mat.NewDense(1, 2, []float64{ + 1, 2, + }), + q: mat.NewDense(1, 2, []float64{ + 3, 4, + }), + pWeights: []float64{1.0}, + qWeights: []float64{1.0}, + want: 2.8284271247461903, + }, + { + name: "Example 6: Identical distributions", + p: mat.NewDense(2, 2, []float64{ + 1, 2, + 3, 4, + }), + q: mat.NewDense(2, 2, []float64{ + 1, 2, + 3, 4, + }), + pWeights: nil, + qWeights: nil, + want: 0.0, + }, + { + name: "Example 7: 3D points from SciPy docs", + p: mat.NewDense(2, 3, []float64{ + 0, 2, 3, + 1, 2, 5, + }), + q: mat.NewDense(2, 3, []float64{ + 3, 2, 3, + 4, 2, 5, + }), + pWeights: nil, + qWeights: nil, + want: 3.0, + }, + { + name: "Example 8: 2D with custom weights from SciPy docs", + p: mat.NewDense(3, 2, []float64{ + 0, 2.75, + 2, 209.3, + 0, 0, + }), + q: mat.NewDense(2, 2, []float64{ + 0.2, 0.322, + 4.5, 25.1808, + }), + pWeights: []float64{0.4, 5.2, 0.114}, + qWeights: []float64{0.8, 1.5}, + want: 174.15840245217169, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := WassersteinDistanceND(test.p, test.q, test.pWeights, test.qWeights, tol) + + if err != nil { + t.Errorf("WassersteinDistanceND() error: %v", err) + } + + if math.Abs(got-test.want) > tol { + t.Errorf("WassersteinDistanceND() = %v, want %v", got, test.want) + } + }) + } +} + +func TestNormalizeWeights(t *testing.T) { + t.Run("panics on zero weight", func(t *testing.T) { + weights := []float64{1.0, 0.0, 3.0} + + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic on zero weight, but got none") + } + }() + + normalizeWeights(weights) + }) + + t.Run("panics on negative weight", func(t *testing.T) { + weights := []float64{1.0, -2.0, 3.0} + + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic on negative weight, but got none") + } + }() + + normalizeWeights(weights) + }) +}