mat: generalise SolveVec vector parameters

This commit is contained in:
kortschak
2017-12-27 22:01:40 +10:30
committed by Dan Kortschak
parent 98fb1ed640
commit a361656bfc
6 changed files with 119 additions and 74 deletions

View File

@@ -207,29 +207,34 @@ func (a *Cholesky) SolveChol(m *Dense, b *Cholesky) error {
// SolveVec finds the vector v that solves A * v = b where A is represented // SolveVec finds the vector v that solves A * v = b where A is represented
// by the Cholesky decomposition, placing the result in v. // by the Cholesky decomposition, placing the result in v.
func (c *Cholesky) SolveVec(v, b *VecDense) error { func (c *Cholesky) SolveVec(v *VecDense, b Vector) error {
if !c.valid() { if !c.valid() {
panic(badCholesky) panic(badCholesky)
} }
n := c.chol.mat.N n := c.chol.mat.N
vn := b.Len() if br, bc := b.Dims(); br != n || bc != 1 {
if vn != n {
panic(ErrShape) panic(ErrShape)
} }
if v != b { switch rv := b.(type) {
v.checkOverlap(b.mat) default:
v.reuseAs(n)
return c.Solve(v.asDense(), b)
case RawVectorer:
bmat := rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
v.reuseAs(n)
if v != b {
v.CopyVec(b)
}
blas64.Trsv(blas.Trans, c.chol.mat, v.mat)
blas64.Trsv(blas.NoTrans, c.chol.mat, v.mat)
if c.cond > ConditionTolerance {
return Condition(c.cond)
}
return nil
} }
v.reuseAs(n)
if v != b {
v.CopyVec(b)
}
blas64.Trsv(blas.Trans, c.chol.mat, v.mat)
blas64.Trsv(blas.NoTrans, c.chol.mat, v.mat)
if c.cond > ConditionTolerance {
return Condition(c.cond)
}
return nil
} }
// RawU returns the Triangular matrix used to store the Cholesky decomposition of // RawU returns the Triangular matrix used to store the Cholesky decomposition of

View File

@@ -209,17 +209,27 @@ func (lq *LQ) Solve(m *Dense, trans bool, b Matrix) error {
// SolveVec finds a minimum-norm solution to a system of linear equations. // SolveVec finds a minimum-norm solution to a system of linear equations.
// Please see LQ.Solve for the full documentation. // Please see LQ.Solve for the full documentation.
func (lq *LQ) SolveVec(v *VecDense, trans bool, b *VecDense) error { func (lq *LQ) SolveVec(v *VecDense, trans bool, b Vector) error {
if v != b {
v.checkOverlap(b.mat)
}
r, c := lq.lq.Dims() r, c := lq.lq.Dims()
if _, bc := b.Dims(); bc != 1 {
panic(ErrShape)
}
// The Solve implementation is non-trivial, so rather than duplicate the code, // The Solve implementation is non-trivial, so rather than duplicate the code,
// instead recast the VecDenses as Dense and call the matrix code. // instead recast the VecDenses as Dense and call the matrix code.
bm := Matrix(b)
if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
b := VecDense{mat: bmat, n: b.Len()}
bm = b.asDense()
}
if trans { if trans {
v.reuseAs(r) v.reuseAs(r)
} else { } else {
v.reuseAs(c) v.reuseAs(c)
} }
return lq.Solve(v.asDense(), trans, b.asDense()) return lq.Solve(v.asDense(), trans, bm)
} }

View File

@@ -333,41 +333,46 @@ func (lu *LU) Solve(m *Dense, trans bool, b Matrix) error {
// //
// If A is singular or near-singular a Condition error is returned. Please see // If A is singular or near-singular a Condition error is returned. Please see
// the documentation for Condition for more information. // the documentation for Condition for more information.
func (lu *LU) SolveVec(v *VecDense, trans bool, b *VecDense) error { func (lu *LU) SolveVec(v *VecDense, trans bool, b Vector) error {
_, n := lu.lu.Dims() _, n := lu.lu.Dims()
bn := b.Len() if br, bc := b.Dims(); br != n || bc != 1 {
if bn != n {
panic(ErrShape) panic(ErrShape)
} }
if v != b { switch rv := b.(type) {
v.checkOverlap(b.mat) default:
} v.reuseAs(n)
// TODO(btracey): Should test the condition number instead of testing that return lu.Solve(v.asDense(), trans, b)
// the determinant is exactly zero. case RawVectorer:
if lu.Det() == 0 { if v != b {
return Condition(math.Inf(1)) v.checkOverlap(rv.RawVector())
} }
// TODO(btracey): Should test the condition number instead of testing that
// the determinant is exactly zero.
if lu.Det() == 0 {
return Condition(math.Inf(1))
}
v.reuseAs(n) v.reuseAs(n)
var restore func() var restore func()
if v == b { if v == b {
v, restore = v.isolatedWorkspace(b) v, restore = v.isolatedWorkspace(b)
defer restore() defer restore()
}
v.CopyVec(b)
vMat := blas64.General{
Rows: n,
Cols: 1,
Stride: v.mat.Inc,
Data: v.mat.Data,
}
t := blas.NoTrans
if trans {
t = blas.Trans
}
lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot)
if lu.cond > ConditionTolerance {
return Condition(lu.cond)
}
return nil
} }
v.CopyVec(b)
vMat := blas64.General{
Rows: n,
Cols: 1,
Stride: v.mat.Inc,
Data: v.mat.Data,
}
t := blas.NoTrans
if trans {
t = blas.Trans
}
lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot)
if lu.cond > ConditionTolerance {
return Condition(lu.cond)
}
return nil
} }

View File

@@ -203,19 +203,31 @@ func (qr *QR) Solve(m *Dense, trans bool, b Matrix) error {
return nil return nil
} }
// SolveVec finds a minimum-norm solution to a system of linear equations. // SolveVec finds a minimum-norm solution to a system of linear equations,
// Ax = b.
// Please see QR.Solve for the full documentation. // Please see QR.Solve for the full documentation.
func (qr *QR) SolveVec(v *VecDense, trans bool, b *VecDense) error { func (qr *QR) SolveVec(v *VecDense, trans bool, b Vector) error {
if v != b {
v.checkOverlap(b.mat)
}
r, c := qr.qr.Dims() r, c := qr.qr.Dims()
if _, bc := b.Dims(); bc != 1 {
panic(ErrShape)
}
// The Solve implementation is non-trivial, so rather than duplicate the code, // The Solve implementation is non-trivial, so rather than duplicate the code,
// instead recast the VecDenses as Dense and call the matrix code. // instead recast the VecDenses as Dense and call the matrix code.
bm := Matrix(b)
if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
b := VecDense{mat: bmat, n: b.Len()}
bm = b.asDense()
}
if trans { if trans {
v.reuseAs(r) v.reuseAs(r)
} else { } else {
v.reuseAs(c) v.reuseAs(c)
} }
return qr.Solve(v.asDense(), trans, b.asDense()) return qr.Solve(v.asDense(), trans, bm)
} }

View File

@@ -104,24 +104,37 @@ func (m *Dense) Solve(a, b Matrix) error {
} }
// SolveVec finds a minimum-norm solution to a system of linear equations defined // SolveVec finds a minimum-norm solution to a system of linear equations defined
// by the matrix a and the right-hand side vector b. If A is singular or // by the matrix a and the right-hand side column vector b. If A is singular or
// near-singular, a Condition error is returned. Please see the documentation for // near-singular, a Condition error is returned. Please see the documentation for
// Dense.Solve for more information. // Dense.Solve for more information.
func (v *VecDense) SolveVec(a Matrix, b *VecDense) error { func (v *VecDense) SolveVec(a Matrix, b Vector) error {
if v != b { if _, bc := b.Dims(); bc != 1 {
v.checkOverlap(b.mat) panic(ErrShape)
} }
_, c := a.Dims() _, c := a.Dims()
// The Solve implementation is non-trivial, so rather than duplicate the code, // The Solve implementation is non-trivial, so rather than duplicate the code,
// instead recast the VecDenses as Dense and call the matrix code. // instead recast the VecDenses as Dense and call the matrix code.
if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
v.reuseAs(c)
m := v.asDense()
// We conditionally create bm as m when b and v are identical
// to prevent the overlap detection code from identifying m
// and bm as overlapping but not identical.
bm := m
if v != b {
b := VecDense{mat: bmat, n: b.Len()}
bm = b.asDense()
}
return m.Solve(a, bm)
}
v.reuseAs(c) v.reuseAs(c)
m := v.asDense() m := v.asDense()
// We conditionally create bm as m when b and v are identical return m.Solve(a, b)
// to prevent the overlap detection code from identifying m
// and bm as overlapping but not identical.
bm := m
if v != b {
bm = b.asDense()
}
return m.Solve(a, bm)
} }

View File

@@ -285,13 +285,13 @@ func TestSolveVec(t *testing.T) {
// Use testTwoInput // Use testTwoInput
method := func(receiver, a, b Matrix) { method := func(receiver, a, b Matrix) {
type SolveVecer interface { type SolveVecer interface {
SolveVec(a Matrix, b *VecDense) error SolveVec(a Matrix, b Vector) error
} }
rd := receiver.(SolveVecer) rd := receiver.(SolveVecer)
rd.SolveVec(a, b.(*VecDense)) rd.SolveVec(a, b.(Vector))
} }
denseComparison := func(receiver, a, b *Dense) { denseComparison := func(receiver, a, b *Dense) {
receiver.Solve(a, b) receiver.Solve(a, b)
} }
testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVecDense, legalSizeSolve, 1e-12) testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12)
} }