stat: add simple linear regression

This commit is contained in:
kortschak
2016-07-28 16:27:03 +09:30
parent 674d440284
commit a377c62d90
3 changed files with 321 additions and 0 deletions

135
stat.go
View File

@@ -648,6 +648,141 @@ func KullbackLeibler(p, q []float64) float64 {
return kl
}
// LinearRegression computes the best-fit line
// y = alpha + beta*x
// to the data in x and y with the given weights. If origin is true, the
// regression is forced to pass through the origin.
//
// Specifically, LinearRegression computes the values of alpha and
// beta such that the total residual
// \sum_i w[i]*(y[i] - alpha - beta*x[i])^2
// is minimized. If origin is true, then alpha is forced to be zero.
//
// The lengths of x and y must be equal. If weights is nil then all of the
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
func LinearRegression(x, y, weights []float64, origin bool) (alpha, beta float64) {
if len(x) != len(y) {
panic("stat: slice length mismatch")
}
if weights != nil && len(weights) != len(x) {
panic("stat: slice length mismatch")
}
w := 1.0
if origin {
var x2Sum, xySum float64
for i, xi := range x {
if weights != nil {
w = weights[i]
}
yi := y[i]
xySum += w * xi * yi
x2Sum += w * xi * xi
}
beta = xySum / x2Sum
return 0, beta
}
beta = Covariance(x, y, weights) / Variance(x, weights)
alpha = Mean(y, weights) - beta*Mean(x, weights)
return alpha, beta
}
// RSquared returns the coefficient of determination defined as
// R^2 = 1 - \sum_i w[i]*(y[i] - alpha - beta*x[i])^2 / \sum_i w[i]*(y[i] - mean(y))^2
// for the line
// y = alpha + beta*x
// and the data in x and y with the given weights.
//
// The lengths of x and y must be equal. If weights is nil then all of the
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
func RSquared(x, y, weights []float64, alpha, beta float64) float64 {
if len(x) != len(y) {
panic("stat: slice length mismatch")
}
if weights != nil && len(weights) != len(x) {
panic("stat: slice length mismatch")
}
w := 1.0
yMean := Mean(y, weights)
var res, tot, d float64
for i, xi := range x {
if weights != nil {
w = weights[i]
}
yi := y[i]
fi := alpha + beta*xi
d = yi - fi
res += w * d * d
d = yi - yMean
tot += w * d * d
}
return 1 - res/tot
}
// RSquaredFrom returns the coefficient of determination defined as
// R^2 = 1 - \sum_i w[i]*(estimate[i] - value[i])^2 / \sum_i w[i]*(value[i] - mean(values))^2
// and the data in estimates and values with the given weights.
//
// The lengths of estimates and values must be equal. If weights is nil then
// all of the weights are 1. If weights is not nil, then len(values) must
// equal len(weights).
func RSquaredFrom(estimates, values, weights []float64) float64 {
if len(estimates) != len(values) {
panic("stat: slice length mismatch")
}
if weights != nil && len(weights) != len(values) {
panic("stat: slice length mismatch")
}
w := 1.0
mean := Mean(values, weights)
var res, tot, d float64
for i, val := range values {
if weights != nil {
w = weights[i]
}
d = val - estimates[i]
res += w * d * d
d = val - mean
tot += w * d * d
}
return 1 - res/tot
}
// RNoughtSquared returns the coefficient of determination defined as
// R₀^2 = \sum_i w[i]*(beta*x[i])^2 / \sum_i w[i]*y[i]^2
// for the line
// y = beta*x
// and the data in x and y with the given weights. RNoughtSquared should
// only be used for best-fit lines regressed through the origin.
//
// The lengths of x and y must be equal. If weights is nil then all of the
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
func RNoughtSquared(x, y, weights []float64, beta float64) float64 {
if len(x) != len(y) {
panic("stat: slice length mismatch")
}
if weights != nil && len(weights) != len(x) {
panic("stat: slice length mismatch")
}
w := 1.0
var ssr, tot float64
for i, xi := range x {
if weights != nil {
w = weights[i]
}
fi := beta * xi
ssr += w * fi * fi
yi := y[i]
tot += w * yi * yi
}
return ssr / tot
}
// Mean computes the weighted mean of the data set.
// sum_i {w_i * x_i} / sum_i {w_i}
// If weights is nil then all of the weights are 1. If weights is not nil, then