mirror of
https://github.com/gonum/gonum.git
synced 2025-10-04 23:02:42 +08:00
mat: Rename Solve(Vec) to Solve(Vec)To (#922)
* mat: Rename Solve(Vec) to Solev(Vec)To Fix #830.
This commit is contained in:
@@ -162,9 +162,9 @@ func (c *Cholesky) LogDet() float64 {
|
|||||||
return det
|
return det
|
||||||
}
|
}
|
||||||
|
|
||||||
// Solve finds the matrix x that solves A * X = B where A is represented
|
// SolveTo finds the matrix X that solves A * X = B where A is represented
|
||||||
// by the Cholesky decomposition, placing the result in x.
|
// by the Cholesky decomposition. The result is stored in-place into dst.
|
||||||
func (c *Cholesky) Solve(x *Dense, b Matrix) error {
|
func (c *Cholesky) SolveTo(dst *Dense, b Matrix) error {
|
||||||
if !c.valid() {
|
if !c.valid() {
|
||||||
panic(badCholesky)
|
panic(badCholesky)
|
||||||
}
|
}
|
||||||
@@ -174,20 +174,21 @@ func (c *Cholesky) Solve(x *Dense, b Matrix) error {
|
|||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
|
|
||||||
x.reuseAs(bm, bn)
|
dst.reuseAs(bm, bn)
|
||||||
if b != x {
|
if b != dst {
|
||||||
x.Copy(b)
|
dst.Copy(b)
|
||||||
}
|
}
|
||||||
lapack64.Potrs(c.chol.mat, x.mat)
|
lapack64.Potrs(c.chol.mat, dst.mat)
|
||||||
if c.cond > ConditionTolerance {
|
if c.cond > ConditionTolerance {
|
||||||
return Condition(c.cond)
|
return Condition(c.cond)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SolveChol finds the matrix x that solves A * X = B where A and B are represented
|
// SolveCholTo finds the matrix X that solves A * X = B where A and B are represented
|
||||||
// by their Cholesky decompositions a and b, placing the result in x.
|
// by their Cholesky decompositions a and b. The result is stored in-place into
|
||||||
func (a *Cholesky) SolveChol(x *Dense, b *Cholesky) error {
|
// dst.
|
||||||
|
func (a *Cholesky) SolveCholTo(dst *Dense, b *Cholesky) error {
|
||||||
if !a.valid() || !b.valid() {
|
if !a.valid() || !b.valid() {
|
||||||
panic(badCholesky)
|
panic(badCholesky)
|
||||||
}
|
}
|
||||||
@@ -196,20 +197,21 @@ func (a *Cholesky) SolveChol(x *Dense, b *Cholesky) error {
|
|||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
|
|
||||||
x.reuseAsZeroed(bn, bn)
|
dst.reuseAsZeroed(bn, bn)
|
||||||
x.Copy(b.chol.T())
|
dst.Copy(b.chol.T())
|
||||||
blas64.Trsm(blas.Left, blas.Trans, 1, a.chol.mat, x.mat)
|
blas64.Trsm(blas.Left, blas.Trans, 1, a.chol.mat, dst.mat)
|
||||||
blas64.Trsm(blas.Left, blas.NoTrans, 1, a.chol.mat, x.mat)
|
blas64.Trsm(blas.Left, blas.NoTrans, 1, a.chol.mat, dst.mat)
|
||||||
blas64.Trmm(blas.Right, blas.NoTrans, 1, b.chol.mat, x.mat)
|
blas64.Trmm(blas.Right, blas.NoTrans, 1, b.chol.mat, dst.mat)
|
||||||
if a.cond > ConditionTolerance {
|
if a.cond > ConditionTolerance {
|
||||||
return Condition(a.cond)
|
return Condition(a.cond)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SolveVec finds the vector x that solves A * x = b where A is represented
|
// SolveVecTo finds the vector X that solves A * x = b where A is represented
|
||||||
// by the Cholesky decomposition, placing the result in x.
|
// by the Cholesky decomposition. The result is stored in-place into
|
||||||
func (c *Cholesky) SolveVec(x *VecDense, b Vector) error {
|
// dst.
|
||||||
|
func (c *Cholesky) SolveVecTo(dst *VecDense, b Vector) error {
|
||||||
if !c.valid() {
|
if !c.valid() {
|
||||||
panic(badCholesky)
|
panic(badCholesky)
|
||||||
}
|
}
|
||||||
@@ -219,18 +221,18 @@ func (c *Cholesky) SolveVec(x *VecDense, b Vector) error {
|
|||||||
}
|
}
|
||||||
switch rv := b.(type) {
|
switch rv := b.(type) {
|
||||||
default:
|
default:
|
||||||
x.reuseAs(n)
|
dst.reuseAs(n)
|
||||||
return c.Solve(x.asDense(), b)
|
return c.SolveTo(dst.asDense(), b)
|
||||||
case RawVectorer:
|
case RawVectorer:
|
||||||
bmat := rv.RawVector()
|
bmat := rv.RawVector()
|
||||||
if x != b {
|
if dst != b {
|
||||||
x.checkOverlap(bmat)
|
dst.checkOverlap(bmat)
|
||||||
}
|
}
|
||||||
x.reuseAs(n)
|
dst.reuseAs(n)
|
||||||
if x != b {
|
if dst != b {
|
||||||
x.CopyVec(b)
|
dst.CopyVec(b)
|
||||||
}
|
}
|
||||||
lapack64.Potrs(c.chol.mat, x.asGeneral())
|
lapack64.Potrs(c.chol.mat, dst.asGeneral())
|
||||||
if c.cond > ConditionTolerance {
|
if c.cond > ConditionTolerance {
|
||||||
return Condition(c.cond)
|
return Condition(c.cond)
|
||||||
}
|
}
|
||||||
|
@@ -35,7 +35,7 @@ func ExampleCholesky() {
|
|||||||
// Use the factorization to solve the system of equations a * x = b.
|
// Use the factorization to solve the system of equations a * x = b.
|
||||||
b := mat.NewVecDense(4, []float64{1, 2, 3, 4})
|
b := mat.NewVecDense(4, []float64{1, 2, 3, 4})
|
||||||
var x mat.VecDense
|
var x mat.VecDense
|
||||||
if err := chol.SolveVec(&x, b); err != nil {
|
if err := chol.SolveVecTo(&x, b); err != nil {
|
||||||
fmt.Println("Matrix is near singular: ", err)
|
fmt.Println("Matrix is near singular: ", err)
|
||||||
}
|
}
|
||||||
fmt.Println("Solve a * x = b")
|
fmt.Println("Solve a * x = b")
|
||||||
|
@@ -72,7 +72,7 @@ func TestCholesky(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCholeskySolve(t *testing.T) {
|
func TestCholeskySolveTo(t *testing.T) {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
a *SymDense
|
a *SymDense
|
||||||
b *Dense
|
b *Dense
|
||||||
@@ -103,7 +103,7 @@ func TestCholeskySolve(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var x Dense
|
var x Dense
|
||||||
chol.Solve(&x, test.b)
|
chol.SolveTo(&x, test.b)
|
||||||
if !EqualApprox(&x, test.ans, 1e-12) {
|
if !EqualApprox(&x, test.ans, 1e-12) {
|
||||||
t.Error("incorrect Cholesky solve solution")
|
t.Error("incorrect Cholesky solve solution")
|
||||||
}
|
}
|
||||||
@@ -116,7 +116,7 @@ func TestCholeskySolve(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCholeskySolveChol(t *testing.T) {
|
func TestCholeskySolveCholTo(t *testing.T) {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
a, b *SymDense
|
a, b *SymDense
|
||||||
}{
|
}{
|
||||||
@@ -164,7 +164,7 @@ func TestCholeskySolveChol(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var x Dense
|
var x Dense
|
||||||
chola.SolveChol(&x, &cholb)
|
chola.SolveCholTo(&x, &cholb)
|
||||||
|
|
||||||
var ans Dense
|
var ans Dense
|
||||||
ans.Mul(test.a, &x)
|
ans.Mul(test.a, &x)
|
||||||
@@ -177,7 +177,7 @@ func TestCholeskySolveChol(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCholeskySolveVec(t *testing.T) {
|
func TestCholeskySolveVecTo(t *testing.T) {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
a *SymDense
|
a *SymDense
|
||||||
b *VecDense
|
b *VecDense
|
||||||
@@ -208,7 +208,7 @@ func TestCholeskySolveVec(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var x VecDense
|
var x VecDense
|
||||||
chol.SolveVec(&x, test.b)
|
chol.SolveVecTo(&x, test.b)
|
||||||
if !EqualApprox(&x, test.ans, 1e-12) {
|
if !EqualApprox(&x, test.ans, 1e-12) {
|
||||||
t.Error("incorrect Cholesky solve solution")
|
t.Error("incorrect Cholesky solve solution")
|
||||||
}
|
}
|
||||||
|
@@ -85,13 +85,13 @@ func (gsvd *HOGSVD) Factorize(m ...Matrix) (ok bool) {
|
|||||||
defer putWorkspace(sij)
|
defer putWorkspace(sij)
|
||||||
for i, ai := range a {
|
for i, ai := range a {
|
||||||
for _, aj := range a[i+1:] {
|
for _, aj := range a[i+1:] {
|
||||||
gsvd.err = ai.SolveChol(sij, &aj)
|
gsvd.err = ai.SolveCholTo(sij, &aj)
|
||||||
if gsvd.err != nil {
|
if gsvd.err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
s.Add(s, sij)
|
s.Add(s, sij)
|
||||||
|
|
||||||
gsvd.err = aj.SolveChol(sij, &ai)
|
gsvd.err = aj.SolveCholTo(sij, &ai)
|
||||||
if gsvd.err != nil {
|
if gsvd.err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
28
mat/lq.go
28
mat/lq.go
@@ -140,7 +140,7 @@ func (lq *LQ) QTo(dst *Dense) *Dense {
|
|||||||
return dst
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
// Solve finds a minimum-norm solution to a system of linear equations defined
|
// SolveTo finds a minimum-norm solution to a system of linear equations defined
|
||||||
// by the matrices A and b, where A is an m×n matrix represented in its LQ factorized
|
// by the matrices A and b, where A is an m×n matrix represented in its LQ factorized
|
||||||
// form. If A is singular or near-singular a Condition error is returned.
|
// form. If A is singular or near-singular a Condition error is returned.
|
||||||
// See the documentation for Condition for more information.
|
// See the documentation for Condition for more information.
|
||||||
@@ -148,8 +148,8 @@ func (lq *LQ) QTo(dst *Dense) *Dense {
|
|||||||
// The minimization problem solved depends on the input parameters.
|
// The minimization problem solved depends on the input parameters.
|
||||||
// If trans == false, find the minimum norm solution of A * X = B.
|
// If trans == false, find the minimum norm solution of A * X = B.
|
||||||
// If trans == true, find X such that ||A*X - B||_2 is minimized.
|
// If trans == true, find X such that ||A*X - B||_2 is minimized.
|
||||||
// The solution matrix, X, is stored in place into x.
|
// The solution matrix, X, is stored in place into dst.
|
||||||
func (lq *LQ) Solve(x *Dense, trans bool, b Matrix) error {
|
func (lq *LQ) SolveTo(dst *Dense, trans bool, b Matrix) error {
|
||||||
r, c := lq.lq.Dims()
|
r, c := lq.lq.Dims()
|
||||||
br, bc := b.Dims()
|
br, bc := b.Dims()
|
||||||
|
|
||||||
@@ -161,12 +161,12 @@ func (lq *LQ) Solve(x *Dense, trans bool, b Matrix) error {
|
|||||||
if c != br {
|
if c != br {
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
x.reuseAs(r, bc)
|
dst.reuseAs(r, bc)
|
||||||
} else {
|
} else {
|
||||||
if r != br {
|
if r != br {
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
x.reuseAs(c, bc)
|
dst.reuseAs(c, bc)
|
||||||
}
|
}
|
||||||
// Do not need to worry about overlap between x and b because w has its own
|
// Do not need to worry about overlap between x and b because w has its own
|
||||||
// independent storage.
|
// independent storage.
|
||||||
@@ -199,7 +199,7 @@ func (lq *LQ) Solve(x *Dense, trans bool, b Matrix) error {
|
|||||||
putFloats(work)
|
putFloats(work)
|
||||||
}
|
}
|
||||||
// x was set above to be the correct size for the result.
|
// x was set above to be the correct size for the result.
|
||||||
x.Copy(w)
|
dst.Copy(w)
|
||||||
putWorkspace(w)
|
putWorkspace(w)
|
||||||
if lq.cond > ConditionTolerance {
|
if lq.cond > ConditionTolerance {
|
||||||
return Condition(lq.cond)
|
return Condition(lq.cond)
|
||||||
@@ -207,9 +207,9 @@ func (lq *LQ) Solve(x *Dense, trans bool, b Matrix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SolveVec finds a minimum-norm solution to a system of linear equations.
|
// SolveVecTo finds a minimum-norm solution to a system of linear equations.
|
||||||
// See LQ.Solve for the full documentation.
|
// See LQ.SolveTo for the full documentation.
|
||||||
func (lq *LQ) SolveVec(x *VecDense, trans bool, b Vector) error {
|
func (lq *LQ) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
|
||||||
r, c := lq.lq.Dims()
|
r, c := lq.lq.Dims()
|
||||||
if _, bc := b.Dims(); bc != 1 {
|
if _, bc := b.Dims(); bc != 1 {
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
@@ -220,16 +220,16 @@ func (lq *LQ) SolveVec(x *VecDense, trans bool, b Vector) error {
|
|||||||
bm := Matrix(b)
|
bm := Matrix(b)
|
||||||
if rv, ok := b.(RawVectorer); ok {
|
if rv, ok := b.(RawVectorer); ok {
|
||||||
bmat := rv.RawVector()
|
bmat := rv.RawVector()
|
||||||
if x != b {
|
if dst != b {
|
||||||
x.checkOverlap(bmat)
|
dst.checkOverlap(bmat)
|
||||||
}
|
}
|
||||||
b := VecDense{mat: bmat}
|
b := VecDense{mat: bmat}
|
||||||
bm = b.asDense()
|
bm = b.asDense()
|
||||||
}
|
}
|
||||||
if trans {
|
if trans {
|
||||||
x.reuseAs(r)
|
dst.reuseAs(r)
|
||||||
} else {
|
} else {
|
||||||
x.reuseAs(c)
|
dst.reuseAs(c)
|
||||||
}
|
}
|
||||||
return lq.Solve(x.asDense(), trans, bm)
|
return lq.SolveTo(dst.asDense(), trans, bm)
|
||||||
}
|
}
|
||||||
|
@@ -46,7 +46,7 @@ func TestLQ(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSolveLQ(t *testing.T) {
|
func TestLQSolveTo(t *testing.T) {
|
||||||
for _, trans := range []bool{false, true} {
|
for _, trans := range []bool{false, true} {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
m, n, bc int
|
m, n, bc int
|
||||||
@@ -78,7 +78,7 @@ func TestSolveLQ(t *testing.T) {
|
|||||||
var x Dense
|
var x Dense
|
||||||
lq := &LQ{}
|
lq := &LQ{}
|
||||||
lq.Factorize(a)
|
lq.Factorize(a)
|
||||||
lq.Solve(&x, trans, b)
|
lq.SolveTo(&x, trans, b)
|
||||||
|
|
||||||
// Test that the normal equations hold.
|
// Test that the normal equations hold.
|
||||||
// A^T * A * x = A^T * b if !trans
|
// A^T * A * x = A^T * b if !trans
|
||||||
@@ -104,7 +104,7 @@ func TestSolveLQ(t *testing.T) {
|
|||||||
// TODO(btracey): Add in testOneInput when it exists.
|
// TODO(btracey): Add in testOneInput when it exists.
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSolveLQVec(t *testing.T) {
|
func TestLQSolveToVec(t *testing.T) {
|
||||||
for _, trans := range []bool{false, true} {
|
for _, trans := range []bool{false, true} {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
m, n int
|
m, n int
|
||||||
@@ -131,7 +131,7 @@ func TestSolveLQVec(t *testing.T) {
|
|||||||
var x VecDense
|
var x VecDense
|
||||||
lq := &LQ{}
|
lq := &LQ{}
|
||||||
lq.Factorize(a)
|
lq.Factorize(a)
|
||||||
lq.SolveVec(&x, trans, b)
|
lq.SolveVecTo(&x, trans, b)
|
||||||
|
|
||||||
// Test that the normal equations hold.
|
// Test that the normal equations hold.
|
||||||
// A^T * A * x = A^T * b if !trans
|
// A^T * A * x = A^T * b if !trans
|
||||||
@@ -157,7 +157,7 @@ func TestSolveLQVec(t *testing.T) {
|
|||||||
// TODO(btracey): Add in testOneInput when it exists.
|
// TODO(btracey): Add in testOneInput when it exists.
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSolveLQCond(t *testing.T) {
|
func TestLQSolveToCond(t *testing.T) {
|
||||||
for _, test := range []*Dense{
|
for _, test := range []*Dense{
|
||||||
NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
|
NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
|
||||||
NewDense(2, 3, []float64{1, 0, 0, 0, 1e-20, 0}),
|
NewDense(2, 3, []float64{1, 0, 0, 0, 1e-20, 0}),
|
||||||
@@ -167,13 +167,13 @@ func TestSolveLQCond(t *testing.T) {
|
|||||||
lq.Factorize(test)
|
lq.Factorize(test)
|
||||||
b := NewDense(m, 2, nil)
|
b := NewDense(m, 2, nil)
|
||||||
var x Dense
|
var x Dense
|
||||||
if err := lq.Solve(&x, false, b); err == nil {
|
if err := lq.SolveTo(&x, false, b); err == nil {
|
||||||
t.Error("No error for near-singular matrix in matrix solve.")
|
t.Error("No error for near-singular matrix in matrix solve.")
|
||||||
}
|
}
|
||||||
|
|
||||||
bvec := NewVecDense(m, nil)
|
bvec := NewVecDense(m, nil)
|
||||||
var xvec VecDense
|
var xvec VecDense
|
||||||
if err := lq.SolveVec(&xvec, false, bvec); err == nil {
|
if err := lq.SolveVecTo(&xvec, false, bvec); err == nil {
|
||||||
t.Error("No error for near-singular matrix in matrix solve.")
|
t.Error("No error for near-singular matrix in matrix solve.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
44
mat/lu.go
44
mat/lu.go
@@ -281,16 +281,16 @@ func (m *Dense) Permutation(r int, swaps []int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Solve solves a system of linear equations using the LU decomposition of a matrix.
|
// SolveTo solves a system of linear equations using the LU decomposition of a matrix.
|
||||||
// It computes
|
// It computes
|
||||||
// A * X = B if trans == false
|
// A * X = B if trans == false
|
||||||
// A^T * X = B if trans == true
|
// A^T * X = B if trans == true
|
||||||
// In both cases, A is represented in LU factorized form, and the matrix X is
|
// In both cases, A is represented in LU factorized form, and the matrix X is
|
||||||
// stored into x.
|
// stored into dst.
|
||||||
//
|
//
|
||||||
// If A is singular or near-singular a Condition error is returned. See
|
// If A is singular or near-singular a Condition error is returned. See
|
||||||
// the documentation for Condition for more information.
|
// the documentation for Condition for more information.
|
||||||
func (lu *LU) Solve(x *Dense, trans bool, b Matrix) error {
|
func (lu *LU) SolveTo(dst *Dense, trans bool, b Matrix) error {
|
||||||
_, n := lu.lu.Dims()
|
_, n := lu.lu.Dims()
|
||||||
br, bc := b.Dims()
|
br, bc := b.Dims()
|
||||||
if br != n {
|
if br != n {
|
||||||
@@ -302,49 +302,49 @@ func (lu *LU) Solve(x *Dense, trans bool, b Matrix) error {
|
|||||||
return Condition(math.Inf(1))
|
return Condition(math.Inf(1))
|
||||||
}
|
}
|
||||||
|
|
||||||
x.reuseAs(n, bc)
|
dst.reuseAs(n, bc)
|
||||||
bU, _ := untranspose(b)
|
bU, _ := untranspose(b)
|
||||||
var restore func()
|
var restore func()
|
||||||
if x == bU {
|
if dst == bU {
|
||||||
x, restore = x.isolatedWorkspace(bU)
|
dst, restore = dst.isolatedWorkspace(bU)
|
||||||
defer restore()
|
defer restore()
|
||||||
} else if rm, ok := bU.(RawMatrixer); ok {
|
} else if rm, ok := bU.(RawMatrixer); ok {
|
||||||
x.checkOverlap(rm.RawMatrix())
|
dst.checkOverlap(rm.RawMatrix())
|
||||||
}
|
}
|
||||||
|
|
||||||
x.Copy(b)
|
dst.Copy(b)
|
||||||
t := blas.NoTrans
|
t := blas.NoTrans
|
||||||
if trans {
|
if trans {
|
||||||
t = blas.Trans
|
t = blas.Trans
|
||||||
}
|
}
|
||||||
lapack64.Getrs(t, lu.lu.mat, x.mat, lu.pivot)
|
lapack64.Getrs(t, lu.lu.mat, dst.mat, lu.pivot)
|
||||||
if lu.cond > ConditionTolerance {
|
if lu.cond > ConditionTolerance {
|
||||||
return Condition(lu.cond)
|
return Condition(lu.cond)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SolveVec solves a system of linear equations using the LU decomposition of a matrix.
|
// SolveVecTo solves a system of linear equations using the LU decomposition of a matrix.
|
||||||
// It computes
|
// It computes
|
||||||
// A * x = b if trans == false
|
// A * x = b if trans == false
|
||||||
// A^T * x = b if trans == true
|
// A^T * x = b if trans == true
|
||||||
// In both cases, A is represented in LU factorized form, and the vector x is
|
// In both cases, A is represented in LU factorized form, and the vector x is
|
||||||
// stored into x.
|
// stored into dst.
|
||||||
//
|
//
|
||||||
// If A is singular or near-singular a Condition error is returned. See
|
// If A is singular or near-singular a Condition error is returned. See
|
||||||
// the documentation for Condition for more information.
|
// the documentation for Condition for more information.
|
||||||
func (lu *LU) SolveVec(x *VecDense, trans bool, b Vector) error {
|
func (lu *LU) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
|
||||||
_, n := lu.lu.Dims()
|
_, n := lu.lu.Dims()
|
||||||
if br, bc := b.Dims(); br != n || bc != 1 {
|
if br, bc := b.Dims(); br != n || bc != 1 {
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
switch rv := b.(type) {
|
switch rv := b.(type) {
|
||||||
default:
|
default:
|
||||||
x.reuseAs(n)
|
dst.reuseAs(n)
|
||||||
return lu.Solve(x.asDense(), trans, b)
|
return lu.SolveTo(dst.asDense(), trans, b)
|
||||||
case RawVectorer:
|
case RawVectorer:
|
||||||
if x != b {
|
if dst != b {
|
||||||
x.checkOverlap(rv.RawVector())
|
dst.checkOverlap(rv.RawVector())
|
||||||
}
|
}
|
||||||
// TODO(btracey): Should test the condition number instead of testing that
|
// TODO(btracey): Should test the condition number instead of testing that
|
||||||
// the determinant is exactly zero.
|
// the determinant is exactly zero.
|
||||||
@@ -352,18 +352,18 @@ func (lu *LU) SolveVec(x *VecDense, trans bool, b Vector) error {
|
|||||||
return Condition(math.Inf(1))
|
return Condition(math.Inf(1))
|
||||||
}
|
}
|
||||||
|
|
||||||
x.reuseAs(n)
|
dst.reuseAs(n)
|
||||||
var restore func()
|
var restore func()
|
||||||
if x == b {
|
if dst == b {
|
||||||
x, restore = x.isolatedWorkspace(b)
|
dst, restore = dst.isolatedWorkspace(b)
|
||||||
defer restore()
|
defer restore()
|
||||||
}
|
}
|
||||||
x.CopyVec(b)
|
dst.CopyVec(b)
|
||||||
vMat := blas64.General{
|
vMat := blas64.General{
|
||||||
Rows: n,
|
Rows: n,
|
||||||
Cols: 1,
|
Cols: 1,
|
||||||
Stride: x.mat.Inc,
|
Stride: dst.mat.Inc,
|
||||||
Data: x.mat.Data,
|
Data: dst.mat.Data,
|
||||||
}
|
}
|
||||||
t := blas.NoTrans
|
t := blas.NoTrans
|
||||||
if trans {
|
if trans {
|
||||||
|
@@ -104,7 +104,7 @@ func luReconstruct(lu *LU) *Dense {
|
|||||||
return &a
|
return &a
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSolveLU(t *testing.T) {
|
func TestLUSolveTo(t *testing.T) {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
n, bc int
|
n, bc int
|
||||||
}{
|
}{
|
||||||
@@ -129,19 +129,19 @@ func TestSolveLU(t *testing.T) {
|
|||||||
var lu LU
|
var lu LU
|
||||||
lu.Factorize(a)
|
lu.Factorize(a)
|
||||||
var x Dense
|
var x Dense
|
||||||
if err := lu.Solve(&x, false, b); err != nil {
|
if err := lu.SolveTo(&x, false, b); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var got Dense
|
var got Dense
|
||||||
got.Mul(a, &x)
|
got.Mul(a, &x)
|
||||||
if !EqualApprox(&got, b, 1e-12) {
|
if !EqualApprox(&got, b, 1e-12) {
|
||||||
t.Errorf("Solve mismatch for non-singular matrix. n = %v, bc = %v.\nWant: %v\nGot: %v", n, bc, b, got)
|
t.Errorf("SolveTo mismatch for non-singular matrix. n = %v, bc = %v.\nWant: %v\nGot: %v", n, bc, b, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO(btracey): Add testOneInput test when such a function exists.
|
// TODO(btracey): Add testOneInput test when such a function exists.
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSolveLUCond(t *testing.T) {
|
func TestLUSolveToCond(t *testing.T) {
|
||||||
for _, test := range []*Dense{
|
for _, test := range []*Dense{
|
||||||
NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
|
NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
|
||||||
} {
|
} {
|
||||||
@@ -150,19 +150,19 @@ func TestSolveLUCond(t *testing.T) {
|
|||||||
lu.Factorize(test)
|
lu.Factorize(test)
|
||||||
b := NewDense(m, 2, nil)
|
b := NewDense(m, 2, nil)
|
||||||
var x Dense
|
var x Dense
|
||||||
if err := lu.Solve(&x, false, b); err == nil {
|
if err := lu.SolveTo(&x, false, b); err == nil {
|
||||||
t.Error("No error for near-singular matrix in matrix solve.")
|
t.Error("No error for near-singular matrix in matrix solve.")
|
||||||
}
|
}
|
||||||
|
|
||||||
bvec := NewVecDense(m, nil)
|
bvec := NewVecDense(m, nil)
|
||||||
var xvec VecDense
|
var xvec VecDense
|
||||||
if err := lu.SolveVec(&xvec, false, bvec); err == nil {
|
if err := lu.SolveVecTo(&xvec, false, bvec); err == nil {
|
||||||
t.Error("No error for near-singular matrix in matrix solve.")
|
t.Error("No error for near-singular matrix in matrix solve.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSolveLUVec(t *testing.T) {
|
func TestLUSolveVecTo(t *testing.T) {
|
||||||
for _, n := range []int{5, 10} {
|
for _, n := range []int{5, 10} {
|
||||||
a := NewDense(n, n, nil)
|
a := NewDense(n, n, nil)
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
@@ -177,13 +177,13 @@ func TestSolveLUVec(t *testing.T) {
|
|||||||
var lu LU
|
var lu LU
|
||||||
lu.Factorize(a)
|
lu.Factorize(a)
|
||||||
var x VecDense
|
var x VecDense
|
||||||
if err := lu.SolveVec(&x, false, b); err != nil {
|
if err := lu.SolveVecTo(&x, false, b); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var got VecDense
|
var got VecDense
|
||||||
got.MulVec(a, &x)
|
got.MulVec(a, &x)
|
||||||
if !EqualApprox(&got, b, 1e-12) {
|
if !EqualApprox(&got, b, 1e-12) {
|
||||||
t.Errorf("Solve mismatch n = %v.\nWant: %v\nGot: %v", n, b, got)
|
t.Errorf("SolveTo mismatch n = %v.\nWant: %v\nGot: %v", n, b, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO(btracey): Add testOneInput test when such a function exists.
|
// TODO(btracey): Add testOneInput test when such a function exists.
|
||||||
|
28
mat/qr.go
28
mat/qr.go
@@ -136,7 +136,7 @@ func (qr *QR) QTo(dst *Dense) *Dense {
|
|||||||
return dst
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
// Solve finds a minimum-norm solution to a system of linear equations defined
|
// SolveTo finds a minimum-norm solution to a system of linear equations defined
|
||||||
// by the matrices A and b, where A is an m×n matrix represented in its QR factorized
|
// by the matrices A and b, where A is an m×n matrix represented in its QR factorized
|
||||||
// form. If A is singular or near-singular a Condition error is returned.
|
// form. If A is singular or near-singular a Condition error is returned.
|
||||||
// See the documentation for Condition for more information.
|
// See the documentation for Condition for more information.
|
||||||
@@ -144,8 +144,8 @@ func (qr *QR) QTo(dst *Dense) *Dense {
|
|||||||
// The minimization problem solved depends on the input parameters.
|
// The minimization problem solved depends on the input parameters.
|
||||||
// If trans == false, find X such that ||A*X - B||_2 is minimized.
|
// If trans == false, find X such that ||A*X - B||_2 is minimized.
|
||||||
// If trans == true, find the minimum norm solution of A^T * X = B.
|
// If trans == true, find the minimum norm solution of A^T * X = B.
|
||||||
// The solution matrix, X, is stored in place into m.
|
// The solution matrix, X, is stored in place into dst.
|
||||||
func (qr *QR) Solve(x *Dense, trans bool, b Matrix) error {
|
func (qr *QR) SolveTo(dst *Dense, trans bool, b Matrix) error {
|
||||||
r, c := qr.qr.Dims()
|
r, c := qr.qr.Dims()
|
||||||
br, bc := b.Dims()
|
br, bc := b.Dims()
|
||||||
|
|
||||||
@@ -157,12 +157,12 @@ func (qr *QR) Solve(x *Dense, trans bool, b Matrix) error {
|
|||||||
if c != br {
|
if c != br {
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
x.reuseAs(r, bc)
|
dst.reuseAs(r, bc)
|
||||||
} else {
|
} else {
|
||||||
if r != br {
|
if r != br {
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
x.reuseAs(c, bc)
|
dst.reuseAs(c, bc)
|
||||||
}
|
}
|
||||||
// Do not need to worry about overlap between m and b because x has its own
|
// Do not need to worry about overlap between m and b because x has its own
|
||||||
// independent storage.
|
// independent storage.
|
||||||
@@ -195,7 +195,7 @@ func (qr *QR) Solve(x *Dense, trans bool, b Matrix) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// X was set above to be the correct size for the result.
|
// X was set above to be the correct size for the result.
|
||||||
x.Copy(w)
|
dst.Copy(w)
|
||||||
putWorkspace(w)
|
putWorkspace(w)
|
||||||
if qr.cond > ConditionTolerance {
|
if qr.cond > ConditionTolerance {
|
||||||
return Condition(qr.cond)
|
return Condition(qr.cond)
|
||||||
@@ -203,10 +203,10 @@ func (qr *QR) Solve(x *Dense, trans bool, b Matrix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SolveVec finds a minimum-norm solution to a system of linear equations,
|
// SolveVecTo finds a minimum-norm solution to a system of linear equations,
|
||||||
// Ax = b.
|
// Ax = b.
|
||||||
// See QR.Solve for the full documentation.
|
// See QR.SolveTo for the full documentation.
|
||||||
func (qr *QR) SolveVec(x *VecDense, trans bool, b Vector) error {
|
func (qr *QR) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
|
||||||
r, c := qr.qr.Dims()
|
r, c := qr.qr.Dims()
|
||||||
if _, bc := b.Dims(); bc != 1 {
|
if _, bc := b.Dims(); bc != 1 {
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
@@ -217,17 +217,17 @@ func (qr *QR) SolveVec(x *VecDense, trans bool, b Vector) error {
|
|||||||
bm := Matrix(b)
|
bm := Matrix(b)
|
||||||
if rv, ok := b.(RawVectorer); ok {
|
if rv, ok := b.(RawVectorer); ok {
|
||||||
bmat := rv.RawVector()
|
bmat := rv.RawVector()
|
||||||
if x != b {
|
if dst != b {
|
||||||
x.checkOverlap(bmat)
|
dst.checkOverlap(bmat)
|
||||||
}
|
}
|
||||||
b := VecDense{mat: bmat}
|
b := VecDense{mat: bmat}
|
||||||
bm = b.asDense()
|
bm = b.asDense()
|
||||||
}
|
}
|
||||||
if trans {
|
if trans {
|
||||||
x.reuseAs(r)
|
dst.reuseAs(r)
|
||||||
} else {
|
} else {
|
||||||
x.reuseAs(c)
|
dst.reuseAs(c)
|
||||||
}
|
}
|
||||||
return qr.Solve(x.asDense(), trans, bm)
|
return qr.SolveTo(dst.asDense(), trans, bm)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -70,7 +70,7 @@ func isOrthonormal(q *Dense, tol float64) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSolveQR(t *testing.T) {
|
func TestQRSolveTo(t *testing.T) {
|
||||||
for _, trans := range []bool{false, true} {
|
for _, trans := range []bool{false, true} {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
m, n, bc int
|
m, n, bc int
|
||||||
@@ -102,7 +102,7 @@ func TestSolveQR(t *testing.T) {
|
|||||||
var x Dense
|
var x Dense
|
||||||
var qr QR
|
var qr QR
|
||||||
qr.Factorize(a)
|
qr.Factorize(a)
|
||||||
qr.Solve(&x, trans, b)
|
qr.SolveTo(&x, trans, b)
|
||||||
|
|
||||||
// Test that the normal equations hold.
|
// Test that the normal equations hold.
|
||||||
// A^T * A * x = A^T * b if !trans
|
// A^T * A * x = A^T * b if !trans
|
||||||
@@ -128,7 +128,7 @@ func TestSolveQR(t *testing.T) {
|
|||||||
// TODO(btracey): Add in testOneInput when it exists.
|
// TODO(btracey): Add in testOneInput when it exists.
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSolveQRVec(t *testing.T) {
|
func TestQRSolveVecTo(t *testing.T) {
|
||||||
for _, trans := range []bool{false, true} {
|
for _, trans := range []bool{false, true} {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
m, n int
|
m, n int
|
||||||
@@ -155,7 +155,7 @@ func TestSolveQRVec(t *testing.T) {
|
|||||||
var x VecDense
|
var x VecDense
|
||||||
var qr QR
|
var qr QR
|
||||||
qr.Factorize(a)
|
qr.Factorize(a)
|
||||||
qr.SolveVec(&x, trans, b)
|
qr.SolveVecTo(&x, trans, b)
|
||||||
|
|
||||||
// Test that the normal equations hold.
|
// Test that the normal equations hold.
|
||||||
// A^T * A * x = A^T * b if !trans
|
// A^T * A * x = A^T * b if !trans
|
||||||
@@ -181,7 +181,7 @@ func TestSolveQRVec(t *testing.T) {
|
|||||||
// TODO(btracey): Add in testOneInput when it exists.
|
// TODO(btracey): Add in testOneInput when it exists.
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSolveQRCond(t *testing.T) {
|
func TestQRSolveCondTo(t *testing.T) {
|
||||||
for _, test := range []*Dense{
|
for _, test := range []*Dense{
|
||||||
NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
|
NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
|
||||||
NewDense(3, 2, []float64{1, 0, 0, 1e-20, 0, 0}),
|
NewDense(3, 2, []float64{1, 0, 0, 1e-20, 0, 0}),
|
||||||
@@ -191,13 +191,13 @@ func TestSolveQRCond(t *testing.T) {
|
|||||||
qr.Factorize(test)
|
qr.Factorize(test)
|
||||||
b := NewDense(m, 2, nil)
|
b := NewDense(m, 2, nil)
|
||||||
var x Dense
|
var x Dense
|
||||||
if err := qr.Solve(&x, false, b); err == nil {
|
if err := qr.SolveTo(&x, false, b); err == nil {
|
||||||
t.Error("No error for near-singular matrix in matrix solve.")
|
t.Error("No error for near-singular matrix in matrix solve.")
|
||||||
}
|
}
|
||||||
|
|
||||||
bvec := NewVecDense(m, nil)
|
bvec := NewVecDense(m, nil)
|
||||||
var xvec VecDense
|
var xvec VecDense
|
||||||
if err := qr.SolveVec(&xvec, false, bvec); err == nil {
|
if err := qr.SolveVecTo(&xvec, false, bvec); err == nil {
|
||||||
t.Error("No error for near-singular matrix in matrix solve.")
|
t.Error("No error for near-singular matrix in matrix solve.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -91,15 +91,15 @@ func (m *Dense) Solve(a, b Matrix) error {
|
|||||||
}
|
}
|
||||||
var lu LU
|
var lu LU
|
||||||
lu.Factorize(a)
|
lu.Factorize(a)
|
||||||
return lu.Solve(m, false, b)
|
return lu.SolveTo(m, false, b)
|
||||||
case ar > ac:
|
case ar > ac:
|
||||||
var qr QR
|
var qr QR
|
||||||
qr.Factorize(a)
|
qr.Factorize(a)
|
||||||
return qr.Solve(m, false, b)
|
return qr.SolveTo(m, false, b)
|
||||||
default:
|
default:
|
||||||
var lq LQ
|
var lq LQ
|
||||||
lq.Factorize(a)
|
lq.Factorize(a)
|
||||||
return lq.Solve(m, false, b)
|
return lq.SolveTo(m, false, b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -153,7 +153,7 @@ func (n *Newton) NextDirection(loc *Location, dir []float64) (stepSize float64)
|
|||||||
pd := n.chol.Factorize(n.hess)
|
pd := n.chol.Factorize(n.hess)
|
||||||
if pd {
|
if pd {
|
||||||
// Store the solution in d's backing array, dir.
|
// Store the solution in d's backing array, dir.
|
||||||
n.chol.SolveVec(d, grad)
|
n.chol.SolveVecTo(d, grad)
|
||||||
d.ScaleVec(-1, d)
|
d.ScaleVec(-1, d)
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
@@ -132,7 +132,7 @@ func (w *Wishart) logProbSymChol(cholX *mat.Cholesky) float64 {
|
|||||||
cholX.UTo(&u)
|
cholX.UTo(&u)
|
||||||
|
|
||||||
var vinvx mat.Dense
|
var vinvx mat.Dense
|
||||||
err := w.cholv.Solve(&vinvx, u.T())
|
err := w.cholv.SolveTo(&vinvx, u.T())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return math.Inf(-1)
|
return math.Inf(-1)
|
||||||
}
|
}
|
||||||
|
@@ -343,7 +343,7 @@ func (n *Normal) ScoreInput(score, x []float64) []float64 {
|
|||||||
copy(tmp, x)
|
copy(tmp, x)
|
||||||
floats.Sub(tmp, n.mu)
|
floats.Sub(tmp, n.mu)
|
||||||
|
|
||||||
n.chol.SolveVec(mat.NewVecDense(len(score), score), mat.NewVecDense(len(tmp), tmp))
|
n.chol.SolveVecTo(mat.NewVecDense(len(score), score), mat.NewVecDense(len(tmp), tmp))
|
||||||
floats.Scale(-1, score)
|
floats.Scale(-1, score)
|
||||||
return score
|
return score
|
||||||
}
|
}
|
||||||
|
@@ -195,7 +195,7 @@ func (KullbackLeibler) DistNormal(l, r *Normal) float64 {
|
|||||||
var u mat.TriDense
|
var u mat.TriDense
|
||||||
l.chol.UTo(&u)
|
l.chol.UTo(&u)
|
||||||
var m mat.Dense
|
var m mat.Dense
|
||||||
err := r.chol.Solve(&m, u.T())
|
err := r.chol.SolveTo(&m, u.T())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return math.NaN()
|
return math.NaN()
|
||||||
}
|
}
|
||||||
|
@@ -160,7 +160,7 @@ func studentsTConditional(observed []int, values []float64, nu float64, mu []flo
|
|||||||
// Compute mu_1 + sigma_{2,1}^T * sigma_{2,2}^-1 (v - mu_2).
|
// Compute mu_1 + sigma_{2,1}^T * sigma_{2,2}^-1 (v - mu_2).
|
||||||
v := mat.NewVecDense(ob, mu2)
|
v := mat.NewVecDense(ob, mu2)
|
||||||
var tmp, tmp2 mat.VecDense
|
var tmp, tmp2 mat.VecDense
|
||||||
err := chol.SolveVec(&tmp, v)
|
err := chol.SolveVecTo(&tmp, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return math.NaN(), nil, nil
|
return math.NaN(), nil, nil
|
||||||
}
|
}
|
||||||
@@ -173,7 +173,7 @@ func studentsTConditional(observed []int, values []float64, nu float64, mu []flo
|
|||||||
// Compute tmp4 = sigma_{2,1}^T * sigma_{2,2}^-1 * sigma_{2,1}.
|
// Compute tmp4 = sigma_{2,1}^T * sigma_{2,2}^-1 * sigma_{2,1}.
|
||||||
// TODO(btracey): Should this be a method of SymDense?
|
// TODO(btracey): Should this be a method of SymDense?
|
||||||
var tmp3, tmp4 mat.Dense
|
var tmp3, tmp4 mat.Dense
|
||||||
err = chol.Solve(&tmp3, sigma21)
|
err = chol.SolveTo(&tmp3, sigma21)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return math.NaN(), nil, nil
|
return math.NaN(), nil, nil
|
||||||
}
|
}
|
||||||
|
@@ -138,7 +138,7 @@ func Mahalanobis(x, y mat.Vector, chol *mat.Cholesky) float64 {
|
|||||||
var diff mat.VecDense
|
var diff mat.VecDense
|
||||||
diff.SubVec(x, y)
|
diff.SubVec(x, y)
|
||||||
var tmp mat.VecDense
|
var tmp mat.VecDense
|
||||||
err := chol.SolveVec(&tmp, &diff)
|
err := chol.SolveVecTo(&tmp, &diff)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return math.NaN()
|
return math.NaN()
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user