From e88f486561e539220b8e1ff987cb05a581386b32 Mon Sep 17 00:00:00 2001 From: btracey Date: Wed, 8 Jan 2014 13:54:26 -0800 Subject: [PATCH] Working version of Solve --- mat64/matrix_test.go | 20 ++++++----- mat64/qr.go | 83 ++++++++++++++++++++++++++++++++++++-------- 2 files changed, 79 insertions(+), 24 deletions(-) diff --git a/mat64/matrix_test.go b/mat64/matrix_test.go index ad014999..4b436466 100644 --- a/mat64/matrix_test.go +++ b/mat64/matrix_test.go @@ -808,18 +808,21 @@ func (s *S) TestSolve(c *check.C) { name: "SkinnyMatrix", panics: false, a: [][]float64{ - {0.8147, 0.9134}, - {0.9058, 0.6324}, - {0.1270, 0.0975}, + {0.8147, 0.9134, 0.9}, + {0.9058, 0.6324, 0.9}, + {0.1270, 0.0975, 0.1}, + {1.6, 2.8, -3.5}, }, b: [][]float64{ {0.278}, {0.547}, - {0.958}, + {-0.958}, + {1.452}, }, x: [][]float64{ - {1.291723965752262}, - {-0.823253621853170}, + {0.820970340787782}, + {-0.218604626527306}, + {-0.212938815234215}, }, }, } { @@ -837,9 +840,8 @@ func (s *S) TestSolve(c *check.C) { c.Check(panicked, check.Equals, test.panics, check.Commentf("Test %v panicked: %s", test.name, message)) continue } - c.Check(x.EqualsApprox(NewDense(flatten(test.x)), 1e-14), check.Equals, true, check.Commentf("Test %v ", test.name)) - a.Mul(a, x) - c.Check(a.EqualsApprox(b, 1e-14), check.Equals, true, check.Commentf("Test %v ", test.name)) + + c.Check(x.EqualsApprox(NewDense(flatten(test.x)), 1e-14), check.Equals, true, check.Commentf("Test %v solution mismatch: Found %v, expected %v ", test.name, x, test.x)) } } diff --git a/mat64/qr.go b/mat64/qr.go index fc694f0f..0f64a99d 100644 --- a/mat64/qr.go +++ b/mat64/qr.go @@ -7,6 +7,8 @@ package mat64 import ( "math" + + "fmt" ) type QRFactor struct { @@ -144,7 +146,8 @@ func (f QRFactor) Q() *Dense { // A matrix x is returned that minimizes the two norm of Q*R*X-B. QRSolve will panic // if a is not full rank. The matrix b is overwritten during the call. func (f QRFactor) Solve(b *Dense) (x *Dense) { - qr, rDiag := f.QR, f.rDiag + qr := f.QR + //rDiag := f.rDiag m, n := qr.Dims() bm, bn := b.Dims() if bm != m { @@ -154,34 +157,84 @@ func (f QRFactor) Solve(b *Dense) (x *Dense) { panic("mat64: matrix is rank deficient") } - x = NewDense(n, bn, use(b.mat.Data, n*bn)) + //x = NewDense(n, bn, use(b.mat.Data, n*bn)) + x = NewDense(n, bn, nil) nx := bn + q := f.Q() + // Compute Y = transpose(Q)*B for k := 0; k < n; k++ { for j := 0; j < nx; j++ { var s float64 - for i := k; i < n; i++ { - s += qr.At(i, k) * x.At(i, j) - } - s /= -qr.At(k, k) - for i := k; i < n; i++ { - x.Set(i, j, x.At(i, j)+s*qr.At(i, k)) + for i := 0; i < m; i++ { + s += q.At(i, k) * b.At(i, j) } + x.Set(k, j, s) } } // Solve R*X = Y; - for k := n - 1; k >= 0; k-- { - for j := 0; j < nx; j++ { - x.Set(k, j, x.At(k, j)/rDiag[k]) - } - for i := 0; i < k; i++ { - for j := 0; j < nx; j++ { - x.Set(i, j, x.At(i, j)-x.At(k, j)*qr.At(i, k)) + r := f.R() + + fmt.Println("r=", r) + fmt.Println("q=", q) + fmt.Println("y=", x) + + xcopy := DenseCopyOf(x) + for j := 0; j < nx; j++ { + for k := n - 1; k >= 0; k-- { + fmt.Println("k = ", k) + val := xcopy.At(k, j) + for i := k + 1; i < n; i++ { + fmt.Println("i = ", i) + val -= r.At(k, i) * x.At(i, j) } + val /= r.At(k, k) + x.Set(k, j, val) } } + /* + for k := n - 1; k >= 0; k-- { + for j := 0; j < nx; j++ { + x.Set(k, j, x.At(k, j)/rDiag[k]) + } + for i := 0; i < k; i++ { + for j := 0; j < nx; j++ { + x.Set(i, j, x.At(i, j)-x.At(k, j)*qr.At(i, k)) + } + } + } + */ + + /* + // Compute Y = transpose(Q)*B + for k := 0; k < n; k++ { + for j := 0; j < nx; j++ { + var s float64 + for i := k; i < n; i++ { + s += qr.At(i, k) * x.At(i, j) + } + s /= -qr.At(k, k) + for i := k; i < n; i++ { + x.Set(i, j, x.At(i, j)+s*qr.At(i, k)) + } + } + } + + // Solve R*X = Y; + for k := n - 1; k >= 0; k-- { + for j := 0; j < nx; j++ { + x.Set(k, j, x.At(k, j)/rDiag[k]) + } + for i := 0; i < k; i++ { + for j := 0; j < nx; j++ { + x.Set(i, j, x.At(i, j)-x.At(k, j)*qr.At(i, k)) + } + } + } + */ + return x }