Working version of Solve

This commit is contained in:
btracey
2014-01-08 13:54:26 -08:00
parent 430f7a3c1b
commit e88f486561
2 changed files with 79 additions and 24 deletions

View File

@@ -808,18 +808,21 @@ func (s *S) TestSolve(c *check.C) {
name: "SkinnyMatrix", name: "SkinnyMatrix",
panics: false, panics: false,
a: [][]float64{ a: [][]float64{
{0.8147, 0.9134}, {0.8147, 0.9134, 0.9},
{0.9058, 0.6324}, {0.9058, 0.6324, 0.9},
{0.1270, 0.0975}, {0.1270, 0.0975, 0.1},
{1.6, 2.8, -3.5},
}, },
b: [][]float64{ b: [][]float64{
{0.278}, {0.278},
{0.547}, {0.547},
{0.958}, {-0.958},
{1.452},
}, },
x: [][]float64{ x: [][]float64{
{1.291723965752262}, {0.820970340787782},
{-0.823253621853170}, {-0.218604626527306},
{-0.212938815234215},
}, },
}, },
} { } {
@@ -837,9 +840,8 @@ func (s *S) TestSolve(c *check.C) {
c.Check(panicked, check.Equals, test.panics, check.Commentf("Test %v panicked: %s", test.name, message)) c.Check(panicked, check.Equals, test.panics, check.Commentf("Test %v panicked: %s", test.name, message))
continue continue
} }
c.Check(x.EqualsApprox(NewDense(flatten(test.x)), 1e-14), check.Equals, true, check.Commentf("Test %v ", test.name))
a.Mul(a, x) c.Check(x.EqualsApprox(NewDense(flatten(test.x)), 1e-14), check.Equals, true, check.Commentf("Test %v solution mismatch: Found %v, expected %v ", test.name, x, test.x))
c.Check(a.EqualsApprox(b, 1e-14), check.Equals, true, check.Commentf("Test %v ", test.name))
} }
} }

View File

@@ -7,6 +7,8 @@ package mat64
import ( import (
"math" "math"
"fmt"
) )
type QRFactor struct { type QRFactor struct {
@@ -144,7 +146,8 @@ func (f QRFactor) Q() *Dense {
// A matrix x is returned that minimizes the two norm of Q*R*X-B. QRSolve will panic // A matrix x is returned that minimizes the two norm of Q*R*X-B. QRSolve will panic
// if a is not full rank. The matrix b is overwritten during the call. // if a is not full rank. The matrix b is overwritten during the call.
func (f QRFactor) Solve(b *Dense) (x *Dense) { func (f QRFactor) Solve(b *Dense) (x *Dense) {
qr, rDiag := f.QR, f.rDiag qr := f.QR
//rDiag := f.rDiag
m, n := qr.Dims() m, n := qr.Dims()
bm, bn := b.Dims() bm, bn := b.Dims()
if bm != m { if bm != m {
@@ -154,34 +157,84 @@ func (f QRFactor) Solve(b *Dense) (x *Dense) {
panic("mat64: matrix is rank deficient") panic("mat64: matrix is rank deficient")
} }
x = NewDense(n, bn, use(b.mat.Data, n*bn)) //x = NewDense(n, bn, use(b.mat.Data, n*bn))
x = NewDense(n, bn, nil)
nx := bn nx := bn
q := f.Q()
// Compute Y = transpose(Q)*B // Compute Y = transpose(Q)*B
for k := 0; k < n; k++ { for k := 0; k < n; k++ {
for j := 0; j < nx; j++ { for j := 0; j < nx; j++ {
var s float64 var s float64
for i := k; i < n; i++ { for i := 0; i < m; i++ {
s += qr.At(i, k) * x.At(i, j) s += q.At(i, k) * b.At(i, j)
}
s /= -qr.At(k, k)
for i := k; i < n; i++ {
x.Set(i, j, x.At(i, j)+s*qr.At(i, k))
} }
x.Set(k, j, s)
} }
} }
// Solve R*X = Y; // Solve R*X = Y;
for k := n - 1; k >= 0; k-- { r := f.R()
for j := 0; j < nx; j++ {
x.Set(k, j, x.At(k, j)/rDiag[k]) fmt.Println("r=", r)
} fmt.Println("q=", q)
for i := 0; i < k; i++ { fmt.Println("y=", x)
for j := 0; j < nx; j++ {
x.Set(i, j, x.At(i, j)-x.At(k, j)*qr.At(i, k)) xcopy := DenseCopyOf(x)
for j := 0; j < nx; j++ {
for k := n - 1; k >= 0; k-- {
fmt.Println("k = ", k)
val := xcopy.At(k, j)
for i := k + 1; i < n; i++ {
fmt.Println("i = ", i)
val -= r.At(k, i) * x.At(i, j)
} }
val /= r.At(k, k)
x.Set(k, j, val)
} }
} }
/*
for k := n - 1; k >= 0; k-- {
for j := 0; j < nx; j++ {
x.Set(k, j, x.At(k, j)/rDiag[k])
}
for i := 0; i < k; i++ {
for j := 0; j < nx; j++ {
x.Set(i, j, x.At(i, j)-x.At(k, j)*qr.At(i, k))
}
}
}
*/
/*
// Compute Y = transpose(Q)*B
for k := 0; k < n; k++ {
for j := 0; j < nx; j++ {
var s float64
for i := k; i < n; i++ {
s += qr.At(i, k) * x.At(i, j)
}
s /= -qr.At(k, k)
for i := k; i < n; i++ {
x.Set(i, j, x.At(i, j)+s*qr.At(i, k))
}
}
}
// Solve R*X = Y;
for k := n - 1; k >= 0; k-- {
for j := 0; j < nx; j++ {
x.Set(k, j, x.At(k, j)/rDiag[k])
}
for i := 0; i < k; i++ {
for j := 0; j < nx; j++ {
x.Set(i, j, x.At(i, j)-x.At(k, j)*qr.At(i, k))
}
}
}
*/
return x return x
} }