Files
gonum/mat/solve.go
Brendan Tracey 3fa9374bd4 matrix: rename matrix to mat, and merge with mat64 and cmat128.
This merges the three packages, matrix, mat64, and cmat128. It then renames this big package to mat. It fixes the import statements and corresponding code
2017-06-13 10:26:10 -06:00

128 lines
3.1 KiB
Go

// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat
import (
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
"gonum.org/v1/gonum/lapack/lapack64"
)
// Solve finds a minimum-norm solution to a system of linear equations defined
// by the matrices a and b. If A is singular or near-singular, a Condition error
// is returned. Please see the documentation for Condition for more information.
//
// The minimization problem solved depends on the input parameters:
// - if m >= n, find X such that ||A*X - B||_2 is minimized,
// - if m < n, find the minimum norm solution of A * X = B.
// The solution matrix, X, is stored in-place into the receiver.
func (m *Dense) Solve(a, b Matrix) error {
ar, ac := a.Dims()
br, bc := b.Dims()
if ar != br {
panic(ErrShape)
}
m.reuseAs(ac, bc)
// TODO(btracey): Add special cases for SymDense, etc.
aU, aTrans := untranspose(a)
bU, bTrans := untranspose(b)
switch rma := aU.(type) {
case RawTriangular:
side := blas.Left
tA := blas.NoTrans
if aTrans {
tA = blas.Trans
}
switch rm := bU.(type) {
case RawMatrixer:
if m != bU || bTrans {
if m == bU || m.checkOverlap(rm.RawMatrix()) {
tmp := getWorkspace(br, bc, false)
tmp.Copy(b)
m.Copy(tmp)
putWorkspace(tmp)
break
}
m.Copy(b)
}
default:
if m != bU {
m.Copy(b)
} else if bTrans {
// m and b share data so Copy cannot be used directly.
tmp := getWorkspace(br, bc, false)
tmp.Copy(b)
m.Copy(tmp)
putWorkspace(tmp)
}
}
rm := rma.RawTriangular()
blas64.Trsm(side, tA, 1, rm, m.mat)
work := getFloats(3*rm.N, false)
iwork := getInts(rm.N, false)
cond := lapack64.Trcon(CondNorm, rm, work, iwork)
putFloats(work)
putInts(iwork)
if cond > ConditionTolerance {
return Condition(cond)
}
return nil
}
switch {
case ar == ac:
if a == b {
// x = I.
if ar == 1 {
m.mat.Data[0] = 1
return nil
}
for i := 0; i < ar; i++ {
v := m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+ac]
zero(v)
v[i] = 1
}
return nil
}
var lu LU
lu.Factorize(a)
return m.SolveLU(&lu, false, b)
case ar > ac:
var qr QR
qr.Factorize(a)
return m.SolveQR(&qr, false, b)
default:
var lq LQ
lq.Factorize(a)
return m.SolveLQ(&lq, false, b)
}
}
// SolveVec finds a minimum-norm solution to a system of linear equations defined
// by the matrix a and the right-hand side vector b. If A is singular or
// near-singular, a Condition error is returned. Please see the documentation for
// Dense.Solve for more information.
func (v *Vector) SolveVec(a Matrix, b *Vector) error {
if v != b {
v.checkOverlap(b.mat)
}
_, c := a.Dims()
// The Solve implementation is non-trivial, so rather than duplicate the code,
// instead recast the Vectors as Dense and call the matrix code.
v.reuseAs(c)
m := v.asDense()
// We conditionally create bm as m when b and v are identical
// to prevent the overlap detection code from identifying m
// and bm as overlapping but not identical.
bm := m
if v != b {
bm = b.asDense()
}
return m.Solve(a, bm)
}