mirror of
https://github.com/gonum/gonum.git
synced 2025-10-17 20:51:06 +08:00
stat: add simple linear regression
This commit is contained in:
135
stat.go
135
stat.go
@@ -648,6 +648,141 @@ func KullbackLeibler(p, q []float64) float64 {
|
||||
return kl
|
||||
}
|
||||
|
||||
// LinearRegression computes the best-fit line
|
||||
// y = alpha + beta*x
|
||||
// to the data in x and y with the given weights. If origin is true, the
|
||||
// regression is forced to pass through the origin.
|
||||
//
|
||||
// Specifically, LinearRegression computes the values of alpha and
|
||||
// beta such that the total residual
|
||||
// \sum_i w[i]*(y[i] - alpha - beta*x[i])^2
|
||||
// is minimized. If origin is true, then alpha is forced to be zero.
|
||||
//
|
||||
// The lengths of x and y must be equal. If weights is nil then all of the
|
||||
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
|
||||
func LinearRegression(x, y, weights []float64, origin bool) (alpha, beta float64) {
|
||||
if len(x) != len(y) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
if weights != nil && len(weights) != len(x) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
|
||||
w := 1.0
|
||||
if origin {
|
||||
var x2Sum, xySum float64
|
||||
for i, xi := range x {
|
||||
if weights != nil {
|
||||
w = weights[i]
|
||||
}
|
||||
yi := y[i]
|
||||
xySum += w * xi * yi
|
||||
x2Sum += w * xi * xi
|
||||
}
|
||||
beta = xySum / x2Sum
|
||||
|
||||
return 0, beta
|
||||
}
|
||||
|
||||
beta = Covariance(x, y, weights) / Variance(x, weights)
|
||||
alpha = Mean(y, weights) - beta*Mean(x, weights)
|
||||
return alpha, beta
|
||||
}
|
||||
|
||||
// RSquared returns the coefficient of determination defined as
|
||||
// R^2 = 1 - \sum_i w[i]*(y[i] - alpha - beta*x[i])^2 / \sum_i w[i]*(y[i] - mean(y))^2
|
||||
// for the line
|
||||
// y = alpha + beta*x
|
||||
// and the data in x and y with the given weights.
|
||||
//
|
||||
// The lengths of x and y must be equal. If weights is nil then all of the
|
||||
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
|
||||
func RSquared(x, y, weights []float64, alpha, beta float64) float64 {
|
||||
if len(x) != len(y) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
if weights != nil && len(weights) != len(x) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
|
||||
w := 1.0
|
||||
yMean := Mean(y, weights)
|
||||
var res, tot, d float64
|
||||
for i, xi := range x {
|
||||
if weights != nil {
|
||||
w = weights[i]
|
||||
}
|
||||
yi := y[i]
|
||||
fi := alpha + beta*xi
|
||||
d = yi - fi
|
||||
res += w * d * d
|
||||
d = yi - yMean
|
||||
tot += w * d * d
|
||||
}
|
||||
return 1 - res/tot
|
||||
}
|
||||
|
||||
// RSquaredFrom returns the coefficient of determination defined as
|
||||
// R^2 = 1 - \sum_i w[i]*(estimate[i] - value[i])^2 / \sum_i w[i]*(value[i] - mean(values))^2
|
||||
// and the data in estimates and values with the given weights.
|
||||
//
|
||||
// The lengths of estimates and values must be equal. If weights is nil then
|
||||
// all of the weights are 1. If weights is not nil, then len(values) must
|
||||
// equal len(weights).
|
||||
func RSquaredFrom(estimates, values, weights []float64) float64 {
|
||||
if len(estimates) != len(values) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
if weights != nil && len(weights) != len(values) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
|
||||
w := 1.0
|
||||
mean := Mean(values, weights)
|
||||
var res, tot, d float64
|
||||
for i, val := range values {
|
||||
if weights != nil {
|
||||
w = weights[i]
|
||||
}
|
||||
d = val - estimates[i]
|
||||
res += w * d * d
|
||||
d = val - mean
|
||||
tot += w * d * d
|
||||
}
|
||||
return 1 - res/tot
|
||||
}
|
||||
|
||||
// RNoughtSquared returns the coefficient of determination defined as
|
||||
// R₀^2 = \sum_i w[i]*(beta*x[i])^2 / \sum_i w[i]*y[i]^2
|
||||
// for the line
|
||||
// y = beta*x
|
||||
// and the data in x and y with the given weights. RNoughtSquared should
|
||||
// only be used for best-fit lines regressed through the origin.
|
||||
//
|
||||
// The lengths of x and y must be equal. If weights is nil then all of the
|
||||
// weights are 1. If weights is not nil, then len(x) must equal len(weights).
|
||||
func RNoughtSquared(x, y, weights []float64, beta float64) float64 {
|
||||
if len(x) != len(y) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
if weights != nil && len(weights) != len(x) {
|
||||
panic("stat: slice length mismatch")
|
||||
}
|
||||
|
||||
w := 1.0
|
||||
var ssr, tot float64
|
||||
for i, xi := range x {
|
||||
if weights != nil {
|
||||
w = weights[i]
|
||||
}
|
||||
fi := beta * xi
|
||||
ssr += w * fi * fi
|
||||
yi := y[i]
|
||||
tot += w * yi * yi
|
||||
}
|
||||
return ssr / tot
|
||||
}
|
||||
|
||||
// Mean computes the weighted mean of the data set.
|
||||
// sum_i {w_i * x_i} / sum_i {w_i}
|
||||
// If weights is nil then all of the weights are 1. If weights is not nil, then
|
||||
|
Reference in New Issue
Block a user