stat: add simple linear regression

This commit is contained in:
kortschak
2016-07-28 16:27:03 +09:30
parent 674d440284
commit a377c62d90
3 changed files with 321 additions and 0 deletions

81
faithful_test.go Normal file
View File

@@ -0,0 +1,81 @@
// Copyright ©2016 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stat
// faithful is the faithful data set from R.
var faithful = struct{ waiting, eruptions []float64 }{
waiting: []float64{
79, 54, 74, 62, 85, 55, 88, 85,
51, 85, 54, 84, 78, 47, 83, 52,
62, 84, 52, 79, 51, 47, 78, 69,
74, 83, 55, 76, 78, 79, 73, 77,
66, 80, 74, 52, 48, 80, 59, 90,
80, 58, 84, 58, 73, 83, 64, 53,
82, 59, 75, 90, 54, 80, 54, 83,
71, 64, 77, 81, 59, 84, 48, 82,
60, 92, 78, 78, 65, 73, 82, 56,
79, 71, 62, 76, 60, 78, 76, 83,
75, 82, 70, 65, 73, 88, 76, 80,
48, 86, 60, 90, 50, 78, 63, 72,
84, 75, 51, 82, 62, 88, 49, 83,
81, 47, 84, 52, 86, 81, 75, 59,
89, 79, 59, 81, 50, 85, 59, 87,
53, 69, 77, 56, 88, 81, 45, 82,
55, 90, 45, 83, 56, 89, 46, 82,
51, 86, 53, 79, 81, 60, 82, 77,
76, 59, 80, 49, 96, 53, 77, 77,
65, 81, 71, 70, 81, 93, 53, 89,
45, 86, 58, 78, 66, 76, 63, 88,
52, 93, 49, 57, 77, 68, 81, 81,
73, 50, 85, 74, 55, 77, 83, 83,
51, 78, 84, 46, 83, 55, 81, 57,
76, 84, 77, 81, 87, 77, 51, 78,
60, 82, 91, 53, 78, 46, 77, 84,
49, 83, 71, 80, 49, 75, 64, 76,
53, 94, 55, 76, 50, 82, 54, 75,
78, 79, 78, 78, 70, 79, 70, 54,
86, 50, 90, 54, 54, 77, 79, 64,
75, 47, 86, 63, 85, 82, 57, 82,
67, 74, 54, 83, 73, 73, 88, 80,
71, 83, 56, 79, 78, 84, 58, 83,
43, 60, 75, 81, 46, 90, 46, 74,
},
eruptions: []float64{
3.600, 1.800, 3.333, 2.283, 4.533, 2.883, 4.700, 3.600,
1.950, 4.350, 1.833, 3.917, 4.200, 1.750, 4.700, 2.167,
1.750, 4.800, 1.600, 4.250, 1.800, 1.750, 3.450, 3.067,
4.533, 3.600, 1.967, 4.083, 3.850, 4.433, 4.300, 4.467,
3.367, 4.033, 3.833, 2.017, 1.867, 4.833, 1.833, 4.783,
4.350, 1.883, 4.567, 1.750, 4.533, 3.317, 3.833, 2.100,
4.633, 2.000, 4.800, 4.716, 1.833, 4.833, 1.733, 4.883,
3.717, 1.667, 4.567, 4.317, 2.233, 4.500, 1.750, 4.800,
1.817, 4.400, 4.167, 4.700, 2.067, 4.700, 4.033, 1.967,
4.500, 4.000, 1.983, 5.067, 2.017, 4.567, 3.883, 3.600,
4.133, 4.333, 4.100, 2.633, 4.067, 4.933, 3.950, 4.517,
2.167, 4.000, 2.200, 4.333, 1.867, 4.817, 1.833, 4.300,
4.667, 3.750, 1.867, 4.900, 2.483, 4.367, 2.100, 4.500,
4.050, 1.867, 4.700, 1.783, 4.850, 3.683, 4.733, 2.300,
4.900, 4.417, 1.700, 4.633, 2.317, 4.600, 1.817, 4.417,
2.617, 4.067, 4.250, 1.967, 4.600, 3.767, 1.917, 4.500,
2.267, 4.650, 1.867, 4.167, 2.800, 4.333, 1.833, 4.383,
1.883, 4.933, 2.033, 3.733, 4.233, 2.233, 4.533, 4.817,
4.333, 1.983, 4.633, 2.017, 5.100, 1.800, 5.033, 4.000,
2.400, 4.600, 3.567, 4.000, 4.500, 4.083, 1.800, 3.967,
2.200, 4.150, 2.000, 3.833, 3.500, 4.583, 2.367, 5.000,
1.933, 4.617, 1.917, 2.083, 4.583, 3.333, 4.167, 4.333,
4.500, 2.417, 4.000, 4.167, 1.883, 4.583, 4.250, 3.767,
2.033, 4.433, 4.083, 1.833, 4.417, 2.183, 4.800, 1.833,
4.800, 4.100, 3.966, 4.233, 3.500, 4.366, 2.250, 4.667,
2.100, 4.350, 4.133, 1.867, 4.600, 1.783, 4.367, 3.850,
1.933, 4.500, 2.383, 4.700, 1.867, 3.833, 3.417, 4.233,
2.400, 4.800, 2.000, 4.150, 1.867, 4.267, 1.750, 4.483,
4.000, 4.117, 4.083, 4.267, 3.917, 4.550, 4.083, 2.417,
4.183, 2.217, 4.450, 1.883, 1.850, 4.283, 3.950, 2.333,
4.150, 2.350, 4.933, 2.900, 4.583, 3.833, 2.083, 4.367,
2.133, 4.350, 2.200, 4.450, 3.567, 4.500, 4.150, 3.817,
3.917, 4.450, 2.000, 4.283, 4.767, 4.533, 1.850, 4.250,
1.983, 2.250, 4.750, 4.117, 2.150, 4.417, 1.817, 4.467,
},
}

135
stat.go
View File

@@ -648,6 +648,141 @@ func KullbackLeibler(p, q []float64) float64 {
return kl
}
// LinearRegression computes the best-fit line
// y = alpha + beta*x
// to the data in x and y with the given weights. If origin is true, the
// regression is forced to pass through the origin.
//
// Specifically, LinearRegression computes the values of alpha and
// beta such that the total residual
// \sum_i w[i]*(y[i] - alpha - beta*x[i])^2
// is minimized. If origin is true, then alpha is forced to be zero.
//
// The lengths of x and y must be equal. If weights is nil then all of the
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
func LinearRegression(x, y, weights []float64, origin bool) (alpha, beta float64) {
if len(x) != len(y) {
panic("stat: slice length mismatch")
}
if weights != nil && len(weights) != len(x) {
panic("stat: slice length mismatch")
}
w := 1.0
if origin {
var x2Sum, xySum float64
for i, xi := range x {
if weights != nil {
w = weights[i]
}
yi := y[i]
xySum += w * xi * yi
x2Sum += w * xi * xi
}
beta = xySum / x2Sum
return 0, beta
}
beta = Covariance(x, y, weights) / Variance(x, weights)
alpha = Mean(y, weights) - beta*Mean(x, weights)
return alpha, beta
}
// RSquared returns the coefficient of determination defined as
// R^2 = 1 - \sum_i w[i]*(y[i] - alpha - beta*x[i])^2 / \sum_i w[i]*(y[i] - mean(y))^2
// for the line
// y = alpha + beta*x
// and the data in x and y with the given weights.
//
// The lengths of x and y must be equal. If weights is nil then all of the
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
func RSquared(x, y, weights []float64, alpha, beta float64) float64 {
if len(x) != len(y) {
panic("stat: slice length mismatch")
}
if weights != nil && len(weights) != len(x) {
panic("stat: slice length mismatch")
}
w := 1.0
yMean := Mean(y, weights)
var res, tot, d float64
for i, xi := range x {
if weights != nil {
w = weights[i]
}
yi := y[i]
fi := alpha + beta*xi
d = yi - fi
res += w * d * d
d = yi - yMean
tot += w * d * d
}
return 1 - res/tot
}
// RSquaredFrom returns the coefficient of determination defined as
// R^2 = 1 - \sum_i w[i]*(estimate[i] - value[i])^2 / \sum_i w[i]*(value[i] - mean(values))^2
// and the data in estimates and values with the given weights.
//
// The lengths of estimates and values must be equal. If weights is nil then
// all of the weights are 1. If weights is not nil, then len(values) must
// equal len(weights).
func RSquaredFrom(estimates, values, weights []float64) float64 {
if len(estimates) != len(values) {
panic("stat: slice length mismatch")
}
if weights != nil && len(weights) != len(values) {
panic("stat: slice length mismatch")
}
w := 1.0
mean := Mean(values, weights)
var res, tot, d float64
for i, val := range values {
if weights != nil {
w = weights[i]
}
d = val - estimates[i]
res += w * d * d
d = val - mean
tot += w * d * d
}
return 1 - res/tot
}
// RNoughtSquared returns the coefficient of determination defined as
// R₀^2 = \sum_i w[i]*(beta*x[i])^2 / \sum_i w[i]*y[i]^2
// for the line
// y = beta*x
// and the data in x and y with the given weights. RNoughtSquared should
// only be used for best-fit lines regressed through the origin.
//
// The lengths of x and y must be equal. If weights is nil then all of the
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
func RNoughtSquared(x, y, weights []float64, beta float64) float64 {
if len(x) != len(y) {
panic("stat: slice length mismatch")
}
if weights != nil && len(weights) != len(x) {
panic("stat: slice length mismatch")
}
w := 1.0
var ssr, tot float64
for i, xi := range x {
if weights != nil {
w = weights[i]
}
fi := beta * xi
ssr += w * fi * fi
yi := y[i]
tot += w * yi * yi
}
return ssr / tot
}
// Mean computes the weighted mean of the data set.
// sum_i {w_i * x_i} / sum_i {w_i}
// If weights is nil then all of the weights are 1. If weights is not nil, then

View File

@@ -753,6 +753,111 @@ func TestKullbackLeibler(t *testing.T) {
}
}
var linearRegressionTests = []struct {
name string
x, y []float64
weights []float64
origin bool
alpha float64
beta float64
r float64
tol float64
}{
{
name: "faithful",
x: faithful.waiting,
y: faithful.eruptions,
// Values calculated by R using lm(eruptions ~ waiting, data=faithful).
alpha: -1.87402,
beta: 0.07563,
r: 0.8114608,
tol: 1e-5,
},
{
name: "faithful through origin",
x: faithful.waiting,
y: faithful.eruptions,
origin: true,
// Values calculated by R using lm(eruptions ~ waiting - 1, data=faithful).
alpha: 0,
beta: 0.05013,
r: 0.9726036,
tol: 1e-5,
},
{
name: "faithful explicit weights",
x: faithful.waiting,
y: faithful.eruptions,
weights: func() []float64 {
w := make([]float64, len(faithful.eruptions))
for i := range w {
w[i] = 1
}
return w
}(),
// Values calculated by R using lm(eruptions ~ waiting, data=faithful).
alpha: -1.87402,
beta: 0.07563,
r: 0.8114608,
tol: 1e-5,
},
{
name: "faithful non-uniform weights",
x: faithful.waiting,
y: faithful.eruptions,
weights: faithful.waiting, // Just an arbitrary set of non-uniform weights.
// Values calculated by R using lm(eruptions ~ waiting, data=faithful, weights=faithful$waiting).
alpha: -1.79268,
beta: 0.07452,
r: 0.7840372,
tol: 1e-5,
},
}
func TestLinearRegression(t *testing.T) {
for _, test := range linearRegressionTests {
alpha, beta := LinearRegression(test.x, test.y, test.weights, test.origin)
var r float64
if test.origin {
r = RNoughtSquared(test.x, test.y, test.weights, beta)
} else {
r = RSquared(test.x, test.y, test.weights, alpha, beta)
ests := make([]float64, len(test.y))
for i, x := range test.x {
ests[i] = alpha + beta*x
}
rvals := RSquaredFrom(ests, test.y, test.weights)
if r != rvals {
t.Errorf("%s: RSquared and RSquaredFrom mismatch: %v != %v", test.name, r, rvals)
}
}
if !floats.EqualWithinAbsOrRel(alpha, test.alpha, test.tol, test.tol) {
t.Errorf("%s: unexpected alpha estimate: want:%v got:%v", test.name, test.alpha, alpha)
}
if !floats.EqualWithinAbsOrRel(beta, test.beta, test.tol, test.tol) {
t.Errorf("%s: unexpected beta estimate: want:%v got:%v", test.name, test.beta, beta)
}
if !floats.EqualWithinAbsOrRel(r, test.r, test.tol, test.tol) {
t.Errorf("%s: unexpected r estimate: want:%v got:%v", test.name, test.r, r)
}
}
}
func TestChiSquare(t *testing.T) {
for i, test := range []struct {
p []float64