diff --git a/mat/cholesky.go b/mat/cholesky.go index e6530fd2..d95533cf 100644 --- a/mat/cholesky.go +++ b/mat/cholesky.go @@ -207,29 +207,34 @@ func (a *Cholesky) SolveChol(m *Dense, b *Cholesky) error { // SolveVec finds the vector v that solves A * v = b where A is represented // by the Cholesky decomposition, placing the result in v. -func (c *Cholesky) SolveVec(v, b *VecDense) error { +func (c *Cholesky) SolveVec(v *VecDense, b Vector) error { if !c.valid() { panic(badCholesky) } n := c.chol.mat.N - vn := b.Len() - if vn != n { + if br, bc := b.Dims(); br != n || bc != 1 { panic(ErrShape) } - if v != b { - v.checkOverlap(b.mat) + switch rv := b.(type) { + default: + v.reuseAs(n) + return c.Solve(v.asDense(), b) + case RawVectorer: + bmat := rv.RawVector() + if v != b { + v.checkOverlap(bmat) + } + v.reuseAs(n) + if v != b { + v.CopyVec(b) + } + blas64.Trsv(blas.Trans, c.chol.mat, v.mat) + blas64.Trsv(blas.NoTrans, c.chol.mat, v.mat) + if c.cond > ConditionTolerance { + return Condition(c.cond) + } + return nil } - v.reuseAs(n) - if v != b { - v.CopyVec(b) - } - blas64.Trsv(blas.Trans, c.chol.mat, v.mat) - blas64.Trsv(blas.NoTrans, c.chol.mat, v.mat) - if c.cond > ConditionTolerance { - return Condition(c.cond) - } - return nil - } // RawU returns the Triangular matrix used to store the Cholesky decomposition of diff --git a/mat/lq.go b/mat/lq.go index c8b797a6..bc7eb95d 100644 --- a/mat/lq.go +++ b/mat/lq.go @@ -209,17 +209,27 @@ func (lq *LQ) Solve(m *Dense, trans bool, b Matrix) error { // SolveVec finds a minimum-norm solution to a system of linear equations. // Please see LQ.Solve for the full documentation. -func (lq *LQ) SolveVec(v *VecDense, trans bool, b *VecDense) error { - if v != b { - v.checkOverlap(b.mat) - } +func (lq *LQ) SolveVec(v *VecDense, trans bool, b Vector) error { r, c := lq.lq.Dims() + if _, bc := b.Dims(); bc != 1 { + panic(ErrShape) + } + // The Solve implementation is non-trivial, so rather than duplicate the code, // instead recast the VecDenses as Dense and call the matrix code. + bm := Matrix(b) + if rv, ok := b.(RawVectorer); ok { + bmat := rv.RawVector() + if v != b { + v.checkOverlap(bmat) + } + b := VecDense{mat: bmat, n: b.Len()} + bm = b.asDense() + } if trans { v.reuseAs(r) } else { v.reuseAs(c) } - return lq.Solve(v.asDense(), trans, b.asDense()) + return lq.Solve(v.asDense(), trans, bm) } diff --git a/mat/lu.go b/mat/lu.go index 4f9a462c..5bef1e36 100644 --- a/mat/lu.go +++ b/mat/lu.go @@ -333,41 +333,46 @@ func (lu *LU) Solve(m *Dense, trans bool, b Matrix) error { // // If A is singular or near-singular a Condition error is returned. Please see // the documentation for Condition for more information. -func (lu *LU) SolveVec(v *VecDense, trans bool, b *VecDense) error { +func (lu *LU) SolveVec(v *VecDense, trans bool, b Vector) error { _, n := lu.lu.Dims() - bn := b.Len() - if bn != n { + if br, bc := b.Dims(); br != n || bc != 1 { panic(ErrShape) } - if v != b { - v.checkOverlap(b.mat) - } - // TODO(btracey): Should test the condition number instead of testing that - // the determinant is exactly zero. - if lu.Det() == 0 { - return Condition(math.Inf(1)) - } + switch rv := b.(type) { + default: + v.reuseAs(n) + return lu.Solve(v.asDense(), trans, b) + case RawVectorer: + if v != b { + v.checkOverlap(rv.RawVector()) + } + // TODO(btracey): Should test the condition number instead of testing that + // the determinant is exactly zero. + if lu.Det() == 0 { + return Condition(math.Inf(1)) + } - v.reuseAs(n) - var restore func() - if v == b { - v, restore = v.isolatedWorkspace(b) - defer restore() + v.reuseAs(n) + var restore func() + if v == b { + v, restore = v.isolatedWorkspace(b) + defer restore() + } + v.CopyVec(b) + vMat := blas64.General{ + Rows: n, + Cols: 1, + Stride: v.mat.Inc, + Data: v.mat.Data, + } + t := blas.NoTrans + if trans { + t = blas.Trans + } + lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot) + if lu.cond > ConditionTolerance { + return Condition(lu.cond) + } + return nil } - v.CopyVec(b) - vMat := blas64.General{ - Rows: n, - Cols: 1, - Stride: v.mat.Inc, - Data: v.mat.Data, - } - t := blas.NoTrans - if trans { - t = blas.Trans - } - lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot) - if lu.cond > ConditionTolerance { - return Condition(lu.cond) - } - return nil } diff --git a/mat/qr.go b/mat/qr.go index 1d65ec9c..9ec0a28c 100644 --- a/mat/qr.go +++ b/mat/qr.go @@ -203,19 +203,31 @@ func (qr *QR) Solve(m *Dense, trans bool, b Matrix) error { return nil } -// SolveVec finds a minimum-norm solution to a system of linear equations. +// SolveVec finds a minimum-norm solution to a system of linear equations, +// Ax = b. // Please see QR.Solve for the full documentation. -func (qr *QR) SolveVec(v *VecDense, trans bool, b *VecDense) error { - if v != b { - v.checkOverlap(b.mat) - } +func (qr *QR) SolveVec(v *VecDense, trans bool, b Vector) error { r, c := qr.qr.Dims() + if _, bc := b.Dims(); bc != 1 { + panic(ErrShape) + } + // The Solve implementation is non-trivial, so rather than duplicate the code, // instead recast the VecDenses as Dense and call the matrix code. + bm := Matrix(b) + if rv, ok := b.(RawVectorer); ok { + bmat := rv.RawVector() + if v != b { + v.checkOverlap(bmat) + } + b := VecDense{mat: bmat, n: b.Len()} + bm = b.asDense() + } if trans { v.reuseAs(r) } else { v.reuseAs(c) } - return qr.Solve(v.asDense(), trans, b.asDense()) + return qr.Solve(v.asDense(), trans, bm) + } diff --git a/mat/solve.go b/mat/solve.go index e88812bd..c1e5b7eb 100644 --- a/mat/solve.go +++ b/mat/solve.go @@ -104,24 +104,37 @@ func (m *Dense) Solve(a, b Matrix) error { } // SolveVec finds a minimum-norm solution to a system of linear equations defined -// by the matrix a and the right-hand side vector b. If A is singular or +// by the matrix a and the right-hand side column vector b. If A is singular or // near-singular, a Condition error is returned. Please see the documentation for // Dense.Solve for more information. -func (v *VecDense) SolveVec(a Matrix, b *VecDense) error { - if v != b { - v.checkOverlap(b.mat) +func (v *VecDense) SolveVec(a Matrix, b Vector) error { + if _, bc := b.Dims(); bc != 1 { + panic(ErrShape) } _, c := a.Dims() + // The Solve implementation is non-trivial, so rather than duplicate the code, // instead recast the VecDenses as Dense and call the matrix code. + + if rv, ok := b.(RawVectorer); ok { + bmat := rv.RawVector() + if v != b { + v.checkOverlap(bmat) + } + v.reuseAs(c) + m := v.asDense() + // We conditionally create bm as m when b and v are identical + // to prevent the overlap detection code from identifying m + // and bm as overlapping but not identical. + bm := m + if v != b { + b := VecDense{mat: bmat, n: b.Len()} + bm = b.asDense() + } + return m.Solve(a, bm) + } + v.reuseAs(c) m := v.asDense() - // We conditionally create bm as m when b and v are identical - // to prevent the overlap detection code from identifying m - // and bm as overlapping but not identical. - bm := m - if v != b { - bm = b.asDense() - } - return m.Solve(a, bm) + return m.Solve(a, b) } diff --git a/mat/solve_test.go b/mat/solve_test.go index d6bd6164..7f4c7362 100644 --- a/mat/solve_test.go +++ b/mat/solve_test.go @@ -285,13 +285,13 @@ func TestSolveVec(t *testing.T) { // Use testTwoInput method := func(receiver, a, b Matrix) { type SolveVecer interface { - SolveVec(a Matrix, b *VecDense) error + SolveVec(a Matrix, b Vector) error } rd := receiver.(SolveVecer) - rd.SolveVec(a, b.(*VecDense)) + rd.SolveVec(a, b.(Vector)) } denseComparison := func(receiver, a, b *Dense) { receiver.Solve(a, b) } - testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVecDense, legalSizeSolve, 1e-12) + testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12) }