From a377c62d9020ccadb2137782013f1e068151cb15 Mon Sep 17 00:00:00 2001 From: kortschak Date: Thu, 28 Jul 2016 16:27:03 +0930 Subject: [PATCH] stat: add simple linear regression --- faithful_test.go | 81 ++++++++++++++++++++++++++++ stat.go | 135 +++++++++++++++++++++++++++++++++++++++++++++++ stat_test.go | 105 ++++++++++++++++++++++++++++++++++++ 3 files changed, 321 insertions(+) create mode 100644 faithful_test.go diff --git a/faithful_test.go b/faithful_test.go new file mode 100644 index 00000000..981eff1c --- /dev/null +++ b/faithful_test.go @@ -0,0 +1,81 @@ +// Copyright ©2016 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stat + +// faithful is the faithful data set from R. +var faithful = struct{ waiting, eruptions []float64 }{ + waiting: []float64{ + 79, 54, 74, 62, 85, 55, 88, 85, + 51, 85, 54, 84, 78, 47, 83, 52, + 62, 84, 52, 79, 51, 47, 78, 69, + 74, 83, 55, 76, 78, 79, 73, 77, + 66, 80, 74, 52, 48, 80, 59, 90, + 80, 58, 84, 58, 73, 83, 64, 53, + 82, 59, 75, 90, 54, 80, 54, 83, + 71, 64, 77, 81, 59, 84, 48, 82, + 60, 92, 78, 78, 65, 73, 82, 56, + 79, 71, 62, 76, 60, 78, 76, 83, + 75, 82, 70, 65, 73, 88, 76, 80, + 48, 86, 60, 90, 50, 78, 63, 72, + 84, 75, 51, 82, 62, 88, 49, 83, + 81, 47, 84, 52, 86, 81, 75, 59, + 89, 79, 59, 81, 50, 85, 59, 87, + 53, 69, 77, 56, 88, 81, 45, 82, + 55, 90, 45, 83, 56, 89, 46, 82, + 51, 86, 53, 79, 81, 60, 82, 77, + 76, 59, 80, 49, 96, 53, 77, 77, + 65, 81, 71, 70, 81, 93, 53, 89, + 45, 86, 58, 78, 66, 76, 63, 88, + 52, 93, 49, 57, 77, 68, 81, 81, + 73, 50, 85, 74, 55, 77, 83, 83, + 51, 78, 84, 46, 83, 55, 81, 57, + 76, 84, 77, 81, 87, 77, 51, 78, + 60, 82, 91, 53, 78, 46, 77, 84, + 49, 83, 71, 80, 49, 75, 64, 76, + 53, 94, 55, 76, 50, 82, 54, 75, + 78, 79, 78, 78, 70, 79, 70, 54, + 86, 50, 90, 54, 54, 77, 79, 64, + 75, 47, 86, 63, 85, 82, 57, 82, + 67, 74, 54, 83, 73, 73, 88, 80, + 71, 83, 56, 79, 78, 84, 58, 83, + 43, 60, 75, 81, 46, 90, 46, 74, + }, + eruptions: []float64{ + 3.600, 1.800, 3.333, 2.283, 4.533, 2.883, 4.700, 3.600, + 1.950, 4.350, 1.833, 3.917, 4.200, 1.750, 4.700, 2.167, + 1.750, 4.800, 1.600, 4.250, 1.800, 1.750, 3.450, 3.067, + 4.533, 3.600, 1.967, 4.083, 3.850, 4.433, 4.300, 4.467, + 3.367, 4.033, 3.833, 2.017, 1.867, 4.833, 1.833, 4.783, + 4.350, 1.883, 4.567, 1.750, 4.533, 3.317, 3.833, 2.100, + 4.633, 2.000, 4.800, 4.716, 1.833, 4.833, 1.733, 4.883, + 3.717, 1.667, 4.567, 4.317, 2.233, 4.500, 1.750, 4.800, + 1.817, 4.400, 4.167, 4.700, 2.067, 4.700, 4.033, 1.967, + 4.500, 4.000, 1.983, 5.067, 2.017, 4.567, 3.883, 3.600, + 4.133, 4.333, 4.100, 2.633, 4.067, 4.933, 3.950, 4.517, + 2.167, 4.000, 2.200, 4.333, 1.867, 4.817, 1.833, 4.300, + 4.667, 3.750, 1.867, 4.900, 2.483, 4.367, 2.100, 4.500, + 4.050, 1.867, 4.700, 1.783, 4.850, 3.683, 4.733, 2.300, + 4.900, 4.417, 1.700, 4.633, 2.317, 4.600, 1.817, 4.417, + 2.617, 4.067, 4.250, 1.967, 4.600, 3.767, 1.917, 4.500, + 2.267, 4.650, 1.867, 4.167, 2.800, 4.333, 1.833, 4.383, + 1.883, 4.933, 2.033, 3.733, 4.233, 2.233, 4.533, 4.817, + 4.333, 1.983, 4.633, 2.017, 5.100, 1.800, 5.033, 4.000, + 2.400, 4.600, 3.567, 4.000, 4.500, 4.083, 1.800, 3.967, + 2.200, 4.150, 2.000, 3.833, 3.500, 4.583, 2.367, 5.000, + 1.933, 4.617, 1.917, 2.083, 4.583, 3.333, 4.167, 4.333, + 4.500, 2.417, 4.000, 4.167, 1.883, 4.583, 4.250, 3.767, + 2.033, 4.433, 4.083, 1.833, 4.417, 2.183, 4.800, 1.833, + 4.800, 4.100, 3.966, 4.233, 3.500, 4.366, 2.250, 4.667, + 2.100, 4.350, 4.133, 1.867, 4.600, 1.783, 4.367, 3.850, + 1.933, 4.500, 2.383, 4.700, 1.867, 3.833, 3.417, 4.233, + 2.400, 4.800, 2.000, 4.150, 1.867, 4.267, 1.750, 4.483, + 4.000, 4.117, 4.083, 4.267, 3.917, 4.550, 4.083, 2.417, + 4.183, 2.217, 4.450, 1.883, 1.850, 4.283, 3.950, 2.333, + 4.150, 2.350, 4.933, 2.900, 4.583, 3.833, 2.083, 4.367, + 2.133, 4.350, 2.200, 4.450, 3.567, 4.500, 4.150, 3.817, + 3.917, 4.450, 2.000, 4.283, 4.767, 4.533, 1.850, 4.250, + 1.983, 2.250, 4.750, 4.117, 2.150, 4.417, 1.817, 4.467, + }, +} diff --git a/stat.go b/stat.go index 34fd2e83..294679e6 100644 --- a/stat.go +++ b/stat.go @@ -648,6 +648,141 @@ func KullbackLeibler(p, q []float64) float64 { return kl } +// LinearRegression computes the best-fit line +// y = alpha + beta*x +// to the data in x and y with the given weights. If origin is true, the +// regression is forced to pass through the origin. +// +// Specifically, LinearRegression computes the values of alpha and +// beta such that the total residual +// \sum_i w[i]*(y[i] - alpha - beta*x[i])^2 +// is minimized. If origin is true, then alpha is forced to be zero. +// +// The lengths of x and y must be equal. If weights is nil then all of the +// weights are 1. If weights is not nil, then len(x) must equal len(weights). +func LinearRegression(x, y, weights []float64, origin bool) (alpha, beta float64) { + if len(x) != len(y) { + panic("stat: slice length mismatch") + } + if weights != nil && len(weights) != len(x) { + panic("stat: slice length mismatch") + } + + w := 1.0 + if origin { + var x2Sum, xySum float64 + for i, xi := range x { + if weights != nil { + w = weights[i] + } + yi := y[i] + xySum += w * xi * yi + x2Sum += w * xi * xi + } + beta = xySum / x2Sum + + return 0, beta + } + + beta = Covariance(x, y, weights) / Variance(x, weights) + alpha = Mean(y, weights) - beta*Mean(x, weights) + return alpha, beta +} + +// RSquared returns the coefficient of determination defined as +// R^2 = 1 - \sum_i w[i]*(y[i] - alpha - beta*x[i])^2 / \sum_i w[i]*(y[i] - mean(y))^2 +// for the line +// y = alpha + beta*x +// and the data in x and y with the given weights. +// +// The lengths of x and y must be equal. If weights is nil then all of the +// weights are 1. If weights is not nil, then len(x) must equal len(weights). +func RSquared(x, y, weights []float64, alpha, beta float64) float64 { + if len(x) != len(y) { + panic("stat: slice length mismatch") + } + if weights != nil && len(weights) != len(x) { + panic("stat: slice length mismatch") + } + + w := 1.0 + yMean := Mean(y, weights) + var res, tot, d float64 + for i, xi := range x { + if weights != nil { + w = weights[i] + } + yi := y[i] + fi := alpha + beta*xi + d = yi - fi + res += w * d * d + d = yi - yMean + tot += w * d * d + } + return 1 - res/tot +} + +// RSquaredFrom returns the coefficient of determination defined as +// R^2 = 1 - \sum_i w[i]*(estimate[i] - value[i])^2 / \sum_i w[i]*(value[i] - mean(values))^2 +// and the data in estimates and values with the given weights. +// +// The lengths of estimates and values must be equal. If weights is nil then +// all of the weights are 1. If weights is not nil, then len(values) must +// equal len(weights). +func RSquaredFrom(estimates, values, weights []float64) float64 { + if len(estimates) != len(values) { + panic("stat: slice length mismatch") + } + if weights != nil && len(weights) != len(values) { + panic("stat: slice length mismatch") + } + + w := 1.0 + mean := Mean(values, weights) + var res, tot, d float64 + for i, val := range values { + if weights != nil { + w = weights[i] + } + d = val - estimates[i] + res += w * d * d + d = val - mean + tot += w * d * d + } + return 1 - res/tot +} + +// RNoughtSquared returns the coefficient of determination defined as +// R₀^2 = \sum_i w[i]*(beta*x[i])^2 / \sum_i w[i]*y[i]^2 +// for the line +// y = beta*x +// and the data in x and y with the given weights. RNoughtSquared should +// only be used for best-fit lines regressed through the origin. +// +// The lengths of x and y must be equal. If weights is nil then all of the +// weights are 1. If weights is not nil, then len(x) must equal len(weights). +func RNoughtSquared(x, y, weights []float64, beta float64) float64 { + if len(x) != len(y) { + panic("stat: slice length mismatch") + } + if weights != nil && len(weights) != len(x) { + panic("stat: slice length mismatch") + } + + w := 1.0 + var ssr, tot float64 + for i, xi := range x { + if weights != nil { + w = weights[i] + } + fi := beta * xi + ssr += w * fi * fi + yi := y[i] + tot += w * yi * yi + } + return ssr / tot +} + // Mean computes the weighted mean of the data set. // sum_i {w_i * x_i} / sum_i {w_i} // If weights is nil then all of the weights are 1. If weights is not nil, then diff --git a/stat_test.go b/stat_test.go index 3a682817..ef246771 100644 --- a/stat_test.go +++ b/stat_test.go @@ -753,6 +753,111 @@ func TestKullbackLeibler(t *testing.T) { } } +var linearRegressionTests = []struct { + name string + + x, y []float64 + weights []float64 + origin bool + + alpha float64 + beta float64 + r float64 + + tol float64 +}{ + { + name: "faithful", + + x: faithful.waiting, + y: faithful.eruptions, + + // Values calculated by R using lm(eruptions ~ waiting, data=faithful). + alpha: -1.87402, + beta: 0.07563, + r: 0.8114608, + + tol: 1e-5, + }, + { + name: "faithful through origin", + + x: faithful.waiting, + y: faithful.eruptions, + origin: true, + + // Values calculated by R using lm(eruptions ~ waiting - 1, data=faithful). + alpha: 0, + beta: 0.05013, + r: 0.9726036, + + tol: 1e-5, + }, + { + name: "faithful explicit weights", + + x: faithful.waiting, + y: faithful.eruptions, + weights: func() []float64 { + w := make([]float64, len(faithful.eruptions)) + for i := range w { + w[i] = 1 + } + return w + }(), + + // Values calculated by R using lm(eruptions ~ waiting, data=faithful). + alpha: -1.87402, + beta: 0.07563, + r: 0.8114608, + + tol: 1e-5, + }, + { + name: "faithful non-uniform weights", + + x: faithful.waiting, + y: faithful.eruptions, + weights: faithful.waiting, // Just an arbitrary set of non-uniform weights. + + // Values calculated by R using lm(eruptions ~ waiting, data=faithful, weights=faithful$waiting). + alpha: -1.79268, + beta: 0.07452, + r: 0.7840372, + + tol: 1e-5, + }, +} + +func TestLinearRegression(t *testing.T) { + for _, test := range linearRegressionTests { + alpha, beta := LinearRegression(test.x, test.y, test.weights, test.origin) + var r float64 + if test.origin { + r = RNoughtSquared(test.x, test.y, test.weights, beta) + } else { + r = RSquared(test.x, test.y, test.weights, alpha, beta) + ests := make([]float64, len(test.y)) + for i, x := range test.x { + ests[i] = alpha + beta*x + } + rvals := RSquaredFrom(ests, test.y, test.weights) + if r != rvals { + t.Errorf("%s: RSquared and RSquaredFrom mismatch: %v != %v", test.name, r, rvals) + } + } + if !floats.EqualWithinAbsOrRel(alpha, test.alpha, test.tol, test.tol) { + t.Errorf("%s: unexpected alpha estimate: want:%v got:%v", test.name, test.alpha, alpha) + } + if !floats.EqualWithinAbsOrRel(beta, test.beta, test.tol, test.tol) { + t.Errorf("%s: unexpected beta estimate: want:%v got:%v", test.name, test.beta, beta) + } + if !floats.EqualWithinAbsOrRel(r, test.r, test.tol, test.tol) { + t.Errorf("%s: unexpected r estimate: want:%v got:%v", test.name, test.r, r) + } + } +} + func TestChiSquare(t *testing.T) { for i, test := range []struct { p []float64