mat: generalise SolveVec vector parameters

This commit is contained in:
kortschak
2017-12-27 22:01:40 +10:30
committed by Dan Kortschak
parent 98fb1ed640
commit a361656bfc
6 changed files with 119 additions and 74 deletions

View File

@@ -207,17 +207,22 @@ func (a *Cholesky) SolveChol(m *Dense, b *Cholesky) error {
// SolveVec finds the vector v that solves A * v = b where A is represented // SolveVec finds the vector v that solves A * v = b where A is represented
// by the Cholesky decomposition, placing the result in v. // by the Cholesky decomposition, placing the result in v.
func (c *Cholesky) SolveVec(v, b *VecDense) error { func (c *Cholesky) SolveVec(v *VecDense, b Vector) error {
if !c.valid() { if !c.valid() {
panic(badCholesky) panic(badCholesky)
} }
n := c.chol.mat.N n := c.chol.mat.N
vn := b.Len() if br, bc := b.Dims(); br != n || bc != 1 {
if vn != n {
panic(ErrShape) panic(ErrShape)
} }
switch rv := b.(type) {
default:
v.reuseAs(n)
return c.Solve(v.asDense(), b)
case RawVectorer:
bmat := rv.RawVector()
if v != b { if v != b {
v.checkOverlap(b.mat) v.checkOverlap(bmat)
} }
v.reuseAs(n) v.reuseAs(n)
if v != b { if v != b {
@@ -229,7 +234,7 @@ func (c *Cholesky) SolveVec(v, b *VecDense) error {
return Condition(c.cond) return Condition(c.cond)
} }
return nil return nil
}
} }
// RawU returns the Triangular matrix used to store the Cholesky decomposition of // RawU returns the Triangular matrix used to store the Cholesky decomposition of

View File

@@ -209,17 +209,27 @@ func (lq *LQ) Solve(m *Dense, trans bool, b Matrix) error {
// SolveVec finds a minimum-norm solution to a system of linear equations. // SolveVec finds a minimum-norm solution to a system of linear equations.
// Please see LQ.Solve for the full documentation. // Please see LQ.Solve for the full documentation.
func (lq *LQ) SolveVec(v *VecDense, trans bool, b *VecDense) error { func (lq *LQ) SolveVec(v *VecDense, trans bool, b Vector) error {
if v != b {
v.checkOverlap(b.mat)
}
r, c := lq.lq.Dims() r, c := lq.lq.Dims()
if _, bc := b.Dims(); bc != 1 {
panic(ErrShape)
}
// The Solve implementation is non-trivial, so rather than duplicate the code, // The Solve implementation is non-trivial, so rather than duplicate the code,
// instead recast the VecDenses as Dense and call the matrix code. // instead recast the VecDenses as Dense and call the matrix code.
bm := Matrix(b)
if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
b := VecDense{mat: bmat, n: b.Len()}
bm = b.asDense()
}
if trans { if trans {
v.reuseAs(r) v.reuseAs(r)
} else { } else {
v.reuseAs(c) v.reuseAs(c)
} }
return lq.Solve(v.asDense(), trans, b.asDense()) return lq.Solve(v.asDense(), trans, bm)
} }

View File

@@ -333,14 +333,18 @@ func (lu *LU) Solve(m *Dense, trans bool, b Matrix) error {
// //
// If A is singular or near-singular a Condition error is returned. Please see // If A is singular or near-singular a Condition error is returned. Please see
// the documentation for Condition for more information. // the documentation for Condition for more information.
func (lu *LU) SolveVec(v *VecDense, trans bool, b *VecDense) error { func (lu *LU) SolveVec(v *VecDense, trans bool, b Vector) error {
_, n := lu.lu.Dims() _, n := lu.lu.Dims()
bn := b.Len() if br, bc := b.Dims(); br != n || bc != 1 {
if bn != n {
panic(ErrShape) panic(ErrShape)
} }
switch rv := b.(type) {
default:
v.reuseAs(n)
return lu.Solve(v.asDense(), trans, b)
case RawVectorer:
if v != b { if v != b {
v.checkOverlap(b.mat) v.checkOverlap(rv.RawVector())
} }
// TODO(btracey): Should test the condition number instead of testing that // TODO(btracey): Should test the condition number instead of testing that
// the determinant is exactly zero. // the determinant is exactly zero.
@@ -371,3 +375,4 @@ func (lu *LU) SolveVec(v *VecDense, trans bool, b *VecDense) error {
} }
return nil return nil
} }
}

View File

@@ -203,19 +203,31 @@ func (qr *QR) Solve(m *Dense, trans bool, b Matrix) error {
return nil return nil
} }
// SolveVec finds a minimum-norm solution to a system of linear equations. // SolveVec finds a minimum-norm solution to a system of linear equations,
// Ax = b.
// Please see QR.Solve for the full documentation. // Please see QR.Solve for the full documentation.
func (qr *QR) SolveVec(v *VecDense, trans bool, b *VecDense) error { func (qr *QR) SolveVec(v *VecDense, trans bool, b Vector) error {
if v != b {
v.checkOverlap(b.mat)
}
r, c := qr.qr.Dims() r, c := qr.qr.Dims()
if _, bc := b.Dims(); bc != 1 {
panic(ErrShape)
}
// The Solve implementation is non-trivial, so rather than duplicate the code, // The Solve implementation is non-trivial, so rather than duplicate the code,
// instead recast the VecDenses as Dense and call the matrix code. // instead recast the VecDenses as Dense and call the matrix code.
bm := Matrix(b)
if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
b := VecDense{mat: bmat, n: b.Len()}
bm = b.asDense()
}
if trans { if trans {
v.reuseAs(r) v.reuseAs(r)
} else { } else {
v.reuseAs(c) v.reuseAs(c)
} }
return qr.Solve(v.asDense(), trans, b.asDense()) return qr.Solve(v.asDense(), trans, bm)
} }

View File

@@ -104,16 +104,23 @@ func (m *Dense) Solve(a, b Matrix) error {
} }
// SolveVec finds a minimum-norm solution to a system of linear equations defined // SolveVec finds a minimum-norm solution to a system of linear equations defined
// by the matrix a and the right-hand side vector b. If A is singular or // by the matrix a and the right-hand side column vector b. If A is singular or
// near-singular, a Condition error is returned. Please see the documentation for // near-singular, a Condition error is returned. Please see the documentation for
// Dense.Solve for more information. // Dense.Solve for more information.
func (v *VecDense) SolveVec(a Matrix, b *VecDense) error { func (v *VecDense) SolveVec(a Matrix, b Vector) error {
if v != b { if _, bc := b.Dims(); bc != 1 {
v.checkOverlap(b.mat) panic(ErrShape)
} }
_, c := a.Dims() _, c := a.Dims()
// The Solve implementation is non-trivial, so rather than duplicate the code, // The Solve implementation is non-trivial, so rather than duplicate the code,
// instead recast the VecDenses as Dense and call the matrix code. // instead recast the VecDenses as Dense and call the matrix code.
if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
v.reuseAs(c) v.reuseAs(c)
m := v.asDense() m := v.asDense()
// We conditionally create bm as m when b and v are identical // We conditionally create bm as m when b and v are identical
@@ -121,7 +128,13 @@ func (v *VecDense) SolveVec(a Matrix, b *VecDense) error {
// and bm as overlapping but not identical. // and bm as overlapping but not identical.
bm := m bm := m
if v != b { if v != b {
b := VecDense{mat: bmat, n: b.Len()}
bm = b.asDense() bm = b.asDense()
} }
return m.Solve(a, bm) return m.Solve(a, bm)
} }
v.reuseAs(c)
m := v.asDense()
return m.Solve(a, b)
}

View File

@@ -285,13 +285,13 @@ func TestSolveVec(t *testing.T) {
// Use testTwoInput // Use testTwoInput
method := func(receiver, a, b Matrix) { method := func(receiver, a, b Matrix) {
type SolveVecer interface { type SolveVecer interface {
SolveVec(a Matrix, b *VecDense) error SolveVec(a Matrix, b Vector) error
} }
rd := receiver.(SolveVecer) rd := receiver.(SolveVecer)
rd.SolveVec(a, b.(*VecDense)) rd.SolveVec(a, b.(Vector))
} }
denseComparison := func(receiver, a, b *Dense) { denseComparison := func(receiver, a, b *Dense) {
receiver.Solve(a, b) receiver.Solve(a, b)
} }
testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVecDense, legalSizeSolve, 1e-12) testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12)
} }