mirror of
https://github.com/gonum/gonum.git
synced 2025-10-15 11:40:45 +08:00
stat: add simple linear regression
This commit is contained in:
81
faithful_test.go
Normal file
81
faithful_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package stat
|
||||
|
||||
// faithful is the faithful data set from R.
|
||||
var faithful = struct{ waiting, eruptions []float64 }{
|
||||
waiting: []float64{
|
||||
79, 54, 74, 62, 85, 55, 88, 85,
|
||||
51, 85, 54, 84, 78, 47, 83, 52,
|
||||
62, 84, 52, 79, 51, 47, 78, 69,
|
||||
74, 83, 55, 76, 78, 79, 73, 77,
|
||||
66, 80, 74, 52, 48, 80, 59, 90,
|
||||
80, 58, 84, 58, 73, 83, 64, 53,
|
||||
82, 59, 75, 90, 54, 80, 54, 83,
|
||||
71, 64, 77, 81, 59, 84, 48, 82,
|
||||
60, 92, 78, 78, 65, 73, 82, 56,
|
||||
79, 71, 62, 76, 60, 78, 76, 83,
|
||||
75, 82, 70, 65, 73, 88, 76, 80,
|
||||
48, 86, 60, 90, 50, 78, 63, 72,
|
||||
84, 75, 51, 82, 62, 88, 49, 83,
|
||||
81, 47, 84, 52, 86, 81, 75, 59,
|
||||
89, 79, 59, 81, 50, 85, 59, 87,
|
||||
53, 69, 77, 56, 88, 81, 45, 82,
|
||||
55, 90, 45, 83, 56, 89, 46, 82,
|
||||
51, 86, 53, 79, 81, 60, 82, 77,
|
||||
76, 59, 80, 49, 96, 53, 77, 77,
|
||||
65, 81, 71, 70, 81, 93, 53, 89,
|
||||
45, 86, 58, 78, 66, 76, 63, 88,
|
||||
52, 93, 49, 57, 77, 68, 81, 81,
|
||||
73, 50, 85, 74, 55, 77, 83, 83,
|
||||
51, 78, 84, 46, 83, 55, 81, 57,
|
||||
76, 84, 77, 81, 87, 77, 51, 78,
|
||||
60, 82, 91, 53, 78, 46, 77, 84,
|
||||
49, 83, 71, 80, 49, 75, 64, 76,
|
||||
53, 94, 55, 76, 50, 82, 54, 75,
|
||||
78, 79, 78, 78, 70, 79, 70, 54,
|
||||
86, 50, 90, 54, 54, 77, 79, 64,
|
||||
75, 47, 86, 63, 85, 82, 57, 82,
|
||||
67, 74, 54, 83, 73, 73, 88, 80,
|
||||
71, 83, 56, 79, 78, 84, 58, 83,
|
||||
43, 60, 75, 81, 46, 90, 46, 74,
|
||||
},
|
||||
eruptions: []float64{
|
||||
3.600, 1.800, 3.333, 2.283, 4.533, 2.883, 4.700, 3.600,
|
||||
1.950, 4.350, 1.833, 3.917, 4.200, 1.750, 4.700, 2.167,
|
||||
1.750, 4.800, 1.600, 4.250, 1.800, 1.750, 3.450, 3.067,
|
||||
4.533, 3.600, 1.967, 4.083, 3.850, 4.433, 4.300, 4.467,
|
||||
3.367, 4.033, 3.833, 2.017, 1.867, 4.833, 1.833, 4.783,
|
||||
4.350, 1.883, 4.567, 1.750, 4.533, 3.317, 3.833, 2.100,
|
||||
4.633, 2.000, 4.800, 4.716, 1.833, 4.833, 1.733, 4.883,
|
||||
3.717, 1.667, 4.567, 4.317, 2.233, 4.500, 1.750, 4.800,
|
||||
1.817, 4.400, 4.167, 4.700, 2.067, 4.700, 4.033, 1.967,
|
||||
4.500, 4.000, 1.983, 5.067, 2.017, 4.567, 3.883, 3.600,
|
||||
4.133, 4.333, 4.100, 2.633, 4.067, 4.933, 3.950, 4.517,
|
||||
2.167, 4.000, 2.200, 4.333, 1.867, 4.817, 1.833, 4.300,
|
||||
4.667, 3.750, 1.867, 4.900, 2.483, 4.367, 2.100, 4.500,
|
||||
4.050, 1.867, 4.700, 1.783, 4.850, 3.683, 4.733, 2.300,
|
||||
4.900, 4.417, 1.700, 4.633, 2.317, 4.600, 1.817, 4.417,
|
||||
2.617, 4.067, 4.250, 1.967, 4.600, 3.767, 1.917, 4.500,
|
||||
2.267, 4.650, 1.867, 4.167, 2.800, 4.333, 1.833, 4.383,
|
||||
1.883, 4.933, 2.033, 3.733, 4.233, 2.233, 4.533, 4.817,
|
||||
4.333, 1.983, 4.633, 2.017, 5.100, 1.800, 5.033, 4.000,
|
||||
2.400, 4.600, 3.567, 4.000, 4.500, 4.083, 1.800, 3.967,
|
||||
2.200, 4.150, 2.000, 3.833, 3.500, 4.583, 2.367, 5.000,
|
||||
1.933, 4.617, 1.917, 2.083, 4.583, 3.333, 4.167, 4.333,
|
||||
4.500, 2.417, 4.000, 4.167, 1.883, 4.583, 4.250, 3.767,
|
||||
2.033, 4.433, 4.083, 1.833, 4.417, 2.183, 4.800, 1.833,
|
||||
4.800, 4.100, 3.966, 4.233, 3.500, 4.366, 2.250, 4.667,
|
||||
2.100, 4.350, 4.133, 1.867, 4.600, 1.783, 4.367, 3.850,
|
||||
1.933, 4.500, 2.383, 4.700, 1.867, 3.833, 3.417, 4.233,
|
||||
2.400, 4.800, 2.000, 4.150, 1.867, 4.267, 1.750, 4.483,
|
||||
4.000, 4.117, 4.083, 4.267, 3.917, 4.550, 4.083, 2.417,
|
||||
4.183, 2.217, 4.450, 1.883, 1.850, 4.283, 3.950, 2.333,
|
||||
4.150, 2.350, 4.933, 2.900, 4.583, 3.833, 2.083, 4.367,
|
||||
2.133, 4.350, 2.200, 4.450, 3.567, 4.500, 4.150, 3.817,
|
||||
3.917, 4.450, 2.000, 4.283, 4.767, 4.533, 1.850, 4.250,
|
||||
1.983, 2.250, 4.750, 4.117, 2.150, 4.417, 1.817, 4.467,
|
||||
},
|
||||
}
|
135
stat.go
135
stat.go
@@ -648,6 +648,141 @@ func KullbackLeibler(p, q []float64) float64 {
|
||||
return kl
|
||||
}
|
||||
|
||||
// LinearRegression computes the best-fit line
|
||||
// y = alpha + beta*x
|
||||
// to the data in x and y with the given weights. If origin is true, the
|
||||
// regression is forced to pass through the origin.
|
||||
//
|
||||
// Specifically, LinearRegression computes the values of alpha and
|
||||
// beta such that the total residual
|
||||
// \sum_i w[i]*(y[i] - alpha - beta*x[i])^2
|
||||
// is minimized. If origin is true, then alpha is forced to be zero.
|
||||
//
|
||||
// The lengths of x and y must be equal. If weights is nil then all of the
|
||||
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
|
||||
func LinearRegression(x, y, weights []float64, origin bool) (alpha, beta float64) {
|
||||
if len(x) != len(y) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
if weights != nil && len(weights) != len(x) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
|
||||
w := 1.0
|
||||
if origin {
|
||||
var x2Sum, xySum float64
|
||||
for i, xi := range x {
|
||||
if weights != nil {
|
||||
w = weights[i]
|
||||
}
|
||||
yi := y[i]
|
||||
xySum += w * xi * yi
|
||||
x2Sum += w * xi * xi
|
||||
}
|
||||
beta = xySum / x2Sum
|
||||
|
||||
return 0, beta
|
||||
}
|
||||
|
||||
beta = Covariance(x, y, weights) / Variance(x, weights)
|
||||
alpha = Mean(y, weights) - beta*Mean(x, weights)
|
||||
return alpha, beta
|
||||
}
|
||||
|
||||
// RSquared returns the coefficient of determination defined as
|
||||
// R^2 = 1 - \sum_i w[i]*(y[i] - alpha - beta*x[i])^2 / \sum_i w[i]*(y[i] - mean(y))^2
|
||||
// for the line
|
||||
// y = alpha + beta*x
|
||||
// and the data in x and y with the given weights.
|
||||
//
|
||||
// The lengths of x and y must be equal. If weights is nil then all of the
|
||||
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
|
||||
func RSquared(x, y, weights []float64, alpha, beta float64) float64 {
|
||||
if len(x) != len(y) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
if weights != nil && len(weights) != len(x) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
|
||||
w := 1.0
|
||||
yMean := Mean(y, weights)
|
||||
var res, tot, d float64
|
||||
for i, xi := range x {
|
||||
if weights != nil {
|
||||
w = weights[i]
|
||||
}
|
||||
yi := y[i]
|
||||
fi := alpha + beta*xi
|
||||
d = yi - fi
|
||||
res += w * d * d
|
||||
d = yi - yMean
|
||||
tot += w * d * d
|
||||
}
|
||||
return 1 - res/tot
|
||||
}
|
||||
|
||||
// RSquaredFrom returns the coefficient of determination defined as
|
||||
// R^2 = 1 - \sum_i w[i]*(estimate[i] - value[i])^2 / \sum_i w[i]*(value[i] - mean(values))^2
|
||||
// and the data in estimates and values with the given weights.
|
||||
//
|
||||
// The lengths of estimates and values must be equal. If weights is nil then
|
||||
// all of the weights are 1. If weights is not nil, then len(values) must
|
||||
// equal len(weights).
|
||||
func RSquaredFrom(estimates, values, weights []float64) float64 {
|
||||
if len(estimates) != len(values) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
if weights != nil && len(weights) != len(values) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
|
||||
w := 1.0
|
||||
mean := Mean(values, weights)
|
||||
var res, tot, d float64
|
||||
for i, val := range values {
|
||||
if weights != nil {
|
||||
w = weights[i]
|
||||
}
|
||||
d = val - estimates[i]
|
||||
res += w * d * d
|
||||
d = val - mean
|
||||
tot += w * d * d
|
||||
}
|
||||
return 1 - res/tot
|
||||
}
|
||||
|
||||
// RNoughtSquared returns the coefficient of determination defined as
|
||||
// R₀^2 = \sum_i w[i]*(beta*x[i])^2 / \sum_i w[i]*y[i]^2
|
||||
// for the line
|
||||
// y = beta*x
|
||||
// and the data in x and y with the given weights. RNoughtSquared should
|
||||
// only be used for best-fit lines regressed through the origin.
|
||||
//
|
||||
// The lengths of x and y must be equal. If weights is nil then all of the
|
||||
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
|
||||
func RNoughtSquared(x, y, weights []float64, beta float64) float64 {
|
||||
if len(x) != len(y) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
if weights != nil && len(weights) != len(x) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
|
||||
w := 1.0
|
||||
var ssr, tot float64
|
||||
for i, xi := range x {
|
||||
if weights != nil {
|
||||
w = weights[i]
|
||||
}
|
||||
fi := beta * xi
|
||||
ssr += w * fi * fi
|
||||
yi := y[i]
|
||||
tot += w * yi * yi
|
||||
}
|
||||
return ssr / tot
|
||||
}
|
||||
|
||||
// Mean computes the weighted mean of the data set.
|
||||
// sum_i {w_i * x_i} / sum_i {w_i}
|
||||
// If weights is nil then all of the weights are 1. If weights is not nil, then
|
||||
|
105
stat_test.go
105
stat_test.go
@@ -753,6 +753,111 @@ func TestKullbackLeibler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var linearRegressionTests = []struct {
|
||||
name string
|
||||
|
||||
x, y []float64
|
||||
weights []float64
|
||||
origin bool
|
||||
|
||||
alpha float64
|
||||
beta float64
|
||||
r float64
|
||||
|
||||
tol float64
|
||||
}{
|
||||
{
|
||||
name: "faithful",
|
||||
|
||||
x: faithful.waiting,
|
||||
y: faithful.eruptions,
|
||||
|
||||
// Values calculated by R using lm(eruptions ~ waiting, data=faithful).
|
||||
alpha: -1.87402,
|
||||
beta: 0.07563,
|
||||
r: 0.8114608,
|
||||
|
||||
tol: 1e-5,
|
||||
},
|
||||
{
|
||||
name: "faithful through origin",
|
||||
|
||||
x: faithful.waiting,
|
||||
y: faithful.eruptions,
|
||||
origin: true,
|
||||
|
||||
// Values calculated by R using lm(eruptions ~ waiting - 1, data=faithful).
|
||||
alpha: 0,
|
||||
beta: 0.05013,
|
||||
r: 0.9726036,
|
||||
|
||||
tol: 1e-5,
|
||||
},
|
||||
{
|
||||
name: "faithful explicit weights",
|
||||
|
||||
x: faithful.waiting,
|
||||
y: faithful.eruptions,
|
||||
weights: func() []float64 {
|
||||
w := make([]float64, len(faithful.eruptions))
|
||||
for i := range w {
|
||||
w[i] = 1
|
||||
}
|
||||
return w
|
||||
}(),
|
||||
|
||||
// Values calculated by R using lm(eruptions ~ waiting, data=faithful).
|
||||
alpha: -1.87402,
|
||||
beta: 0.07563,
|
||||
r: 0.8114608,
|
||||
|
||||
tol: 1e-5,
|
||||
},
|
||||
{
|
||||
name: "faithful non-uniform weights",
|
||||
|
||||
x: faithful.waiting,
|
||||
y: faithful.eruptions,
|
||||
weights: faithful.waiting, // Just an arbitrary set of non-uniform weights.
|
||||
|
||||
// Values calculated by R using lm(eruptions ~ waiting, data=faithful, weights=faithful$waiting).
|
||||
alpha: -1.79268,
|
||||
beta: 0.07452,
|
||||
r: 0.7840372,
|
||||
|
||||
tol: 1e-5,
|
||||
},
|
||||
}
|
||||
|
||||
func TestLinearRegression(t *testing.T) {
|
||||
for _, test := range linearRegressionTests {
|
||||
alpha, beta := LinearRegression(test.x, test.y, test.weights, test.origin)
|
||||
var r float64
|
||||
if test.origin {
|
||||
r = RNoughtSquared(test.x, test.y, test.weights, beta)
|
||||
} else {
|
||||
r = RSquared(test.x, test.y, test.weights, alpha, beta)
|
||||
ests := make([]float64, len(test.y))
|
||||
for i, x := range test.x {
|
||||
ests[i] = alpha + beta*x
|
||||
}
|
||||
rvals := RSquaredFrom(ests, test.y, test.weights)
|
||||
if r != rvals {
|
||||
t.Errorf("%s: RSquared and RSquaredFrom mismatch: %v != %v", test.name, r, rvals)
|
||||
}
|
||||
}
|
||||
if !floats.EqualWithinAbsOrRel(alpha, test.alpha, test.tol, test.tol) {
|
||||
t.Errorf("%s: unexpected alpha estimate: want:%v got:%v", test.name, test.alpha, alpha)
|
||||
}
|
||||
if !floats.EqualWithinAbsOrRel(beta, test.beta, test.tol, test.tol) {
|
||||
t.Errorf("%s: unexpected beta estimate: want:%v got:%v", test.name, test.beta, beta)
|
||||
}
|
||||
if !floats.EqualWithinAbsOrRel(r, test.r, test.tol, test.tol) {
|
||||
t.Errorf("%s: unexpected r estimate: want:%v got:%v", test.name, test.r, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChiSquare(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
p []float64
|
||||
|
Reference in New Issue
Block a user