mat: harmonise parameter naming and documentation

This commit is contained in:
Dan Kortschak
2018-02-15 14:31:01 +10:30
committed by Dan Kortschak
parent 8d2aaa4a38
commit 64df69126c
5 changed files with 101 additions and 101 deletions

View File

@@ -159,9 +159,9 @@ func (c *Cholesky) LogDet() float64 {
return det return det
} }
// Solve finds the matrix m that solves A * m = b where A is represented // Solve finds the matrix x that solves A * X = B where A is represented
// by the Cholesky decomposition, placing the result in m. // by the Cholesky decomposition, placing the result in x.
func (c *Cholesky) Solve(m *Dense, b Matrix) error { func (c *Cholesky) Solve(x *Dense, b Matrix) error {
if !c.valid() { if !c.valid() {
panic(badCholesky) panic(badCholesky)
} }
@@ -171,21 +171,21 @@ func (c *Cholesky) Solve(m *Dense, b Matrix) error {
panic(ErrShape) panic(ErrShape)
} }
m.reuseAs(bm, bn) x.reuseAs(bm, bn)
if b != m { if b != x {
m.Copy(b) x.Copy(b)
} }
blas64.Trsm(blas.Left, blas.Trans, 1, c.chol.mat, m.mat) blas64.Trsm(blas.Left, blas.Trans, 1, c.chol.mat, x.mat)
blas64.Trsm(blas.Left, blas.NoTrans, 1, c.chol.mat, m.mat) blas64.Trsm(blas.Left, blas.NoTrans, 1, c.chol.mat, x.mat)
if c.cond > ConditionTolerance { if c.cond > ConditionTolerance {
return Condition(c.cond) return Condition(c.cond)
} }
return nil return nil
} }
// SolveChol finds the matrix m that solves A * m = B where A and B are represented // SolveChol finds the matrix x that solves A * X = B where A and B are represented
// by their Cholesky decompositions a and b, placing the result in the receiver. // by their Cholesky decompositions a and b, placing the result in the receiver.
func (a *Cholesky) SolveChol(m *Dense, b *Cholesky) error { func (a *Cholesky) SolveChol(x *Dense, b *Cholesky) error {
if !a.valid() || !b.valid() { if !a.valid() || !b.valid() {
panic(badCholesky) panic(badCholesky)
} }
@@ -194,20 +194,20 @@ func (a *Cholesky) SolveChol(m *Dense, b *Cholesky) error {
panic(ErrShape) panic(ErrShape)
} }
m.reuseAsZeroed(bn, bn) x.reuseAsZeroed(bn, bn)
m.Copy(b.chol.T()) x.Copy(b.chol.T())
blas64.Trsm(blas.Left, blas.Trans, 1, a.chol.mat, m.mat) blas64.Trsm(blas.Left, blas.Trans, 1, a.chol.mat, x.mat)
blas64.Trsm(blas.Left, blas.NoTrans, 1, a.chol.mat, m.mat) blas64.Trsm(blas.Left, blas.NoTrans, 1, a.chol.mat, x.mat)
blas64.Trmm(blas.Right, blas.NoTrans, 1, b.chol.mat, m.mat) blas64.Trmm(blas.Right, blas.NoTrans, 1, b.chol.mat, x.mat)
if a.cond > ConditionTolerance { if a.cond > ConditionTolerance {
return Condition(a.cond) return Condition(a.cond)
} }
return nil return nil
} }
// SolveVec finds the vector v that solves A * v = b where A is represented // SolveVec finds the vector x that solves A * x = b where A is represented
// by the Cholesky decomposition, placing the result in v. // by the Cholesky decomposition, placing the result in x.
func (c *Cholesky) SolveVec(v *VecDense, b Vector) error { func (c *Cholesky) SolveVec(x *VecDense, b Vector) error {
if !c.valid() { if !c.valid() {
panic(badCholesky) panic(badCholesky)
} }
@@ -217,19 +217,19 @@ func (c *Cholesky) SolveVec(v *VecDense, b Vector) error {
} }
switch rv := b.(type) { switch rv := b.(type) {
default: default:
v.reuseAs(n) x.reuseAs(n)
return c.Solve(v.asDense(), b) return c.Solve(x.asDense(), b)
case RawVectorer: case RawVectorer:
bmat := rv.RawVector() bmat := rv.RawVector()
if v != b { if x != b {
v.checkOverlap(bmat) x.checkOverlap(bmat)
} }
v.reuseAs(n) x.reuseAs(n)
if v != b { if x != b {
v.CopyVec(b) x.CopyVec(b)
} }
blas64.Trsv(blas.Trans, c.chol.mat, v.mat) blas64.Trsv(blas.Trans, c.chol.mat, x.mat)
blas64.Trsv(blas.NoTrans, c.chol.mat, v.mat) blas64.Trsv(blas.NoTrans, c.chol.mat, x.mat)
if c.cond > ConditionTolerance { if c.cond > ConditionTolerance {
return Condition(c.cond) return Condition(c.cond)
} }

View File

@@ -146,61 +146,61 @@ func (lq *LQ) QTo(dst *Dense) *Dense {
// See the documentation for Condition for more information. // See the documentation for Condition for more information.
// //
// The minimization problem solved depends on the input parameters. // The minimization problem solved depends on the input parameters.
// If trans == false, find the minimum norm solution of A * X = b. // If trans == false, find the minimum norm solution of A * X = B.
// If trans == true, find X such that ||A*X - b||_2 is minimized. // If trans == true, find X such that ||A*X - B||_2 is minimized.
// The solution matrix, X, is stored in place into m. // The solution matrix, X, is stored in place into x.
func (lq *LQ) Solve(m *Dense, trans bool, b Matrix) error { func (lq *LQ) Solve(x *Dense, trans bool, b Matrix) error {
r, c := lq.lq.Dims() r, c := lq.lq.Dims()
br, bc := b.Dims() br, bc := b.Dims()
// The LQ solve algorithm stores the result in-place into the right hand side. // The LQ solve algorithm stores the result in-place into the right hand side.
// The storage for the answer must be large enough to hold both b and x. // The storage for the answer must be large enough to hold both b and x.
// However, this method's receiver must be the size of x. Copy b, and then // However, this method's receiver must be the size of x. Copy b, and then
// copy the result into m at the end. // copy the result into x at the end.
if trans { if trans {
if c != br { if c != br {
panic(ErrShape) panic(ErrShape)
} }
m.reuseAs(r, bc) x.reuseAs(r, bc)
} else { } else {
if r != br { if r != br {
panic(ErrShape) panic(ErrShape)
} }
m.reuseAs(c, bc) x.reuseAs(c, bc)
} }
// Do not need to worry about overlap between m and b because x has its own // Do not need to worry about overlap between x and b because w has its own
// independent storage. // independent storage.
x := getWorkspace(max(r, c), bc, false) w := getWorkspace(max(r, c), bc, false)
x.Copy(b) w.Copy(b)
t := lq.lq.asTriDense(lq.lq.mat.Rows, blas.NonUnit, blas.Lower).mat t := lq.lq.asTriDense(lq.lq.mat.Rows, blas.NonUnit, blas.Lower).mat
if trans { if trans {
work := []float64{0} work := []float64{0}
lapack64.Ormlq(blas.Left, blas.NoTrans, lq.lq.mat, lq.tau, x.mat, work, -1) lapack64.Ormlq(blas.Left, blas.NoTrans, lq.lq.mat, lq.tau, w.mat, work, -1)
work = getFloats(int(work[0]), false) work = getFloats(int(work[0]), false)
lapack64.Ormlq(blas.Left, blas.NoTrans, lq.lq.mat, lq.tau, x.mat, work, len(work)) lapack64.Ormlq(blas.Left, blas.NoTrans, lq.lq.mat, lq.tau, w.mat, work, len(work))
putFloats(work) putFloats(work)
ok := lapack64.Trtrs(blas.Trans, t, x.mat) ok := lapack64.Trtrs(blas.Trans, t, w.mat)
if !ok { if !ok {
return Condition(math.Inf(1)) return Condition(math.Inf(1))
} }
} else { } else {
ok := lapack64.Trtrs(blas.NoTrans, t, x.mat) ok := lapack64.Trtrs(blas.NoTrans, t, w.mat)
if !ok { if !ok {
return Condition(math.Inf(1)) return Condition(math.Inf(1))
} }
for i := r; i < c; i++ { for i := r; i < c; i++ {
zero(x.mat.Data[i*x.mat.Stride : i*x.mat.Stride+bc]) zero(w.mat.Data[i*w.mat.Stride : i*w.mat.Stride+bc])
} }
work := []float64{0} work := []float64{0}
lapack64.Ormlq(blas.Left, blas.Trans, lq.lq.mat, lq.tau, x.mat, work, -1) lapack64.Ormlq(blas.Left, blas.Trans, lq.lq.mat, lq.tau, w.mat, work, -1)
work = getFloats(int(work[0]), false) work = getFloats(int(work[0]), false)
lapack64.Ormlq(blas.Left, blas.Trans, lq.lq.mat, lq.tau, x.mat, work, len(work)) lapack64.Ormlq(blas.Left, blas.Trans, lq.lq.mat, lq.tau, w.mat, work, len(work))
putFloats(work) putFloats(work)
} }
// M was set above to be the correct size for the result. // x was set above to be the correct size for the result.
m.Copy(x) x.Copy(w)
putWorkspace(x) putWorkspace(w)
if lq.cond > ConditionTolerance { if lq.cond > ConditionTolerance {
return Condition(lq.cond) return Condition(lq.cond)
} }
@@ -209,7 +209,7 @@ func (lq *LQ) Solve(m *Dense, trans bool, b Matrix) error {
// SolveVec finds a minimum-norm solution to a system of linear equations. // SolveVec finds a minimum-norm solution to a system of linear equations.
// See LQ.Solve for the full documentation. // See LQ.Solve for the full documentation.
func (lq *LQ) SolveVec(v *VecDense, trans bool, b Vector) error { func (lq *LQ) SolveVec(x *VecDense, trans bool, b Vector) error {
r, c := lq.lq.Dims() r, c := lq.lq.Dims()
if _, bc := b.Dims(); bc != 1 { if _, bc := b.Dims(); bc != 1 {
panic(ErrShape) panic(ErrShape)
@@ -220,16 +220,16 @@ func (lq *LQ) SolveVec(v *VecDense, trans bool, b Vector) error {
bm := Matrix(b) bm := Matrix(b)
if rv, ok := b.(RawVectorer); ok { if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector() bmat := rv.RawVector()
if v != b { if x != b {
v.checkOverlap(bmat) x.checkOverlap(bmat)
} }
b := VecDense{mat: bmat, n: b.Len()} b := VecDense{mat: bmat, n: b.Len()}
bm = b.asDense() bm = b.asDense()
} }
if trans { if trans {
v.reuseAs(r) x.reuseAs(r)
} else { } else {
v.reuseAs(c) x.reuseAs(c)
} }
return lq.Solve(v.asDense(), trans, bm) return lq.Solve(x.asDense(), trans, bm)
} }

View File

@@ -283,14 +283,14 @@ func (m *Dense) Permutation(r int, swaps []int) {
// Solve solves a system of linear equations using the LU decomposition of a matrix. // Solve solves a system of linear equations using the LU decomposition of a matrix.
// It computes // It computes
// A * x = b if trans == false // A * X = B if trans == false
// A^T * x = b if trans == true // A^T * X = B if trans == true
// In both cases, A is represented in LU factorized form, and the matrix x is // In both cases, A is represented in LU factorized form, and the matrix X is
// stored into m. // stored into x.
// //
// If A is singular or near-singular a Condition error is returned. See // If A is singular or near-singular a Condition error is returned. See
// the documentation for Condition for more information. // the documentation for Condition for more information.
func (lu *LU) Solve(m *Dense, trans bool, b Matrix) error { func (lu *LU) Solve(x *Dense, trans bool, b Matrix) error {
_, n := lu.lu.Dims() _, n := lu.lu.Dims()
br, bc := b.Dims() br, bc := b.Dims()
if br != n { if br != n {
@@ -302,22 +302,22 @@ func (lu *LU) Solve(m *Dense, trans bool, b Matrix) error {
return Condition(math.Inf(1)) return Condition(math.Inf(1))
} }
m.reuseAs(n, bc) x.reuseAs(n, bc)
bU, _ := untranspose(b) bU, _ := untranspose(b)
var restore func() var restore func()
if m == bU { if x == bU {
m, restore = m.isolatedWorkspace(bU) x, restore = x.isolatedWorkspace(bU)
defer restore() defer restore()
} else if rm, ok := bU.(RawMatrixer); ok { } else if rm, ok := bU.(RawMatrixer); ok {
m.checkOverlap(rm.RawMatrix()) x.checkOverlap(rm.RawMatrix())
} }
m.Copy(b) x.Copy(b)
t := blas.NoTrans t := blas.NoTrans
if trans { if trans {
t = blas.Trans t = blas.Trans
} }
lapack64.Getrs(t, lu.lu.mat, m.mat, lu.pivot) lapack64.Getrs(t, lu.lu.mat, x.mat, lu.pivot)
if lu.cond > ConditionTolerance { if lu.cond > ConditionTolerance {
return Condition(lu.cond) return Condition(lu.cond)
} }
@@ -328,23 +328,23 @@ func (lu *LU) Solve(m *Dense, trans bool, b Matrix) error {
// It computes // It computes
// A * x = b if trans == false // A * x = b if trans == false
// A^T * x = b if trans == true // A^T * x = b if trans == true
// In both cases, A is represented in LU factorized form, and the matrix x is // In both cases, A is represented in LU factorized form, and the vector x is
// stored into v. // stored into x.
// //
// If A is singular or near-singular a Condition error is returned. See // If A is singular or near-singular a Condition error is returned. See
// the documentation for Condition for more information. // the documentation for Condition for more information.
func (lu *LU) SolveVec(v *VecDense, trans bool, b Vector) error { func (lu *LU) SolveVec(x *VecDense, trans bool, b Vector) error {
_, n := lu.lu.Dims() _, n := lu.lu.Dims()
if br, bc := b.Dims(); br != n || bc != 1 { if br, bc := b.Dims(); br != n || bc != 1 {
panic(ErrShape) panic(ErrShape)
} }
switch rv := b.(type) { switch rv := b.(type) {
default: default:
v.reuseAs(n) x.reuseAs(n)
return lu.Solve(v.asDense(), trans, b) return lu.Solve(x.asDense(), trans, b)
case RawVectorer: case RawVectorer:
if v != b { if x != b {
v.checkOverlap(rv.RawVector()) x.checkOverlap(rv.RawVector())
} }
// TODO(btracey): Should test the condition number instead of testing that // TODO(btracey): Should test the condition number instead of testing that
// the determinant is exactly zero. // the determinant is exactly zero.
@@ -352,18 +352,18 @@ func (lu *LU) SolveVec(v *VecDense, trans bool, b Vector) error {
return Condition(math.Inf(1)) return Condition(math.Inf(1))
} }
v.reuseAs(n) x.reuseAs(n)
var restore func() var restore func()
if v == b { if x == b {
v, restore = v.isolatedWorkspace(b) x, restore = x.isolatedWorkspace(b)
defer restore() defer restore()
} }
v.CopyVec(b) x.CopyVec(b)
vMat := blas64.General{ vMat := blas64.General{
Rows: n, Rows: n,
Cols: 1, Cols: 1,
Stride: v.mat.Inc, Stride: x.mat.Inc,
Data: v.mat.Data, Data: x.mat.Data,
} }
t := blas.NoTrans t := blas.NoTrans
if trans { if trans {

View File

@@ -142,10 +142,10 @@ func (qr *QR) QTo(dst *Dense) *Dense {
// See the documentation for Condition for more information. // See the documentation for Condition for more information.
// //
// The minimization problem solved depends on the input parameters. // The minimization problem solved depends on the input parameters.
// If trans == false, find X such that ||A*X - b||_2 is minimized. // If trans == false, find X such that ||A*X - B||_2 is minimized.
// If trans == true, find the minimum norm solution of A^T * X = b. // If trans == true, find the minimum norm solution of A^T * X = B.
// The solution matrix, X, is stored in place into m. // The solution matrix, X, is stored in place into m.
func (qr *QR) Solve(m *Dense, trans bool, b Matrix) error { func (qr *QR) Solve(x *Dense, trans bool, b Matrix) error {
r, c := qr.qr.Dims() r, c := qr.qr.Dims()
br, bc := b.Dims() br, bc := b.Dims()
@@ -157,46 +157,46 @@ func (qr *QR) Solve(m *Dense, trans bool, b Matrix) error {
if c != br { if c != br {
panic(ErrShape) panic(ErrShape)
} }
m.reuseAs(r, bc) x.reuseAs(r, bc)
} else { } else {
if r != br { if r != br {
panic(ErrShape) panic(ErrShape)
} }
m.reuseAs(c, bc) x.reuseAs(c, bc)
} }
// Do not need to worry about overlap between m and b because x has its own // Do not need to worry about overlap between m and b because x has its own
// independent storage. // independent storage.
x := getWorkspace(max(r, c), bc, false) w := getWorkspace(max(r, c), bc, false)
x.Copy(b) w.Copy(b)
t := qr.qr.asTriDense(qr.qr.mat.Cols, blas.NonUnit, blas.Upper).mat t := qr.qr.asTriDense(qr.qr.mat.Cols, blas.NonUnit, blas.Upper).mat
if trans { if trans {
ok := lapack64.Trtrs(blas.Trans, t, x.mat) ok := lapack64.Trtrs(blas.Trans, t, w.mat)
if !ok { if !ok {
return Condition(math.Inf(1)) return Condition(math.Inf(1))
} }
for i := c; i < r; i++ { for i := c; i < r; i++ {
zero(x.mat.Data[i*x.mat.Stride : i*x.mat.Stride+bc]) zero(w.mat.Data[i*w.mat.Stride : i*w.mat.Stride+bc])
} }
work := []float64{0} work := []float64{0}
lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, x.mat, work, -1) lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, -1)
work = getFloats(int(work[0]), false) work = getFloats(int(work[0]), false)
lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, x.mat, work, len(work)) lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, len(work))
putFloats(work) putFloats(work)
} else { } else {
work := []float64{0} work := []float64{0}
lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, x.mat, work, -1) lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, -1)
work = getFloats(int(work[0]), false) work = getFloats(int(work[0]), false)
lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, x.mat, work, len(work)) lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, len(work))
putFloats(work) putFloats(work)
ok := lapack64.Trtrs(blas.NoTrans, t, x.mat) ok := lapack64.Trtrs(blas.NoTrans, t, w.mat)
if !ok { if !ok {
return Condition(math.Inf(1)) return Condition(math.Inf(1))
} }
} }
// M was set above to be the correct size for the result. // X was set above to be the correct size for the result.
m.Copy(x) x.Copy(w)
putWorkspace(x) putWorkspace(w)
if qr.cond > ConditionTolerance { if qr.cond > ConditionTolerance {
return Condition(qr.cond) return Condition(qr.cond)
} }
@@ -206,7 +206,7 @@ func (qr *QR) Solve(m *Dense, trans bool, b Matrix) error {
// SolveVec finds a minimum-norm solution to a system of linear equations, // SolveVec finds a minimum-norm solution to a system of linear equations,
// Ax = b. // Ax = b.
// See QR.Solve for the full documentation. // See QR.Solve for the full documentation.
func (qr *QR) SolveVec(v *VecDense, trans bool, b Vector) error { func (qr *QR) SolveVec(x *VecDense, trans bool, b Vector) error {
r, c := qr.qr.Dims() r, c := qr.qr.Dims()
if _, bc := b.Dims(); bc != 1 { if _, bc := b.Dims(); bc != 1 {
panic(ErrShape) panic(ErrShape)
@@ -217,17 +217,17 @@ func (qr *QR) SolveVec(v *VecDense, trans bool, b Vector) error {
bm := Matrix(b) bm := Matrix(b)
if rv, ok := b.(RawVectorer); ok { if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector() bmat := rv.RawVector()
if v != b { if x != b {
v.checkOverlap(bmat) x.checkOverlap(bmat)
} }
b := VecDense{mat: bmat, n: b.Len()} b := VecDense{mat: bmat, n: b.Len()}
bm = b.asDense() bm = b.asDense()
} }
if trans { if trans {
v.reuseAs(r) x.reuseAs(r)
} else { } else {
v.reuseAs(c) x.reuseAs(c)
} }
return qr.Solve(v.asDense(), trans, bm) return qr.Solve(x.asDense(), trans, bm)
} }

View File

@@ -11,7 +11,7 @@ import (
) )
// Solve finds a minimum-norm solution to a system of linear equations defined // Solve finds a minimum-norm solution to a system of linear equations defined
// by the matrices a and b. If A is singular or near-singular, a Condition error // by the matrices A and B. If A is singular or near-singular, a Condition error
// is returned. See the documentation for Condition for more information. // is returned. See the documentation for Condition for more information.
// //
// The minimization problem solved depends on the input parameters: // The minimization problem solved depends on the input parameters: