mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 07:37:03 +08:00
mat: generalise SolveVec vector parameters
This commit is contained in:
@@ -207,29 +207,34 @@ func (a *Cholesky) SolveChol(m *Dense, b *Cholesky) error {
|
|||||||
|
|
||||||
// SolveVec finds the vector v that solves A * v = b where A is represented
|
// SolveVec finds the vector v that solves A * v = b where A is represented
|
||||||
// by the Cholesky decomposition, placing the result in v.
|
// by the Cholesky decomposition, placing the result in v.
|
||||||
func (c *Cholesky) SolveVec(v, b *VecDense) error {
|
func (c *Cholesky) SolveVec(v *VecDense, b Vector) error {
|
||||||
if !c.valid() {
|
if !c.valid() {
|
||||||
panic(badCholesky)
|
panic(badCholesky)
|
||||||
}
|
}
|
||||||
n := c.chol.mat.N
|
n := c.chol.mat.N
|
||||||
vn := b.Len()
|
if br, bc := b.Dims(); br != n || bc != 1 {
|
||||||
if vn != n {
|
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
if v != b {
|
switch rv := b.(type) {
|
||||||
v.checkOverlap(b.mat)
|
default:
|
||||||
|
v.reuseAs(n)
|
||||||
|
return c.Solve(v.asDense(), b)
|
||||||
|
case RawVectorer:
|
||||||
|
bmat := rv.RawVector()
|
||||||
|
if v != b {
|
||||||
|
v.checkOverlap(bmat)
|
||||||
|
}
|
||||||
|
v.reuseAs(n)
|
||||||
|
if v != b {
|
||||||
|
v.CopyVec(b)
|
||||||
|
}
|
||||||
|
blas64.Trsv(blas.Trans, c.chol.mat, v.mat)
|
||||||
|
blas64.Trsv(blas.NoTrans, c.chol.mat, v.mat)
|
||||||
|
if c.cond > ConditionTolerance {
|
||||||
|
return Condition(c.cond)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
v.reuseAs(n)
|
|
||||||
if v != b {
|
|
||||||
v.CopyVec(b)
|
|
||||||
}
|
|
||||||
blas64.Trsv(blas.Trans, c.chol.mat, v.mat)
|
|
||||||
blas64.Trsv(blas.NoTrans, c.chol.mat, v.mat)
|
|
||||||
if c.cond > ConditionTolerance {
|
|
||||||
return Condition(c.cond)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RawU returns the Triangular matrix used to store the Cholesky decomposition of
|
// RawU returns the Triangular matrix used to store the Cholesky decomposition of
|
||||||
|
20
mat/lq.go
20
mat/lq.go
@@ -209,17 +209,27 @@ func (lq *LQ) Solve(m *Dense, trans bool, b Matrix) error {
|
|||||||
|
|
||||||
// SolveVec finds a minimum-norm solution to a system of linear equations.
|
// SolveVec finds a minimum-norm solution to a system of linear equations.
|
||||||
// Please see LQ.Solve for the full documentation.
|
// Please see LQ.Solve for the full documentation.
|
||||||
func (lq *LQ) SolveVec(v *VecDense, trans bool, b *VecDense) error {
|
func (lq *LQ) SolveVec(v *VecDense, trans bool, b Vector) error {
|
||||||
if v != b {
|
|
||||||
v.checkOverlap(b.mat)
|
|
||||||
}
|
|
||||||
r, c := lq.lq.Dims()
|
r, c := lq.lq.Dims()
|
||||||
|
if _, bc := b.Dims(); bc != 1 {
|
||||||
|
panic(ErrShape)
|
||||||
|
}
|
||||||
|
|
||||||
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
||||||
// instead recast the VecDenses as Dense and call the matrix code.
|
// instead recast the VecDenses as Dense and call the matrix code.
|
||||||
|
bm := Matrix(b)
|
||||||
|
if rv, ok := b.(RawVectorer); ok {
|
||||||
|
bmat := rv.RawVector()
|
||||||
|
if v != b {
|
||||||
|
v.checkOverlap(bmat)
|
||||||
|
}
|
||||||
|
b := VecDense{mat: bmat, n: b.Len()}
|
||||||
|
bm = b.asDense()
|
||||||
|
}
|
||||||
if trans {
|
if trans {
|
||||||
v.reuseAs(r)
|
v.reuseAs(r)
|
||||||
} else {
|
} else {
|
||||||
v.reuseAs(c)
|
v.reuseAs(c)
|
||||||
}
|
}
|
||||||
return lq.Solve(v.asDense(), trans, b.asDense())
|
return lq.Solve(v.asDense(), trans, bm)
|
||||||
}
|
}
|
||||||
|
69
mat/lu.go
69
mat/lu.go
@@ -333,41 +333,46 @@ func (lu *LU) Solve(m *Dense, trans bool, b Matrix) error {
|
|||||||
//
|
//
|
||||||
// If A is singular or near-singular a Condition error is returned. Please see
|
// If A is singular or near-singular a Condition error is returned. Please see
|
||||||
// the documentation for Condition for more information.
|
// the documentation for Condition for more information.
|
||||||
func (lu *LU) SolveVec(v *VecDense, trans bool, b *VecDense) error {
|
func (lu *LU) SolveVec(v *VecDense, trans bool, b Vector) error {
|
||||||
_, n := lu.lu.Dims()
|
_, n := lu.lu.Dims()
|
||||||
bn := b.Len()
|
if br, bc := b.Dims(); br != n || bc != 1 {
|
||||||
if bn != n {
|
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
if v != b {
|
switch rv := b.(type) {
|
||||||
v.checkOverlap(b.mat)
|
default:
|
||||||
}
|
v.reuseAs(n)
|
||||||
// TODO(btracey): Should test the condition number instead of testing that
|
return lu.Solve(v.asDense(), trans, b)
|
||||||
// the determinant is exactly zero.
|
case RawVectorer:
|
||||||
if lu.Det() == 0 {
|
if v != b {
|
||||||
return Condition(math.Inf(1))
|
v.checkOverlap(rv.RawVector())
|
||||||
}
|
}
|
||||||
|
// TODO(btracey): Should test the condition number instead of testing that
|
||||||
|
// the determinant is exactly zero.
|
||||||
|
if lu.Det() == 0 {
|
||||||
|
return Condition(math.Inf(1))
|
||||||
|
}
|
||||||
|
|
||||||
v.reuseAs(n)
|
v.reuseAs(n)
|
||||||
var restore func()
|
var restore func()
|
||||||
if v == b {
|
if v == b {
|
||||||
v, restore = v.isolatedWorkspace(b)
|
v, restore = v.isolatedWorkspace(b)
|
||||||
defer restore()
|
defer restore()
|
||||||
|
}
|
||||||
|
v.CopyVec(b)
|
||||||
|
vMat := blas64.General{
|
||||||
|
Rows: n,
|
||||||
|
Cols: 1,
|
||||||
|
Stride: v.mat.Inc,
|
||||||
|
Data: v.mat.Data,
|
||||||
|
}
|
||||||
|
t := blas.NoTrans
|
||||||
|
if trans {
|
||||||
|
t = blas.Trans
|
||||||
|
}
|
||||||
|
lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot)
|
||||||
|
if lu.cond > ConditionTolerance {
|
||||||
|
return Condition(lu.cond)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
v.CopyVec(b)
|
|
||||||
vMat := blas64.General{
|
|
||||||
Rows: n,
|
|
||||||
Cols: 1,
|
|
||||||
Stride: v.mat.Inc,
|
|
||||||
Data: v.mat.Data,
|
|
||||||
}
|
|
||||||
t := blas.NoTrans
|
|
||||||
if trans {
|
|
||||||
t = blas.Trans
|
|
||||||
}
|
|
||||||
lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot)
|
|
||||||
if lu.cond > ConditionTolerance {
|
|
||||||
return Condition(lu.cond)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
24
mat/qr.go
24
mat/qr.go
@@ -203,19 +203,31 @@ func (qr *QR) Solve(m *Dense, trans bool, b Matrix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SolveVec finds a minimum-norm solution to a system of linear equations.
|
// SolveVec finds a minimum-norm solution to a system of linear equations,
|
||||||
|
// Ax = b.
|
||||||
// Please see QR.Solve for the full documentation.
|
// Please see QR.Solve for the full documentation.
|
||||||
func (qr *QR) SolveVec(v *VecDense, trans bool, b *VecDense) error {
|
func (qr *QR) SolveVec(v *VecDense, trans bool, b Vector) error {
|
||||||
if v != b {
|
|
||||||
v.checkOverlap(b.mat)
|
|
||||||
}
|
|
||||||
r, c := qr.qr.Dims()
|
r, c := qr.qr.Dims()
|
||||||
|
if _, bc := b.Dims(); bc != 1 {
|
||||||
|
panic(ErrShape)
|
||||||
|
}
|
||||||
|
|
||||||
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
||||||
// instead recast the VecDenses as Dense and call the matrix code.
|
// instead recast the VecDenses as Dense and call the matrix code.
|
||||||
|
bm := Matrix(b)
|
||||||
|
if rv, ok := b.(RawVectorer); ok {
|
||||||
|
bmat := rv.RawVector()
|
||||||
|
if v != b {
|
||||||
|
v.checkOverlap(bmat)
|
||||||
|
}
|
||||||
|
b := VecDense{mat: bmat, n: b.Len()}
|
||||||
|
bm = b.asDense()
|
||||||
|
}
|
||||||
if trans {
|
if trans {
|
||||||
v.reuseAs(r)
|
v.reuseAs(r)
|
||||||
} else {
|
} else {
|
||||||
v.reuseAs(c)
|
v.reuseAs(c)
|
||||||
}
|
}
|
||||||
return qr.Solve(v.asDense(), trans, b.asDense())
|
return qr.Solve(v.asDense(), trans, bm)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
37
mat/solve.go
37
mat/solve.go
@@ -104,24 +104,37 @@ func (m *Dense) Solve(a, b Matrix) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SolveVec finds a minimum-norm solution to a system of linear equations defined
|
// SolveVec finds a minimum-norm solution to a system of linear equations defined
|
||||||
// by the matrix a and the right-hand side vector b. If A is singular or
|
// by the matrix a and the right-hand side column vector b. If A is singular or
|
||||||
// near-singular, a Condition error is returned. Please see the documentation for
|
// near-singular, a Condition error is returned. Please see the documentation for
|
||||||
// Dense.Solve for more information.
|
// Dense.Solve for more information.
|
||||||
func (v *VecDense) SolveVec(a Matrix, b *VecDense) error {
|
func (v *VecDense) SolveVec(a Matrix, b Vector) error {
|
||||||
if v != b {
|
if _, bc := b.Dims(); bc != 1 {
|
||||||
v.checkOverlap(b.mat)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
_, c := a.Dims()
|
_, c := a.Dims()
|
||||||
|
|
||||||
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
// The Solve implementation is non-trivial, so rather than duplicate the code,
|
||||||
// instead recast the VecDenses as Dense and call the matrix code.
|
// instead recast the VecDenses as Dense and call the matrix code.
|
||||||
|
|
||||||
|
if rv, ok := b.(RawVectorer); ok {
|
||||||
|
bmat := rv.RawVector()
|
||||||
|
if v != b {
|
||||||
|
v.checkOverlap(bmat)
|
||||||
|
}
|
||||||
|
v.reuseAs(c)
|
||||||
|
m := v.asDense()
|
||||||
|
// We conditionally create bm as m when b and v are identical
|
||||||
|
// to prevent the overlap detection code from identifying m
|
||||||
|
// and bm as overlapping but not identical.
|
||||||
|
bm := m
|
||||||
|
if v != b {
|
||||||
|
b := VecDense{mat: bmat, n: b.Len()}
|
||||||
|
bm = b.asDense()
|
||||||
|
}
|
||||||
|
return m.Solve(a, bm)
|
||||||
|
}
|
||||||
|
|
||||||
v.reuseAs(c)
|
v.reuseAs(c)
|
||||||
m := v.asDense()
|
m := v.asDense()
|
||||||
// We conditionally create bm as m when b and v are identical
|
return m.Solve(a, b)
|
||||||
// to prevent the overlap detection code from identifying m
|
|
||||||
// and bm as overlapping but not identical.
|
|
||||||
bm := m
|
|
||||||
if v != b {
|
|
||||||
bm = b.asDense()
|
|
||||||
}
|
|
||||||
return m.Solve(a, bm)
|
|
||||||
}
|
}
|
||||||
|
@@ -285,13 +285,13 @@ func TestSolveVec(t *testing.T) {
|
|||||||
// Use testTwoInput
|
// Use testTwoInput
|
||||||
method := func(receiver, a, b Matrix) {
|
method := func(receiver, a, b Matrix) {
|
||||||
type SolveVecer interface {
|
type SolveVecer interface {
|
||||||
SolveVec(a Matrix, b *VecDense) error
|
SolveVec(a Matrix, b Vector) error
|
||||||
}
|
}
|
||||||
rd := receiver.(SolveVecer)
|
rd := receiver.(SolveVecer)
|
||||||
rd.SolveVec(a, b.(*VecDense))
|
rd.SolveVec(a, b.(Vector))
|
||||||
}
|
}
|
||||||
denseComparison := func(receiver, a, b *Dense) {
|
denseComparison := func(receiver, a, b *Dense) {
|
||||||
receiver.Solve(a, b)
|
receiver.Solve(a, b)
|
||||||
}
|
}
|
||||||
testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVecDense, legalSizeSolve, 1e-12)
|
testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user