stat: implement Wasserstein distance calculation

This commit is contained in:
Shieldine
2025-04-23 00:05:14 +02:00
committed by GitHub
parent 4408afacd1
commit 672aa59ec6
2 changed files with 532 additions and 0 deletions

View File

@@ -9,6 +9,8 @@ import (
"sort"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/optimize/convex/lp"
)
// CumulantKind specifies the behavior for calculating the empirical CDF or Quantile
@@ -24,6 +26,276 @@ const (
LinInterp CumulantKind = 4
)
// WassersteinDistance computes the Wasserstein distance (Earth Mover's Distance)
// between two 1D distributions p and q with optional weights. p and q are
// the support points of each distribution. pWeights and qWeights are the weights
// for each point.
// If a weights slice is nil, then uniform weights are used. If it is not nil,
// then the length of the weights slice must equal the length of the corresponding points.
// Otherwise, the function will panic.
//
// The function returns the 1-Wasserstein distance (L1 metric).
// This implementation uses the CDF-based algorithm for the 1D case.
func WassersteinDistance(p, q, pWeights, qWeights []float64) float64 {
if len(p) == 0 || len(q) == 0 {
return math.NaN()
}
if (pWeights != nil && len(p) != len(pWeights)) || (qWeights != nil && len(q) != len(qWeights)) {
panic("stat: input distributions and their weights must have same length")
}
// Special case for single-point distributions
if len(p) == 1 && len(q) == 1 {
return math.Abs(p[0] - q[0])
}
// Handle identical distributions case.
if floats.Equal(p, q) && floats.Equal(pWeights, qWeights) {
return 0
}
// Use uniform weights if not provided.
pUniform := pWeights == nil
qUniform := qWeights == nil
var pUniformWeight, qUniformWeight float64
if pUniform {
pUniformWeight = 1 / float64(len(p))
} else {
normalizeWeights(pWeights)
}
if qUniform {
qUniformWeight = 1 / float64(len(q))
} else {
normalizeWeights(qWeights)
}
// Pair values with weights and sort.
type pair struct {
value float64
weight float64
}
pPairs := make([]pair, len(p))
qPairs := make([]pair, len(q))
for i := range p {
if pUniform {
pPairs[i] = pair{p[i], pUniformWeight}
} else {
pPairs[i] = pair{p[i], pWeights[i]}
}
}
for i := range q {
if qUniform {
qPairs[i] = pair{q[i], qUniformWeight}
} else {
qPairs[i] = pair{q[i], qWeights[i]}
}
}
sort.Slice(pPairs, func(i, j int) bool { return pPairs[i].value < pPairs[j].value })
sort.Slice(qPairs, func(i, j int) bool { return qPairs[i].value < qPairs[j].value })
// Compute CDFs.
pCDF := make([]float64, len(pPairs))
qCDF := make([]float64, len(qPairs))
pCDF[0] = pPairs[0].weight
for i := 1; i < len(pPairs); i++ {
pCDF[i] = pCDF[i-1] + pPairs[i].weight
}
qCDF[0] = qPairs[0].weight
for i := 1; i < len(qPairs); i++ {
qCDF[i] = qCDF[i-1] + qPairs[i].weight
}
// Merge the support points.
allPoints := make([]float64, 0, len(pPairs)+len(qPairs))
for _, pair := range pPairs {
allPoints = append(allPoints, pair.value)
}
for _, pair := range qPairs {
allPoints = append(allPoints, pair.value)
}
sort.Float64s(allPoints)
// Remove duplicates.
uniquePoints := make([]float64, 0, len(allPoints))
for i := 0; i < len(allPoints); i++ {
if i == 0 || allPoints[i] != allPoints[i-1] {
uniquePoints = append(uniquePoints, allPoints[i])
}
}
// Calculate the Wasserstein distance.
var distance float64
for i := 0; i < len(uniquePoints)-1; i++ {
x1 := uniquePoints[i]
x2 := uniquePoints[i+1]
// Find CDF values at x1.
var pCDFAtX1 float64
for j := 0; j < len(pPairs); j++ {
if pPairs[j].value > x1 {
break
}
pCDFAtX1 = pCDF[j]
}
var qCDFAtX1 float64
for j := 0; j < len(qPairs); j++ {
if qPairs[j].value > x1 {
break
}
qCDFAtX1 = qCDF[j]
}
distance += float64(math.Abs(pCDFAtX1-qCDFAtX1) * (x2 - x1))
}
return distance
}
// WassersteinDistanceND computes the Wasserstein distance (Earth Mover's Distance)
// between two n-dimensional distributions p and q with optional weights. p and q are
// matrices where each row represents a point. pWeights and qWeights are the weights for each point.
// The number of columns in p and q must be equal.
// If a weights slice is nil, then uniform weights are used. If it is not nil, then the
// length of the weights slice must equal the number of rows in the corresponding matrix.
// WassersteinDistanceND will panic if these conditions are not met.
//
// This implementation uses linear programming to solve the optimal transport problem.
// tol controls the solver's tolerance. See lp.Simplex for more information on this.
func WassersteinDistanceND(p, q mat.Matrix, pWeights, qWeights []float64, tol float64) (float64, error) {
pRows, pCols := p.Dims()
qRows, qCols := q.Dims()
if pCols == 0 || qCols == 0 || pRows == 0 || qRows == 0 {
return math.NaN(), nil
}
if (pWeights != nil && pRows != len(pWeights)) || (qWeights != nil && qRows != len(qWeights)) {
panic("stat: input distributions and their weights must have same length")
}
if pCols != qCols {
panic("stat: point dimensions must match between distributions")
}
// Special case for single-point distributions
if pRows == 1 && qRows == 1 {
var sumSquares float64
for j := 0; j < pCols; j++ {
diff := p.At(0, j) - q.At(0, j)
sumSquares += float64(diff * diff)
}
return math.Sqrt(sumSquares), nil
}
// Handle identical distributions case.
if pRows == qRows {
identical := true
for i := 0; i < pRows; i++ {
for j := 0; j < pCols; j++ {
if p.At(i, j) != q.At(i, j) {
identical = false
break
}
}
}
if identical {
if floats.Equal(pWeights, qWeights) {
return 0, nil
}
}
}
// Use uniform weights if not provided.
if pWeights == nil {
pWeights = make([]float64, pRows)
floats.AddConst(1/float64(pRows), pWeights)
} else {
normalizeWeights(pWeights)
}
if qWeights == nil {
qWeights = make([]float64, qRows)
floats.AddConst(1/float64(qRows), qWeights)
} else {
normalizeWeights(qWeights)
}
// Create cost matrix using Euclidean distance.
cost := mat.NewDense(pRows, qRows, nil)
for i := 0; i < pRows; i++ {
for j := 0; j < qRows; j++ {
// Calculate Euclidean distance between points
var sumSquares float64
for k := 0; k < pCols; k++ {
diff := p.At(i, k) - q.At(j, k)
sumSquares += float64(diff * diff)
}
dist := math.Sqrt(sumSquares)
cost.Set(i, j, dist)
}
}
return solveOptimalTransportLP(cost, pWeights, qWeights, tol)
}
// normalizeWeights checks and normalizes the provided weight array if needed.
func normalizeWeights(weights []float64) {
var sum float64
for _, w := range weights {
if w <= 0 {
panic("stat: all weights must be positive")
}
sum += w
}
floats.Scale(1/sum, weights)
}
// solveOptimalTransportLP solves the optimal transport problem.
func solveOptimalTransportLP(costMatrix *mat.Dense, supply, demand []float64, tol float64) (float64, error) {
rows, cols := costMatrix.Dims()
// Formulate the linear program.
c := costMatrix.RawMatrix().Data
constraints := mat.NewDense(rows+cols-1, rows*cols, nil)
// Supply constraints (row sums) - all rows except the last one
for i := 0; i < rows-1; i++ {
for j := 0; j < cols; j++ {
constraints.Set(i, i*cols+j, 1.0)
}
}
// Demand constraints (column sums)
for j := 0; j < cols; j++ {
for i := 0; i < rows; i++ {
constraints.Set(rows-1+j, i*cols+j, 1.0)
}
}
// Right-hand side of constraints
b := make([]float64, rows+cols-1)
copy(b[:rows-1], supply[:rows-1])
copy(b[rows-1:], demand)
// Solve the linear program.
optVal, _, err := lp.Simplex(c, constraints, b, tol, nil)
return optVal, err
}
// bhattacharyyaCoeff computes the Bhattacharyya Coefficient for probability distributions given by:
//
// \sum_i \sqrt{p_i q_i}

View File

@@ -14,6 +14,7 @@ import (
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/floats/scalar"
"gonum.org/v1/gonum/mat"
)
func ExampleCircularMean() {
@@ -1887,3 +1888,262 @@ func TestStdScore(t *testing.T) {
}
}
}
func TestWassersteinDistance(t *testing.T) {
const tol = 1e-8
// sample cases were taken from scipy's documentation
tests := []struct {
name string
p []float64
q []float64
pWeights []float64
qWeights []float64
want float64
}{
{
name: "Example 1: Basic case with different distributions",
p: []float64{0, 1, 3},
q: []float64{5, 6, 8},
pWeights: nil,
qWeights: nil,
want: 5.0,
},
{
name: "Example 2: Same distributions with different weights",
p: []float64{0, 1},
q: []float64{0, 1},
pWeights: []float64{3, 1},
qWeights: []float64{2, 2},
want: 0.25,
},
{
name: "Example 3: Complex case with different sizes and weights",
p: []float64{3.4, 3.9, 7.5, 7.8},
q: []float64{4.5, 1.4},
pWeights: []float64{1.4, 0.9, 3.1, 7.2},
qWeights: []float64{3.2, 3.5},
want: 4.078133143804786,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := WassersteinDistance(test.p, test.q, test.pWeights, test.qWeights)
if math.Abs(got-test.want) > tol {
t.Errorf("WassersteinDistance() = %v, want %v", got, test.want)
}
})
}
}
// TestWassersteinDistanceEdgeCases tests specific edge cases
func TestWassersteinDistanceEdgeCases(t *testing.T) {
// Test empty distribution.
result := WassersteinDistance([]float64{}, []float64{1.0}, nil, nil)
if !math.IsNaN(result) {
t.Errorf("WassersteinDistance() = %v, want %v", result, math.NaN())
}
}
// TestWassersteinDistanceWeightNormalization tests that weights are properly normalized
func TestWassersteinDistanceWeightNormalization(t *testing.T) {
const tol = 1e-8
p := []float64{1.0, 2.0, 3.0}
q := []float64{2.0, 3.0, 4.0}
// Non-normalized weights that sum to different values
pWeights := []float64{2.0, 3.0, 5.0} // Sum = 10
qWeights := []float64{1.0, 1.0, 1.0} // Sum = 3
// Create normalized copies for reference
pWeightsNorm := make([]float64, len(pWeights))
qWeightsNorm := make([]float64, len(qWeights))
copy(pWeightsNorm, pWeights)
copy(qWeightsNorm, qWeights)
floats.Scale(1.0/floats.Sum(pWeightsNorm), pWeightsNorm)
floats.Scale(1.0/floats.Sum(qWeightsNorm), qWeightsNorm)
// Calculate with original and normalized weights
result1 := WassersteinDistance(p, q, pWeights, qWeights)
result2 := WassersteinDistance(p, q, pWeightsNorm, qWeightsNorm)
// Results should be the same
if math.Abs(result1-result2) > tol {
t.Errorf("Weight normalization failed: got %v and %v", result1, result2)
}
}
func TestWassersteinDistanceND(t *testing.T) {
const tol = 1e-8
tests := []struct {
name string
p *mat.Dense
q *mat.Dense
pWeights []float64
qWeights []float64
want float64
}{
{
name: "Example 1: 2D with uniform weights",
p: mat.NewDense(3, 2, []float64{
0, 0,
1, 1,
2, 2,
}),
q: mat.NewDense(3, 2, []float64{
0, 1,
1, 2,
2, 3,
}),
pWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3},
qWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3},
want: 1.0,
},
{
name: "Example 2: 2D with different sizes",
p: mat.NewDense(2, 2, []float64{
0, 0,
1, 1,
}),
q: mat.NewDense(3, 2, []float64{
0, 1,
1, 2,
2, 3,
}),
pWeights: []float64{0.5, 0.5},
qWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3},
want: 1.618033988749895,
},
{
name: "Example 3: 2D with custom weights",
p: mat.NewDense(3, 2, []float64{
0, 0,
1, 1,
2, 2,
}),
q: mat.NewDense(3, 2, []float64{
0, 0,
1, 1,
2, 2,
}),
pWeights: []float64{0.2, 0.3, 0.5},
qWeights: []float64{0.5, 0.3, 0.2},
want: 0.848528137423857,
},
{
name: "Example 4: 3D with uniform weights",
p: mat.NewDense(3, 3, []float64{
0, 0, 0,
1, 1, 1,
2, 2, 2,
}),
q: mat.NewDense(3, 3, []float64{
0, 1, 0,
1, 2, 1,
2, 3, 2,
}),
pWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3},
qWeights: []float64{1.0 / 3, 1.0 / 3, 1.0 / 3},
want: 1.0,
},
{
name: "Example 5: Single points",
p: mat.NewDense(1, 2, []float64{
1, 2,
}),
q: mat.NewDense(1, 2, []float64{
3, 4,
}),
pWeights: []float64{1.0},
qWeights: []float64{1.0},
want: 2.8284271247461903,
},
{
name: "Example 6: Identical distributions",
p: mat.NewDense(2, 2, []float64{
1, 2,
3, 4,
}),
q: mat.NewDense(2, 2, []float64{
1, 2,
3, 4,
}),
pWeights: nil,
qWeights: nil,
want: 0.0,
},
{
name: "Example 7: 3D points from SciPy docs",
p: mat.NewDense(2, 3, []float64{
0, 2, 3,
1, 2, 5,
}),
q: mat.NewDense(2, 3, []float64{
3, 2, 3,
4, 2, 5,
}),
pWeights: nil,
qWeights: nil,
want: 3.0,
},
{
name: "Example 8: 2D with custom weights from SciPy docs",
p: mat.NewDense(3, 2, []float64{
0, 2.75,
2, 209.3,
0, 0,
}),
q: mat.NewDense(2, 2, []float64{
0.2, 0.322,
4.5, 25.1808,
}),
pWeights: []float64{0.4, 5.2, 0.114},
qWeights: []float64{0.8, 1.5},
want: 174.15840245217169,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got, err := WassersteinDistanceND(test.p, test.q, test.pWeights, test.qWeights, tol)
if err != nil {
t.Errorf("WassersteinDistanceND() error: %v", err)
}
if math.Abs(got-test.want) > tol {
t.Errorf("WassersteinDistanceND() = %v, want %v", got, test.want)
}
})
}
}
func TestNormalizeWeights(t *testing.T) {
t.Run("panics on zero weight", func(t *testing.T) {
weights := []float64{1.0, 0.0, 3.0}
defer func() {
if r := recover(); r == nil {
t.Errorf("expected panic on zero weight, but got none")
}
}()
normalizeWeights(weights)
})
t.Run("panics on negative weight", func(t *testing.T) {
weights := []float64{1.0, -2.0, 3.0}
defer func() {
if r := recover(); r == nil {
t.Errorf("expected panic on negative weight, but got none")
}
}()
normalizeWeights(weights)
})
}