mat: Rename Solve(Vec) to Solve(Vec)To (#922)

* mat: Rename Solve(Vec) to Solev(Vec)To

Fix #830.
This commit is contained in:
Brendan Tracey
2019-03-28 01:01:36 +00:00
committed by GitHub
parent 9996f1428e
commit a65628b4b5
17 changed files with 120 additions and 118 deletions

View File

@@ -162,9 +162,9 @@ func (c *Cholesky) LogDet() float64 {
return det return det
} }
// Solve finds the matrix x that solves A * X = B where A is represented // SolveTo finds the matrix X that solves A * X = B where A is represented
// by the Cholesky decomposition, placing the result in x. // by the Cholesky decomposition. The result is stored in-place into dst.
func (c *Cholesky) Solve(x *Dense, b Matrix) error { func (c *Cholesky) SolveTo(dst *Dense, b Matrix) error {
if !c.valid() { if !c.valid() {
panic(badCholesky) panic(badCholesky)
} }
@@ -174,20 +174,21 @@ func (c *Cholesky) Solve(x *Dense, b Matrix) error {
panic(ErrShape) panic(ErrShape)
} }
x.reuseAs(bm, bn) dst.reuseAs(bm, bn)
if b != x { if b != dst {
x.Copy(b) dst.Copy(b)
} }
lapack64.Potrs(c.chol.mat, x.mat) lapack64.Potrs(c.chol.mat, dst.mat)
if c.cond > ConditionTolerance { if c.cond > ConditionTolerance {
return Condition(c.cond) return Condition(c.cond)
} }
return nil return nil
} }
// SolveChol finds the matrix x that solves A * X = B where A and B are represented // SolveCholTo finds the matrix X that solves A * X = B where A and B are represented
// by their Cholesky decompositions a and b, placing the result in x. // by their Cholesky decompositions a and b. The result is stored in-place into
func (a *Cholesky) SolveChol(x *Dense, b *Cholesky) error { // dst.
func (a *Cholesky) SolveCholTo(dst *Dense, b *Cholesky) error {
if !a.valid() || !b.valid() { if !a.valid() || !b.valid() {
panic(badCholesky) panic(badCholesky)
} }
@@ -196,20 +197,21 @@ func (a *Cholesky) SolveChol(x *Dense, b *Cholesky) error {
panic(ErrShape) panic(ErrShape)
} }
x.reuseAsZeroed(bn, bn) dst.reuseAsZeroed(bn, bn)
x.Copy(b.chol.T()) dst.Copy(b.chol.T())
blas64.Trsm(blas.Left, blas.Trans, 1, a.chol.mat, x.mat) blas64.Trsm(blas.Left, blas.Trans, 1, a.chol.mat, dst.mat)
blas64.Trsm(blas.Left, blas.NoTrans, 1, a.chol.mat, x.mat) blas64.Trsm(blas.Left, blas.NoTrans, 1, a.chol.mat, dst.mat)
blas64.Trmm(blas.Right, blas.NoTrans, 1, b.chol.mat, x.mat) blas64.Trmm(blas.Right, blas.NoTrans, 1, b.chol.mat, dst.mat)
if a.cond > ConditionTolerance { if a.cond > ConditionTolerance {
return Condition(a.cond) return Condition(a.cond)
} }
return nil return nil
} }
// SolveVec finds the vector x that solves A * x = b where A is represented // SolveVecTo finds the vector X that solves A * x = b where A is represented
// by the Cholesky decomposition, placing the result in x. // by the Cholesky decomposition. The result is stored in-place into
func (c *Cholesky) SolveVec(x *VecDense, b Vector) error { // dst.
func (c *Cholesky) SolveVecTo(dst *VecDense, b Vector) error {
if !c.valid() { if !c.valid() {
panic(badCholesky) panic(badCholesky)
} }
@@ -219,18 +221,18 @@ func (c *Cholesky) SolveVec(x *VecDense, b Vector) error {
} }
switch rv := b.(type) { switch rv := b.(type) {
default: default:
x.reuseAs(n) dst.reuseAs(n)
return c.Solve(x.asDense(), b) return c.SolveTo(dst.asDense(), b)
case RawVectorer: case RawVectorer:
bmat := rv.RawVector() bmat := rv.RawVector()
if x != b { if dst != b {
x.checkOverlap(bmat) dst.checkOverlap(bmat)
} }
x.reuseAs(n) dst.reuseAs(n)
if x != b { if dst != b {
x.CopyVec(b) dst.CopyVec(b)
} }
lapack64.Potrs(c.chol.mat, x.asGeneral()) lapack64.Potrs(c.chol.mat, dst.asGeneral())
if c.cond > ConditionTolerance { if c.cond > ConditionTolerance {
return Condition(c.cond) return Condition(c.cond)
} }

View File

@@ -35,7 +35,7 @@ func ExampleCholesky() {
// Use the factorization to solve the system of equations a * x = b. // Use the factorization to solve the system of equations a * x = b.
b := mat.NewVecDense(4, []float64{1, 2, 3, 4}) b := mat.NewVecDense(4, []float64{1, 2, 3, 4})
var x mat.VecDense var x mat.VecDense
if err := chol.SolveVec(&x, b); err != nil { if err := chol.SolveVecTo(&x, b); err != nil {
fmt.Println("Matrix is near singular: ", err) fmt.Println("Matrix is near singular: ", err)
} }
fmt.Println("Solve a * x = b") fmt.Println("Solve a * x = b")

View File

@@ -72,7 +72,7 @@ func TestCholesky(t *testing.T) {
} }
} }
func TestCholeskySolve(t *testing.T) { func TestCholeskySolveTo(t *testing.T) {
for _, test := range []struct { for _, test := range []struct {
a *SymDense a *SymDense
b *Dense b *Dense
@@ -103,7 +103,7 @@ func TestCholeskySolve(t *testing.T) {
} }
var x Dense var x Dense
chol.Solve(&x, test.b) chol.SolveTo(&x, test.b)
if !EqualApprox(&x, test.ans, 1e-12) { if !EqualApprox(&x, test.ans, 1e-12) {
t.Error("incorrect Cholesky solve solution") t.Error("incorrect Cholesky solve solution")
} }
@@ -116,7 +116,7 @@ func TestCholeskySolve(t *testing.T) {
} }
} }
func TestCholeskySolveChol(t *testing.T) { func TestCholeskySolveCholTo(t *testing.T) {
for _, test := range []struct { for _, test := range []struct {
a, b *SymDense a, b *SymDense
}{ }{
@@ -164,7 +164,7 @@ func TestCholeskySolveChol(t *testing.T) {
} }
var x Dense var x Dense
chola.SolveChol(&x, &cholb) chola.SolveCholTo(&x, &cholb)
var ans Dense var ans Dense
ans.Mul(test.a, &x) ans.Mul(test.a, &x)
@@ -177,7 +177,7 @@ func TestCholeskySolveChol(t *testing.T) {
} }
} }
func TestCholeskySolveVec(t *testing.T) { func TestCholeskySolveVecTo(t *testing.T) {
for _, test := range []struct { for _, test := range []struct {
a *SymDense a *SymDense
b *VecDense b *VecDense
@@ -208,7 +208,7 @@ func TestCholeskySolveVec(t *testing.T) {
} }
var x VecDense var x VecDense
chol.SolveVec(&x, test.b) chol.SolveVecTo(&x, test.b)
if !EqualApprox(&x, test.ans, 1e-12) { if !EqualApprox(&x, test.ans, 1e-12) {
t.Error("incorrect Cholesky solve solution") t.Error("incorrect Cholesky solve solution")
} }

View File

@@ -85,13 +85,13 @@ func (gsvd *HOGSVD) Factorize(m ...Matrix) (ok bool) {
defer putWorkspace(sij) defer putWorkspace(sij)
for i, ai := range a { for i, ai := range a {
for _, aj := range a[i+1:] { for _, aj := range a[i+1:] {
gsvd.err = ai.SolveChol(sij, &aj) gsvd.err = ai.SolveCholTo(sij, &aj)
if gsvd.err != nil { if gsvd.err != nil {
return false return false
} }
s.Add(s, sij) s.Add(s, sij)
gsvd.err = aj.SolveChol(sij, &ai) gsvd.err = aj.SolveCholTo(sij, &ai)
if gsvd.err != nil { if gsvd.err != nil {
return false return false
} }

View File

@@ -140,7 +140,7 @@ func (lq *LQ) QTo(dst *Dense) *Dense {
return dst return dst
} }
// Solve finds a minimum-norm solution to a system of linear equations defined // SolveTo finds a minimum-norm solution to a system of linear equations defined
// by the matrices A and b, where A is an m×n matrix represented in its LQ factorized // by the matrices A and b, where A is an m×n matrix represented in its LQ factorized
// form. If A is singular or near-singular a Condition error is returned. // form. If A is singular or near-singular a Condition error is returned.
// See the documentation for Condition for more information. // See the documentation for Condition for more information.
@@ -148,8 +148,8 @@ func (lq *LQ) QTo(dst *Dense) *Dense {
// The minimization problem solved depends on the input parameters. // The minimization problem solved depends on the input parameters.
// If trans == false, find the minimum norm solution of A * X = B. // If trans == false, find the minimum norm solution of A * X = B.
// If trans == true, find X such that ||A*X - B||_2 is minimized. // If trans == true, find X such that ||A*X - B||_2 is minimized.
// The solution matrix, X, is stored in place into x. // The solution matrix, X, is stored in place into dst.
func (lq *LQ) Solve(x *Dense, trans bool, b Matrix) error { func (lq *LQ) SolveTo(dst *Dense, trans bool, b Matrix) error {
r, c := lq.lq.Dims() r, c := lq.lq.Dims()
br, bc := b.Dims() br, bc := b.Dims()
@@ -161,12 +161,12 @@ func (lq *LQ) Solve(x *Dense, trans bool, b Matrix) error {
if c != br { if c != br {
panic(ErrShape) panic(ErrShape)
} }
x.reuseAs(r, bc) dst.reuseAs(r, bc)
} else { } else {
if r != br { if r != br {
panic(ErrShape) panic(ErrShape)
} }
x.reuseAs(c, bc) dst.reuseAs(c, bc)
} }
// Do not need to worry about overlap between x and b because w has its own // Do not need to worry about overlap between x and b because w has its own
// independent storage. // independent storage.
@@ -199,7 +199,7 @@ func (lq *LQ) Solve(x *Dense, trans bool, b Matrix) error {
putFloats(work) putFloats(work)
} }
// x was set above to be the correct size for the result. // x was set above to be the correct size for the result.
x.Copy(w) dst.Copy(w)
putWorkspace(w) putWorkspace(w)
if lq.cond > ConditionTolerance { if lq.cond > ConditionTolerance {
return Condition(lq.cond) return Condition(lq.cond)
@@ -207,9 +207,9 @@ func (lq *LQ) Solve(x *Dense, trans bool, b Matrix) error {
return nil return nil
} }
// SolveVec finds a minimum-norm solution to a system of linear equations. // SolveVecTo finds a minimum-norm solution to a system of linear equations.
// See LQ.Solve for the full documentation. // See LQ.SolveTo for the full documentation.
func (lq *LQ) SolveVec(x *VecDense, trans bool, b Vector) error { func (lq *LQ) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
r, c := lq.lq.Dims() r, c := lq.lq.Dims()
if _, bc := b.Dims(); bc != 1 { if _, bc := b.Dims(); bc != 1 {
panic(ErrShape) panic(ErrShape)
@@ -220,16 +220,16 @@ func (lq *LQ) SolveVec(x *VecDense, trans bool, b Vector) error {
bm := Matrix(b) bm := Matrix(b)
if rv, ok := b.(RawVectorer); ok { if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector() bmat := rv.RawVector()
if x != b { if dst != b {
x.checkOverlap(bmat) dst.checkOverlap(bmat)
} }
b := VecDense{mat: bmat} b := VecDense{mat: bmat}
bm = b.asDense() bm = b.asDense()
} }
if trans { if trans {
x.reuseAs(r) dst.reuseAs(r)
} else { } else {
x.reuseAs(c) dst.reuseAs(c)
} }
return lq.Solve(x.asDense(), trans, bm) return lq.SolveTo(dst.asDense(), trans, bm)
} }

View File

@@ -46,7 +46,7 @@ func TestLQ(t *testing.T) {
} }
} }
func TestSolveLQ(t *testing.T) { func TestLQSolveTo(t *testing.T) {
for _, trans := range []bool{false, true} { for _, trans := range []bool{false, true} {
for _, test := range []struct { for _, test := range []struct {
m, n, bc int m, n, bc int
@@ -78,7 +78,7 @@ func TestSolveLQ(t *testing.T) {
var x Dense var x Dense
lq := &LQ{} lq := &LQ{}
lq.Factorize(a) lq.Factorize(a)
lq.Solve(&x, trans, b) lq.SolveTo(&x, trans, b)
// Test that the normal equations hold. // Test that the normal equations hold.
// A^T * A * x = A^T * b if !trans // A^T * A * x = A^T * b if !trans
@@ -104,7 +104,7 @@ func TestSolveLQ(t *testing.T) {
// TODO(btracey): Add in testOneInput when it exists. // TODO(btracey): Add in testOneInput when it exists.
} }
func TestSolveLQVec(t *testing.T) { func TestLQSolveToVec(t *testing.T) {
for _, trans := range []bool{false, true} { for _, trans := range []bool{false, true} {
for _, test := range []struct { for _, test := range []struct {
m, n int m, n int
@@ -131,7 +131,7 @@ func TestSolveLQVec(t *testing.T) {
var x VecDense var x VecDense
lq := &LQ{} lq := &LQ{}
lq.Factorize(a) lq.Factorize(a)
lq.SolveVec(&x, trans, b) lq.SolveVecTo(&x, trans, b)
// Test that the normal equations hold. // Test that the normal equations hold.
// A^T * A * x = A^T * b if !trans // A^T * A * x = A^T * b if !trans
@@ -157,7 +157,7 @@ func TestSolveLQVec(t *testing.T) {
// TODO(btracey): Add in testOneInput when it exists. // TODO(btracey): Add in testOneInput when it exists.
} }
func TestSolveLQCond(t *testing.T) { func TestLQSolveToCond(t *testing.T) {
for _, test := range []*Dense{ for _, test := range []*Dense{
NewDense(2, 2, []float64{1, 0, 0, 1e-20}), NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
NewDense(2, 3, []float64{1, 0, 0, 0, 1e-20, 0}), NewDense(2, 3, []float64{1, 0, 0, 0, 1e-20, 0}),
@@ -167,13 +167,13 @@ func TestSolveLQCond(t *testing.T) {
lq.Factorize(test) lq.Factorize(test)
b := NewDense(m, 2, nil) b := NewDense(m, 2, nil)
var x Dense var x Dense
if err := lq.Solve(&x, false, b); err == nil { if err := lq.SolveTo(&x, false, b); err == nil {
t.Error("No error for near-singular matrix in matrix solve.") t.Error("No error for near-singular matrix in matrix solve.")
} }
bvec := NewVecDense(m, nil) bvec := NewVecDense(m, nil)
var xvec VecDense var xvec VecDense
if err := lq.SolveVec(&xvec, false, bvec); err == nil { if err := lq.SolveVecTo(&xvec, false, bvec); err == nil {
t.Error("No error for near-singular matrix in matrix solve.") t.Error("No error for near-singular matrix in matrix solve.")
} }
} }

View File

@@ -281,16 +281,16 @@ func (m *Dense) Permutation(r int, swaps []int) {
} }
} }
// Solve solves a system of linear equations using the LU decomposition of a matrix. // SolveTo solves a system of linear equations using the LU decomposition of a matrix.
// It computes // It computes
// A * X = B if trans == false // A * X = B if trans == false
// A^T * X = B if trans == true // A^T * X = B if trans == true
// In both cases, A is represented in LU factorized form, and the matrix X is // In both cases, A is represented in LU factorized form, and the matrix X is
// stored into x. // stored into dst.
// //
// If A is singular or near-singular a Condition error is returned. See // If A is singular or near-singular a Condition error is returned. See
// the documentation for Condition for more information. // the documentation for Condition for more information.
func (lu *LU) Solve(x *Dense, trans bool, b Matrix) error { func (lu *LU) SolveTo(dst *Dense, trans bool, b Matrix) error {
_, n := lu.lu.Dims() _, n := lu.lu.Dims()
br, bc := b.Dims() br, bc := b.Dims()
if br != n { if br != n {
@@ -302,49 +302,49 @@ func (lu *LU) Solve(x *Dense, trans bool, b Matrix) error {
return Condition(math.Inf(1)) return Condition(math.Inf(1))
} }
x.reuseAs(n, bc) dst.reuseAs(n, bc)
bU, _ := untranspose(b) bU, _ := untranspose(b)
var restore func() var restore func()
if x == bU { if dst == bU {
x, restore = x.isolatedWorkspace(bU) dst, restore = dst.isolatedWorkspace(bU)
defer restore() defer restore()
} else if rm, ok := bU.(RawMatrixer); ok { } else if rm, ok := bU.(RawMatrixer); ok {
x.checkOverlap(rm.RawMatrix()) dst.checkOverlap(rm.RawMatrix())
} }
x.Copy(b) dst.Copy(b)
t := blas.NoTrans t := blas.NoTrans
if trans { if trans {
t = blas.Trans t = blas.Trans
} }
lapack64.Getrs(t, lu.lu.mat, x.mat, lu.pivot) lapack64.Getrs(t, lu.lu.mat, dst.mat, lu.pivot)
if lu.cond > ConditionTolerance { if lu.cond > ConditionTolerance {
return Condition(lu.cond) return Condition(lu.cond)
} }
return nil return nil
} }
// SolveVec solves a system of linear equations using the LU decomposition of a matrix. // SolveVecTo solves a system of linear equations using the LU decomposition of a matrix.
// It computes // It computes
// A * x = b if trans == false // A * x = b if trans == false
// A^T * x = b if trans == true // A^T * x = b if trans == true
// In both cases, A is represented in LU factorized form, and the vector x is // In both cases, A is represented in LU factorized form, and the vector x is
// stored into x. // stored into dst.
// //
// If A is singular or near-singular a Condition error is returned. See // If A is singular or near-singular a Condition error is returned. See
// the documentation for Condition for more information. // the documentation for Condition for more information.
func (lu *LU) SolveVec(x *VecDense, trans bool, b Vector) error { func (lu *LU) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
_, n := lu.lu.Dims() _, n := lu.lu.Dims()
if br, bc := b.Dims(); br != n || bc != 1 { if br, bc := b.Dims(); br != n || bc != 1 {
panic(ErrShape) panic(ErrShape)
} }
switch rv := b.(type) { switch rv := b.(type) {
default: default:
x.reuseAs(n) dst.reuseAs(n)
return lu.Solve(x.asDense(), trans, b) return lu.SolveTo(dst.asDense(), trans, b)
case RawVectorer: case RawVectorer:
if x != b { if dst != b {
x.checkOverlap(rv.RawVector()) dst.checkOverlap(rv.RawVector())
} }
// TODO(btracey): Should test the condition number instead of testing that // TODO(btracey): Should test the condition number instead of testing that
// the determinant is exactly zero. // the determinant is exactly zero.
@@ -352,18 +352,18 @@ func (lu *LU) SolveVec(x *VecDense, trans bool, b Vector) error {
return Condition(math.Inf(1)) return Condition(math.Inf(1))
} }
x.reuseAs(n) dst.reuseAs(n)
var restore func() var restore func()
if x == b { if dst == b {
x, restore = x.isolatedWorkspace(b) dst, restore = dst.isolatedWorkspace(b)
defer restore() defer restore()
} }
x.CopyVec(b) dst.CopyVec(b)
vMat := blas64.General{ vMat := blas64.General{
Rows: n, Rows: n,
Cols: 1, Cols: 1,
Stride: x.mat.Inc, Stride: dst.mat.Inc,
Data: x.mat.Data, Data: dst.mat.Data,
} }
t := blas.NoTrans t := blas.NoTrans
if trans { if trans {

View File

@@ -104,7 +104,7 @@ func luReconstruct(lu *LU) *Dense {
return &a return &a
} }
func TestSolveLU(t *testing.T) { func TestLUSolveTo(t *testing.T) {
for _, test := range []struct { for _, test := range []struct {
n, bc int n, bc int
}{ }{
@@ -129,19 +129,19 @@ func TestSolveLU(t *testing.T) {
var lu LU var lu LU
lu.Factorize(a) lu.Factorize(a)
var x Dense var x Dense
if err := lu.Solve(&x, false, b); err != nil { if err := lu.SolveTo(&x, false, b); err != nil {
continue continue
} }
var got Dense var got Dense
got.Mul(a, &x) got.Mul(a, &x)
if !EqualApprox(&got, b, 1e-12) { if !EqualApprox(&got, b, 1e-12) {
t.Errorf("Solve mismatch for non-singular matrix. n = %v, bc = %v.\nWant: %v\nGot: %v", n, bc, b, got) t.Errorf("SolveTo mismatch for non-singular matrix. n = %v, bc = %v.\nWant: %v\nGot: %v", n, bc, b, got)
} }
} }
// TODO(btracey): Add testOneInput test when such a function exists. // TODO(btracey): Add testOneInput test when such a function exists.
} }
func TestSolveLUCond(t *testing.T) { func TestLUSolveToCond(t *testing.T) {
for _, test := range []*Dense{ for _, test := range []*Dense{
NewDense(2, 2, []float64{1, 0, 0, 1e-20}), NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
} { } {
@@ -150,19 +150,19 @@ func TestSolveLUCond(t *testing.T) {
lu.Factorize(test) lu.Factorize(test)
b := NewDense(m, 2, nil) b := NewDense(m, 2, nil)
var x Dense var x Dense
if err := lu.Solve(&x, false, b); err == nil { if err := lu.SolveTo(&x, false, b); err == nil {
t.Error("No error for near-singular matrix in matrix solve.") t.Error("No error for near-singular matrix in matrix solve.")
} }
bvec := NewVecDense(m, nil) bvec := NewVecDense(m, nil)
var xvec VecDense var xvec VecDense
if err := lu.SolveVec(&xvec, false, bvec); err == nil { if err := lu.SolveVecTo(&xvec, false, bvec); err == nil {
t.Error("No error for near-singular matrix in matrix solve.") t.Error("No error for near-singular matrix in matrix solve.")
} }
} }
} }
func TestSolveLUVec(t *testing.T) { func TestLUSolveVecTo(t *testing.T) {
for _, n := range []int{5, 10} { for _, n := range []int{5, 10} {
a := NewDense(n, n, nil) a := NewDense(n, n, nil)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
@@ -177,13 +177,13 @@ func TestSolveLUVec(t *testing.T) {
var lu LU var lu LU
lu.Factorize(a) lu.Factorize(a)
var x VecDense var x VecDense
if err := lu.SolveVec(&x, false, b); err != nil { if err := lu.SolveVecTo(&x, false, b); err != nil {
continue continue
} }
var got VecDense var got VecDense
got.MulVec(a, &x) got.MulVec(a, &x)
if !EqualApprox(&got, b, 1e-12) { if !EqualApprox(&got, b, 1e-12) {
t.Errorf("Solve mismatch n = %v.\nWant: %v\nGot: %v", n, b, got) t.Errorf("SolveTo mismatch n = %v.\nWant: %v\nGot: %v", n, b, got)
} }
} }
// TODO(btracey): Add testOneInput test when such a function exists. // TODO(btracey): Add testOneInput test when such a function exists.

View File

@@ -136,7 +136,7 @@ func (qr *QR) QTo(dst *Dense) *Dense {
return dst return dst
} }
// Solve finds a minimum-norm solution to a system of linear equations defined // SolveTo finds a minimum-norm solution to a system of linear equations defined
// by the matrices A and b, where A is an m×n matrix represented in its QR factorized // by the matrices A and b, where A is an m×n matrix represented in its QR factorized
// form. If A is singular or near-singular a Condition error is returned. // form. If A is singular or near-singular a Condition error is returned.
// See the documentation for Condition for more information. // See the documentation for Condition for more information.
@@ -144,8 +144,8 @@ func (qr *QR) QTo(dst *Dense) *Dense {
// The minimization problem solved depends on the input parameters. // The minimization problem solved depends on the input parameters.
// If trans == false, find X such that ||A*X - B||_2 is minimized. // If trans == false, find X such that ||A*X - B||_2 is minimized.
// If trans == true, find the minimum norm solution of A^T * X = B. // If trans == true, find the minimum norm solution of A^T * X = B.
// The solution matrix, X, is stored in place into m. // The solution matrix, X, is stored in place into dst.
func (qr *QR) Solve(x *Dense, trans bool, b Matrix) error { func (qr *QR) SolveTo(dst *Dense, trans bool, b Matrix) error {
r, c := qr.qr.Dims() r, c := qr.qr.Dims()
br, bc := b.Dims() br, bc := b.Dims()
@@ -157,12 +157,12 @@ func (qr *QR) Solve(x *Dense, trans bool, b Matrix) error {
if c != br { if c != br {
panic(ErrShape) panic(ErrShape)
} }
x.reuseAs(r, bc) dst.reuseAs(r, bc)
} else { } else {
if r != br { if r != br {
panic(ErrShape) panic(ErrShape)
} }
x.reuseAs(c, bc) dst.reuseAs(c, bc)
} }
// Do not need to worry about overlap between m and b because x has its own // Do not need to worry about overlap between m and b because x has its own
// independent storage. // independent storage.
@@ -195,7 +195,7 @@ func (qr *QR) Solve(x *Dense, trans bool, b Matrix) error {
} }
} }
// X was set above to be the correct size for the result. // X was set above to be the correct size for the result.
x.Copy(w) dst.Copy(w)
putWorkspace(w) putWorkspace(w)
if qr.cond > ConditionTolerance { if qr.cond > ConditionTolerance {
return Condition(qr.cond) return Condition(qr.cond)
@@ -203,10 +203,10 @@ func (qr *QR) Solve(x *Dense, trans bool, b Matrix) error {
return nil return nil
} }
// SolveVec finds a minimum-norm solution to a system of linear equations, // SolveVecTo finds a minimum-norm solution to a system of linear equations,
// Ax = b. // Ax = b.
// See QR.Solve for the full documentation. // See QR.SolveTo for the full documentation.
func (qr *QR) SolveVec(x *VecDense, trans bool, b Vector) error { func (qr *QR) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
r, c := qr.qr.Dims() r, c := qr.qr.Dims()
if _, bc := b.Dims(); bc != 1 { if _, bc := b.Dims(); bc != 1 {
panic(ErrShape) panic(ErrShape)
@@ -217,17 +217,17 @@ func (qr *QR) SolveVec(x *VecDense, trans bool, b Vector) error {
bm := Matrix(b) bm := Matrix(b)
if rv, ok := b.(RawVectorer); ok { if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector() bmat := rv.RawVector()
if x != b { if dst != b {
x.checkOverlap(bmat) dst.checkOverlap(bmat)
} }
b := VecDense{mat: bmat} b := VecDense{mat: bmat}
bm = b.asDense() bm = b.asDense()
} }
if trans { if trans {
x.reuseAs(r) dst.reuseAs(r)
} else { } else {
x.reuseAs(c) dst.reuseAs(c)
} }
return qr.Solve(x.asDense(), trans, bm) return qr.SolveTo(dst.asDense(), trans, bm)
} }

View File

@@ -70,7 +70,7 @@ func isOrthonormal(q *Dense, tol float64) bool {
return true return true
} }
func TestSolveQR(t *testing.T) { func TestQRSolveTo(t *testing.T) {
for _, trans := range []bool{false, true} { for _, trans := range []bool{false, true} {
for _, test := range []struct { for _, test := range []struct {
m, n, bc int m, n, bc int
@@ -102,7 +102,7 @@ func TestSolveQR(t *testing.T) {
var x Dense var x Dense
var qr QR var qr QR
qr.Factorize(a) qr.Factorize(a)
qr.Solve(&x, trans, b) qr.SolveTo(&x, trans, b)
// Test that the normal equations hold. // Test that the normal equations hold.
// A^T * A * x = A^T * b if !trans // A^T * A * x = A^T * b if !trans
@@ -128,7 +128,7 @@ func TestSolveQR(t *testing.T) {
// TODO(btracey): Add in testOneInput when it exists. // TODO(btracey): Add in testOneInput when it exists.
} }
func TestSolveQRVec(t *testing.T) { func TestQRSolveVecTo(t *testing.T) {
for _, trans := range []bool{false, true} { for _, trans := range []bool{false, true} {
for _, test := range []struct { for _, test := range []struct {
m, n int m, n int
@@ -155,7 +155,7 @@ func TestSolveQRVec(t *testing.T) {
var x VecDense var x VecDense
var qr QR var qr QR
qr.Factorize(a) qr.Factorize(a)
qr.SolveVec(&x, trans, b) qr.SolveVecTo(&x, trans, b)
// Test that the normal equations hold. // Test that the normal equations hold.
// A^T * A * x = A^T * b if !trans // A^T * A * x = A^T * b if !trans
@@ -181,7 +181,7 @@ func TestSolveQRVec(t *testing.T) {
// TODO(btracey): Add in testOneInput when it exists. // TODO(btracey): Add in testOneInput when it exists.
} }
func TestSolveQRCond(t *testing.T) { func TestQRSolveCondTo(t *testing.T) {
for _, test := range []*Dense{ for _, test := range []*Dense{
NewDense(2, 2, []float64{1, 0, 0, 1e-20}), NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
NewDense(3, 2, []float64{1, 0, 0, 1e-20, 0, 0}), NewDense(3, 2, []float64{1, 0, 0, 1e-20, 0, 0}),
@@ -191,13 +191,13 @@ func TestSolveQRCond(t *testing.T) {
qr.Factorize(test) qr.Factorize(test)
b := NewDense(m, 2, nil) b := NewDense(m, 2, nil)
var x Dense var x Dense
if err := qr.Solve(&x, false, b); err == nil { if err := qr.SolveTo(&x, false, b); err == nil {
t.Error("No error for near-singular matrix in matrix solve.") t.Error("No error for near-singular matrix in matrix solve.")
} }
bvec := NewVecDense(m, nil) bvec := NewVecDense(m, nil)
var xvec VecDense var xvec VecDense
if err := qr.SolveVec(&xvec, false, bvec); err == nil { if err := qr.SolveVecTo(&xvec, false, bvec); err == nil {
t.Error("No error for near-singular matrix in matrix solve.") t.Error("No error for near-singular matrix in matrix solve.")
} }
} }

View File

@@ -91,15 +91,15 @@ func (m *Dense) Solve(a, b Matrix) error {
} }
var lu LU var lu LU
lu.Factorize(a) lu.Factorize(a)
return lu.Solve(m, false, b) return lu.SolveTo(m, false, b)
case ar > ac: case ar > ac:
var qr QR var qr QR
qr.Factorize(a) qr.Factorize(a)
return qr.Solve(m, false, b) return qr.SolveTo(m, false, b)
default: default:
var lq LQ var lq LQ
lq.Factorize(a) lq.Factorize(a)
return lq.Solve(m, false, b) return lq.SolveTo(m, false, b)
} }
} }

View File

@@ -153,7 +153,7 @@ func (n *Newton) NextDirection(loc *Location, dir []float64) (stepSize float64)
pd := n.chol.Factorize(n.hess) pd := n.chol.Factorize(n.hess)
if pd { if pd {
// Store the solution in d's backing array, dir. // Store the solution in d's backing array, dir.
n.chol.SolveVec(d, grad) n.chol.SolveVecTo(d, grad)
d.ScaleVec(-1, d) d.ScaleVec(-1, d)
return 1 return 1
} }

View File

@@ -132,7 +132,7 @@ func (w *Wishart) logProbSymChol(cholX *mat.Cholesky) float64 {
cholX.UTo(&u) cholX.UTo(&u)
var vinvx mat.Dense var vinvx mat.Dense
err := w.cholv.Solve(&vinvx, u.T()) err := w.cholv.SolveTo(&vinvx, u.T())
if err != nil { if err != nil {
return math.Inf(-1) return math.Inf(-1)
} }

View File

@@ -343,7 +343,7 @@ func (n *Normal) ScoreInput(score, x []float64) []float64 {
copy(tmp, x) copy(tmp, x)
floats.Sub(tmp, n.mu) floats.Sub(tmp, n.mu)
n.chol.SolveVec(mat.NewVecDense(len(score), score), mat.NewVecDense(len(tmp), tmp)) n.chol.SolveVecTo(mat.NewVecDense(len(score), score), mat.NewVecDense(len(tmp), tmp))
floats.Scale(-1, score) floats.Scale(-1, score)
return score return score
} }

View File

@@ -195,7 +195,7 @@ func (KullbackLeibler) DistNormal(l, r *Normal) float64 {
var u mat.TriDense var u mat.TriDense
l.chol.UTo(&u) l.chol.UTo(&u)
var m mat.Dense var m mat.Dense
err := r.chol.Solve(&m, u.T()) err := r.chol.SolveTo(&m, u.T())
if err != nil { if err != nil {
return math.NaN() return math.NaN()
} }

View File

@@ -160,7 +160,7 @@ func studentsTConditional(observed []int, values []float64, nu float64, mu []flo
// Compute mu_1 + sigma_{2,1}^T * sigma_{2,2}^-1 (v - mu_2). // Compute mu_1 + sigma_{2,1}^T * sigma_{2,2}^-1 (v - mu_2).
v := mat.NewVecDense(ob, mu2) v := mat.NewVecDense(ob, mu2)
var tmp, tmp2 mat.VecDense var tmp, tmp2 mat.VecDense
err := chol.SolveVec(&tmp, v) err := chol.SolveVecTo(&tmp, v)
if err != nil { if err != nil {
return math.NaN(), nil, nil return math.NaN(), nil, nil
} }
@@ -173,7 +173,7 @@ func studentsTConditional(observed []int, values []float64, nu float64, mu []flo
// Compute tmp4 = sigma_{2,1}^T * sigma_{2,2}^-1 * sigma_{2,1}. // Compute tmp4 = sigma_{2,1}^T * sigma_{2,2}^-1 * sigma_{2,1}.
// TODO(btracey): Should this be a method of SymDense? // TODO(btracey): Should this be a method of SymDense?
var tmp3, tmp4 mat.Dense var tmp3, tmp4 mat.Dense
err = chol.Solve(&tmp3, sigma21) err = chol.SolveTo(&tmp3, sigma21)
if err != nil { if err != nil {
return math.NaN(), nil, nil return math.NaN(), nil, nil
} }

View File

@@ -138,7 +138,7 @@ func Mahalanobis(x, y mat.Vector, chol *mat.Cholesky) float64 {
var diff mat.VecDense var diff mat.VecDense
diff.SubVec(x, y) diff.SubVec(x, y)
var tmp mat.VecDense var tmp mat.VecDense
err := chol.SolveVec(&tmp, &diff) err := chol.SolveVecTo(&tmp, &diff)
if err != nil { if err != nil {
return math.NaN() return math.NaN()
} }