mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 23:52:47 +08:00
mat: generalise SolveVec vector parameters
This commit is contained in:
@@ -207,17 +207,22 @@ func (a *Cholesky) SolveChol(m *Dense, b *Cholesky) error {
|
|||||||
|
|
||||||
// SolveVec finds the vector v that solves A * v = b where A is represented
|
// SolveVec finds the vector v that solves A * v = b where A is represented
|
||||||
// by the Cholesky decomposition, placing the result in v.
|
// by the Cholesky decomposition, placing the result in v.
|
||||||
func (c *Cholesky) SolveVec(v, b *VecDense) error {
|
func (c *Cholesky) SolveVec(v *VecDense, b Vector) error {
|
||||||
if !c.valid() {
|
if !c.valid() {
|
||||||
panic(badCholesky)
|
panic(badCholesky)
|
||||||
}
|
}
|
||||||
n := c.chol.mat.N
|
n := c.chol.mat.N
|
||||||
vn := b.Len()
|
if br, bc := b.Dims(); br != n || bc != 1 {
|
||||||
if vn != n {
|
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
|
switch rv := b.(type) {
|
||||||
|
default:
|
||||||
|
v.reuseAs(n)
|
||||||
|
return c.Solve(v.asDense(), b)
|
||||||
|
case RawVectorer:
|
||||||
|
bmat := rv.RawVector()
|
||||||
if v != b {
|
if v != b {
|
||||||
v.checkOverlap(b.mat)
|
v.checkOverlap(bmat)
|
||||||
}
|
}
|
||||||
v.reuseAs(n)
|
v.reuseAs(n)
|
||||||
if v != b {
|
if v != b {
|
||||||
@@ -229,7 +234,7 @@ func (c *Cholesky) SolveVec(v, b *VecDense) error {
|
|||||||
return Condition(c.cond)
|
return Condition(c.cond)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RawU returns the Triangular matrix used to store the Cholesky decomposition of
|
// RawU returns the Triangular matrix used to store the Cholesky decomposition of
|
||||||
|
20
mat/lq.go
20
mat/lq.go
@@ -209,17 +209,27 @@ func (lq *LQ) Solve(m *Dense, trans bool, b Matrix) error {
|
|||||||
|
|
||||||
// SolveVec finds a minimum-norm solution to a system of linear equations.
|
// SolveVec finds a minimum-norm solution to a system of linear equations.
|
||||||
// Please see LQ.Solve for the full documentation.
|
// Please see LQ.Solve for the full documentation.
|
||||||
func (lq *LQ) SolveVec(v *VecDense, trans bool, b *VecDense) error {
|
func (lq *LQ) SolveVec(v *VecDense, trans bool, b Vector) error {
|
||||||
if v != b {
|
|
||||||
v.checkOverlap(b.mat)
|
|
||||||
}
|
|
||||||
r, c := lq.lq.Dims()
|
r, c := lq.lq.Dims()
|
||||||
|
if _, bc := b.Dims(); bc != 1 {
|
||||||
|
panic(ErrShape)
|
||||||
|
}
|
||||||
|
|
||||||
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
||||||
// instead recast the VecDenses as Dense and call the matrix code.
|
// instead recast the VecDenses as Dense and call the matrix code.
|
||||||
|
bm := Matrix(b)
|
||||||
|
if rv, ok := b.(RawVectorer); ok {
|
||||||
|
bmat := rv.RawVector()
|
||||||
|
if v != b {
|
||||||
|
v.checkOverlap(bmat)
|
||||||
|
}
|
||||||
|
b := VecDense{mat: bmat, n: b.Len()}
|
||||||
|
bm = b.asDense()
|
||||||
|
}
|
||||||
if trans {
|
if trans {
|
||||||
v.reuseAs(r)
|
v.reuseAs(r)
|
||||||
} else {
|
} else {
|
||||||
v.reuseAs(c)
|
v.reuseAs(c)
|
||||||
}
|
}
|
||||||
return lq.Solve(v.asDense(), trans, b.asDense())
|
return lq.Solve(v.asDense(), trans, bm)
|
||||||
}
|
}
|
||||||
|
13
mat/lu.go
13
mat/lu.go
@@ -333,14 +333,18 @@ func (lu *LU) Solve(m *Dense, trans bool, b Matrix) error {
|
|||||||
//
|
//
|
||||||
// If A is singular or near-singular a Condition error is returned. Please see
|
// If A is singular or near-singular a Condition error is returned. Please see
|
||||||
// the documentation for Condition for more information.
|
// the documentation for Condition for more information.
|
||||||
func (lu *LU) SolveVec(v *VecDense, trans bool, b *VecDense) error {
|
func (lu *LU) SolveVec(v *VecDense, trans bool, b Vector) error {
|
||||||
_, n := lu.lu.Dims()
|
_, n := lu.lu.Dims()
|
||||||
bn := b.Len()
|
if br, bc := b.Dims(); br != n || bc != 1 {
|
||||||
if bn != n {
|
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
|
switch rv := b.(type) {
|
||||||
|
default:
|
||||||
|
v.reuseAs(n)
|
||||||
|
return lu.Solve(v.asDense(), trans, b)
|
||||||
|
case RawVectorer:
|
||||||
if v != b {
|
if v != b {
|
||||||
v.checkOverlap(b.mat)
|
v.checkOverlap(rv.RawVector())
|
||||||
}
|
}
|
||||||
// TODO(btracey): Should test the condition number instead of testing that
|
// TODO(btracey): Should test the condition number instead of testing that
|
||||||
// the determinant is exactly zero.
|
// the determinant is exactly zero.
|
||||||
@@ -371,3 +375,4 @@ func (lu *LU) SolveVec(v *VecDense, trans bool, b *VecDense) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
24
mat/qr.go
24
mat/qr.go
@@ -203,19 +203,31 @@ func (qr *QR) Solve(m *Dense, trans bool, b Matrix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SolveVec finds a minimum-norm solution to a system of linear equations.
|
// SolveVec finds a minimum-norm solution to a system of linear equations,
|
||||||
|
// Ax = b.
|
||||||
// Please see QR.Solve for the full documentation.
|
// Please see QR.Solve for the full documentation.
|
||||||
func (qr *QR) SolveVec(v *VecDense, trans bool, b *VecDense) error {
|
func (qr *QR) SolveVec(v *VecDense, trans bool, b Vector) error {
|
||||||
if v != b {
|
|
||||||
v.checkOverlap(b.mat)
|
|
||||||
}
|
|
||||||
r, c := qr.qr.Dims()
|
r, c := qr.qr.Dims()
|
||||||
|
if _, bc := b.Dims(); bc != 1 {
|
||||||
|
panic(ErrShape)
|
||||||
|
}
|
||||||
|
|
||||||
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
||||||
// instead recast the VecDenses as Dense and call the matrix code.
|
// instead recast the VecDenses as Dense and call the matrix code.
|
||||||
|
bm := Matrix(b)
|
||||||
|
if rv, ok := b.(RawVectorer); ok {
|
||||||
|
bmat := rv.RawVector()
|
||||||
|
if v != b {
|
||||||
|
v.checkOverlap(bmat)
|
||||||
|
}
|
||||||
|
b := VecDense{mat: bmat, n: b.Len()}
|
||||||
|
bm = b.asDense()
|
||||||
|
}
|
||||||
if trans {
|
if trans {
|
||||||
v.reuseAs(r)
|
v.reuseAs(r)
|
||||||
} else {
|
} else {
|
||||||
v.reuseAs(c)
|
v.reuseAs(c)
|
||||||
}
|
}
|
||||||
return qr.Solve(v.asDense(), trans, b.asDense())
|
return qr.Solve(v.asDense(), trans, bm)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
21
mat/solve.go
21
mat/solve.go
@@ -104,16 +104,23 @@ func (m *Dense) Solve(a, b Matrix) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SolveVec finds a minimum-norm solution to a system of linear equations defined
|
// SolveVec finds a minimum-norm solution to a system of linear equations defined
|
||||||
// by the matrix a and the right-hand side vector b. If A is singular or
|
// by the matrix a and the right-hand side column vector b. If A is singular or
|
||||||
// near-singular, a Condition error is returned. Please see the documentation for
|
// near-singular, a Condition error is returned. Please see the documentation for
|
||||||
// Dense.Solve for more information.
|
// Dense.Solve for more information.
|
||||||
func (v *VecDense) SolveVec(a Matrix, b *VecDense) error {
|
func (v *VecDense) SolveVec(a Matrix, b Vector) error {
|
||||||
if v != b {
|
if _, bc := b.Dims(); bc != 1 {
|
||||||
v.checkOverlap(b.mat)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
_, c := a.Dims()
|
_, c := a.Dims()
|
||||||
|
|
||||||
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
||||||
// instead recast the VecDenses as Dense and call the matrix code.
|
// instead recast the VecDenses as Dense and call the matrix code.
|
||||||
|
|
||||||
|
if rv, ok := b.(RawVectorer); ok {
|
||||||
|
bmat := rv.RawVector()
|
||||||
|
if v != b {
|
||||||
|
v.checkOverlap(bmat)
|
||||||
|
}
|
||||||
v.reuseAs(c)
|
v.reuseAs(c)
|
||||||
m := v.asDense()
|
m := v.asDense()
|
||||||
// We conditionally create bm as m when b and v are identical
|
// We conditionally create bm as m when b and v are identical
|
||||||
@@ -121,7 +128,13 @@ func (v *VecDense) SolveVec(a Matrix, b *VecDense) error {
|
|||||||
// and bm as overlapping but not identical.
|
// and bm as overlapping but not identical.
|
||||||
bm := m
|
bm := m
|
||||||
if v != b {
|
if v != b {
|
||||||
|
b := VecDense{mat: bmat, n: b.Len()}
|
||||||
bm = b.asDense()
|
bm = b.asDense()
|
||||||
}
|
}
|
||||||
return m.Solve(a, bm)
|
return m.Solve(a, bm)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
v.reuseAs(c)
|
||||||
|
m := v.asDense()
|
||||||
|
return m.Solve(a, b)
|
||||||
|
}
|
||||||
|
@@ -285,13 +285,13 @@ func TestSolveVec(t *testing.T) {
|
|||||||
// Use testTwoInput
|
// Use testTwoInput
|
||||||
method := func(receiver, a, b Matrix) {
|
method := func(receiver, a, b Matrix) {
|
||||||
type SolveVecer interface {
|
type SolveVecer interface {
|
||||||
SolveVec(a Matrix, b *VecDense) error
|
SolveVec(a Matrix, b Vector) error
|
||||||
}
|
}
|
||||||
rd := receiver.(SolveVecer)
|
rd := receiver.(SolveVecer)
|
||||||
rd.SolveVec(a, b.(*VecDense))
|
rd.SolveVec(a, b.(Vector))
|
||||||
}
|
}
|
||||||
denseComparison := func(receiver, a, b *Dense) {
|
denseComparison := func(receiver, a, b *Dense) {
|
||||||
receiver.Solve(a, b)
|
receiver.Solve(a, b)
|
||||||
}
|
}
|
||||||
testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVecDense, legalSizeSolve, 1e-12)
|
testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user