mat: generalise SolveVec vector parameters

This commit is contained in:
kortschak
2017-12-27 22:01:40 +10:30
committed by Dan Kortschak
parent 98fb1ed640
commit a361656bfc
6 changed files with 119 additions and 74 deletions

View File

@@ -207,17 +207,22 @@ func (a *Cholesky) SolveChol(m *Dense, b *Cholesky) error {
// SolveVec finds the vector v that solves A * v = b where A is represented
// by the Cholesky decomposition, placing the result in v.
func (c *Cholesky) SolveVec(v, b *VecDense) error {
func (c *Cholesky) SolveVec(v *VecDense, b Vector) error {
if !c.valid() {
panic(badCholesky)
}
n := c.chol.mat.N
vn := b.Len()
if vn != n {
if br, bc := b.Dims(); br != n || bc != 1 {
panic(ErrShape)
}
switch rv := b.(type) {
default:
v.reuseAs(n)
return c.Solve(v.asDense(), b)
case RawVectorer:
bmat := rv.RawVector()
if v != b {
v.checkOverlap(b.mat)
v.checkOverlap(bmat)
}
v.reuseAs(n)
if v != b {
@@ -229,7 +234,7 @@ func (c *Cholesky) SolveVec(v, b *VecDense) error {
return Condition(c.cond)
}
return nil
}
}
// RawU returns the Triangular matrix used to store the Cholesky decomposition of

View File

@@ -209,17 +209,27 @@ func (lq *LQ) Solve(m *Dense, trans bool, b Matrix) error {
// SolveVec finds a minimum-norm solution to a system of linear equations.
// Please see LQ.Solve for the full documentation.
func (lq *LQ) SolveVec(v *VecDense, trans bool, b *VecDense) error {
if v != b {
v.checkOverlap(b.mat)
}
func (lq *LQ) SolveVec(v *VecDense, trans bool, b Vector) error {
r, c := lq.lq.Dims()
if _, bc := b.Dims(); bc != 1 {
panic(ErrShape)
}
// The Solve implementation is non-trivial, so rather than duplicate the code,
// instead recast the VecDenses as Dense and call the matrix code.
bm := Matrix(b)
if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
b := VecDense{mat: bmat, n: b.Len()}
bm = b.asDense()
}
if trans {
v.reuseAs(r)
} else {
v.reuseAs(c)
}
return lq.Solve(v.asDense(), trans, b.asDense())
return lq.Solve(v.asDense(), trans, bm)
}

View File

@@ -333,14 +333,18 @@ func (lu *LU) Solve(m *Dense, trans bool, b Matrix) error {
//
// If A is singular or near-singular a Condition error is returned. Please see
// the documentation for Condition for more information.
func (lu *LU) SolveVec(v *VecDense, trans bool, b *VecDense) error {
func (lu *LU) SolveVec(v *VecDense, trans bool, b Vector) error {
_, n := lu.lu.Dims()
bn := b.Len()
if bn != n {
if br, bc := b.Dims(); br != n || bc != 1 {
panic(ErrShape)
}
switch rv := b.(type) {
default:
v.reuseAs(n)
return lu.Solve(v.asDense(), trans, b)
case RawVectorer:
if v != b {
v.checkOverlap(b.mat)
v.checkOverlap(rv.RawVector())
}
// TODO(btracey): Should test the condition number instead of testing that
// the determinant is exactly zero.
@@ -370,4 +374,5 @@ func (lu *LU) SolveVec(v *VecDense, trans bool, b *VecDense) error {
return Condition(lu.cond)
}
return nil
}
}

View File

@@ -203,19 +203,31 @@ func (qr *QR) Solve(m *Dense, trans bool, b Matrix) error {
return nil
}
// SolveVec finds a minimum-norm solution to a system of linear equations.
// SolveVec finds a minimum-norm solution to a system of linear equations,
// Ax = b.
// Please see QR.Solve for the full documentation.
func (qr *QR) SolveVec(v *VecDense, trans bool, b *VecDense) error {
if v != b {
v.checkOverlap(b.mat)
}
func (qr *QR) SolveVec(v *VecDense, trans bool, b Vector) error {
r, c := qr.qr.Dims()
if _, bc := b.Dims(); bc != 1 {
panic(ErrShape)
}
// The Solve implementation is non-trivial, so rather than duplicate the code,
// instead recast the VecDenses as Dense and call the matrix code.
bm := Matrix(b)
if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
b := VecDense{mat: bmat, n: b.Len()}
bm = b.asDense()
}
if trans {
v.reuseAs(r)
} else {
v.reuseAs(c)
}
return qr.Solve(v.asDense(), trans, b.asDense())
return qr.Solve(v.asDense(), trans, bm)
}

View File

@@ -104,16 +104,23 @@ func (m *Dense) Solve(a, b Matrix) error {
}
// SolveVec finds a minimum-norm solution to a system of linear equations defined
// by the matrix a and the right-hand side vector b. If A is singular or
// by the matrix a and the right-hand side column vector b. If A is singular or
// near-singular, a Condition error is returned. Please see the documentation for
// Dense.Solve for more information.
func (v *VecDense) SolveVec(a Matrix, b *VecDense) error {
if v != b {
v.checkOverlap(b.mat)
func (v *VecDense) SolveVec(a Matrix, b Vector) error {
if _, bc := b.Dims(); bc != 1 {
panic(ErrShape)
}
_, c := a.Dims()
// The Solve implementation is non-trivial, so rather than duplicate the code,
// instead recast the VecDenses as Dense and call the matrix code.
if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
v.reuseAs(c)
m := v.asDense()
// We conditionally create bm as m when b and v are identical
@@ -121,7 +128,13 @@ func (v *VecDense) SolveVec(a Matrix, b *VecDense) error {
// and bm as overlapping but not identical.
bm := m
if v != b {
b := VecDense{mat: bmat, n: b.Len()}
bm = b.asDense()
}
return m.Solve(a, bm)
}
v.reuseAs(c)
m := v.asDense()
return m.Solve(a, b)
}

View File

@@ -285,13 +285,13 @@ func TestSolveVec(t *testing.T) {
// Use testTwoInput
method := func(receiver, a, b Matrix) {
type SolveVecer interface {
SolveVec(a Matrix, b *VecDense) error
SolveVec(a Matrix, b Vector) error
}
rd := receiver.(SolveVecer)
rd.SolveVec(a, b.(*VecDense))
rd.SolveVec(a, b.(Vector))
}
denseComparison := func(receiver, a, b *Dense) {
receiver.Solve(a, b)
}
testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVecDense, legalSizeSolve, 1e-12)
testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12)
}