mirror of
https://github.com/gonum/gonum.git
synced 2025-10-19 13:35:51 +08:00
stat: add simple linear regression
This commit is contained in:
105
stat_test.go
105
stat_test.go
@@ -753,6 +753,111 @@ func TestKullbackLeibler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var linearRegressionTests = []struct {
|
||||
name string
|
||||
|
||||
x, y []float64
|
||||
weights []float64
|
||||
origin bool
|
||||
|
||||
alpha float64
|
||||
beta float64
|
||||
r float64
|
||||
|
||||
tol float64
|
||||
}{
|
||||
{
|
||||
name: "faithful",
|
||||
|
||||
x: faithful.waiting,
|
||||
y: faithful.eruptions,
|
||||
|
||||
// Values calculated by R using lm(eruptions ~ waiting, data=faithful).
|
||||
alpha: -1.87402,
|
||||
beta: 0.07563,
|
||||
r: 0.8114608,
|
||||
|
||||
tol: 1e-5,
|
||||
},
|
||||
{
|
||||
name: "faithful through origin",
|
||||
|
||||
x: faithful.waiting,
|
||||
y: faithful.eruptions,
|
||||
origin: true,
|
||||
|
||||
// Values calculated by R using lm(eruptions ~ waiting - 1, data=faithful).
|
||||
alpha: 0,
|
||||
beta: 0.05013,
|
||||
r: 0.9726036,
|
||||
|
||||
tol: 1e-5,
|
||||
},
|
||||
{
|
||||
name: "faithful explicit weights",
|
||||
|
||||
x: faithful.waiting,
|
||||
y: faithful.eruptions,
|
||||
weights: func() []float64 {
|
||||
w := make([]float64, len(faithful.eruptions))
|
||||
for i := range w {
|
||||
w[i] = 1
|
||||
}
|
||||
return w
|
||||
}(),
|
||||
|
||||
// Values calculated by R using lm(eruptions ~ waiting, data=faithful).
|
||||
alpha: -1.87402,
|
||||
beta: 0.07563,
|
||||
r: 0.8114608,
|
||||
|
||||
tol: 1e-5,
|
||||
},
|
||||
{
|
||||
name: "faithful non-uniform weights",
|
||||
|
||||
x: faithful.waiting,
|
||||
y: faithful.eruptions,
|
||||
weights: faithful.waiting, // Just an arbitrary set of non-uniform weights.
|
||||
|
||||
// Values calculated by R using lm(eruptions ~ waiting, data=faithful, weights=faithful$waiting).
|
||||
alpha: -1.79268,
|
||||
beta: 0.07452,
|
||||
r: 0.7840372,
|
||||
|
||||
tol: 1e-5,
|
||||
},
|
||||
}
|
||||
|
||||
func TestLinearRegression(t *testing.T) {
|
||||
for _, test := range linearRegressionTests {
|
||||
alpha, beta := LinearRegression(test.x, test.y, test.weights, test.origin)
|
||||
var r float64
|
||||
if test.origin {
|
||||
r = RNoughtSquared(test.x, test.y, test.weights, beta)
|
||||
} else {
|
||||
r = RSquared(test.x, test.y, test.weights, alpha, beta)
|
||||
ests := make([]float64, len(test.y))
|
||||
for i, x := range test.x {
|
||||
ests[i] = alpha + beta*x
|
||||
}
|
||||
rvals := RSquaredFrom(ests, test.y, test.weights)
|
||||
if r != rvals {
|
||||
t.Errorf("%s: RSquared and RSquaredFrom mismatch: %v != %v", test.name, r, rvals)
|
||||
}
|
||||
}
|
||||
if !floats.EqualWithinAbsOrRel(alpha, test.alpha, test.tol, test.tol) {
|
||||
t.Errorf("%s: unexpected alpha estimate: want:%v got:%v", test.name, test.alpha, alpha)
|
||||
}
|
||||
if !floats.EqualWithinAbsOrRel(beta, test.beta, test.tol, test.tol) {
|
||||
t.Errorf("%s: unexpected beta estimate: want:%v got:%v", test.name, test.beta, beta)
|
||||
}
|
||||
if !floats.EqualWithinAbsOrRel(r, test.r, test.tol, test.tol) {
|
||||
t.Errorf("%s: unexpected r estimate: want:%v got:%v", test.name, test.r, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChiSquare(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
p []float64
|
||||
|
Reference in New Issue
Block a user