Files
gonum/mat/solve.go
Jonathan Chan Kwan Yin 2d7eec07c1 mat: fixed typo in mat.Solve/mat.SolveVec
It seems the extra `A` here is unintended.
2022-11-17 22:01:33 +10:30

171 lines
4.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat
import (
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
"gonum.org/v1/gonum/lapack/lapack64"
)
// Solve solves the linear least squares problem
//
// minimize over x |b - A*x|_2
//
// where A is an m×n matrix, b is a given m element vector and x is n element
// solution vector. Solve assumes that A has full rank, that is
//
// rank(A) = min(m,n)
//
// If m >= n, Solve finds the unique least squares solution of an overdetermined
// system.
//
// If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In
// this case Solve finds the unique solution of an underdetermined system that
// minimizes |x|_2.
//
// Several right-hand side vectors b and solution vectors x can be handled in a
// single call. Vectors b are stored in the columns of the m×k matrix B. Vectors
// x will be stored in-place into the n×k receiver.
//
// If A does not have full rank, a Condition error is returned. See the
// documentation for Condition for more information.
func (m *Dense) Solve(a, b Matrix) error {
ar, ac := a.Dims()
br, bc := b.Dims()
if ar != br {
panic(ErrShape)
}
m.reuseAsNonZeroed(ac, bc)
// TODO(btracey): Add special cases for SymDense, etc.
aU, aTrans := untranspose(a)
bU, bTrans := untranspose(b)
switch rma := aU.(type) {
case RawTriangular:
side := blas.Left
tA := blas.NoTrans
if aTrans {
tA = blas.Trans
}
switch rm := bU.(type) {
case RawMatrixer:
if m != bU || bTrans {
if m == bU || m.checkOverlap(rm.RawMatrix()) {
tmp := getDenseWorkspace(br, bc, false)
tmp.Copy(b)
m.Copy(tmp)
putDenseWorkspace(tmp)
break
}
m.Copy(b)
}
default:
if m != bU {
m.Copy(b)
} else if bTrans {
// m and b share data so Copy cannot be used directly.
tmp := getDenseWorkspace(br, bc, false)
tmp.Copy(b)
m.Copy(tmp)
putDenseWorkspace(tmp)
}
}
rm := rma.RawTriangular()
blas64.Trsm(side, tA, 1, rm, m.mat)
work := getFloat64s(3*rm.N, false)
iwork := getInts(rm.N, false)
cond := lapack64.Trcon(CondNorm, rm, work, iwork)
putFloat64s(work)
putInts(iwork)
if cond > ConditionTolerance {
return Condition(cond)
}
return nil
}
switch {
case ar == ac:
if a == b {
// x = I.
if ar == 1 {
m.mat.Data[0] = 1
return nil
}
for i := 0; i < ar; i++ {
v := m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+ac]
zero(v)
v[i] = 1
}
return nil
}
var lu LU
lu.Factorize(a)
return lu.SolveTo(m, false, b)
case ar > ac:
var qr QR
qr.Factorize(a)
return qr.SolveTo(m, false, b)
default:
var lq LQ
lq.Factorize(a)
return lq.SolveTo(m, false, b)
}
}
// SolveVec solves the linear least squares problem
//
// minimize over x |b - A*x|_2
//
// where A is an m×n matrix, b is a given m element vector and x is n element
// solution vector. Solve assumes that A has full rank, that is
//
// rank(A) = min(m,n)
//
// If m >= n, Solve finds the unique least squares solution of an overdetermined
// system.
//
// If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In
// this case Solve finds the unique solution of an underdetermined system that
// minimizes |x|_2.
//
// The solution vector x will be stored in-place into the receiver.
//
// If A does not have full rank, a Condition error is returned. See the
// documentation for Condition for more information.
func (v *VecDense) SolveVec(a Matrix, b Vector) error {
if _, bc := b.Dims(); bc != 1 {
panic(ErrShape)
}
_, c := a.Dims()
// The Solve implementation is non-trivial, so rather than duplicate the code,
// instead recast the VecDenses as Dense and call the matrix code.
if rv, ok := b.(RawVectorer); ok {
bmat := rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
v.reuseAsNonZeroed(c)
m := v.asDense()
// We conditionally create bm as m when b and v are identical
// to prevent the overlap detection code from identifying m
// and bm as overlapping but not identical.
bm := m
if v != b {
b := VecDense{mat: bmat}
bm = b.asDense()
}
return m.Solve(a, bm)
}
v.reuseAsNonZeroed(c)
m := v.asDense()
return m.Solve(a, b)
}