mirror of
https://github.com/gonum/gonum.git
synced 2025-12-24 13:47:56 +08:00
mat: make QR satisfy Matrix
This commit is contained in:
committed by
Vladimír Chalupecký
parent
aef3c5f344
commit
45b74210d6
@@ -27,6 +27,7 @@ type Float64 interface {
|
||||
Dlansy(norm MatrixNorm, uplo blas.Uplo, n int, a []float64, lda int, work []float64) float64
|
||||
Dlapmr(forward bool, m, n int, x []float64, ldx int, k []int)
|
||||
Dlapmt(forward bool, m, n int, x []float64, ldx int, k []int)
|
||||
Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
|
||||
Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
|
||||
Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
|
||||
Dpbcon(uplo blas.Uplo, n, kd int, ab []float64, ldab int, anorm float64, work []float64, iwork []int) float64
|
||||
|
||||
@@ -694,6 +694,28 @@ func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64
|
||||
lapack64.Dormlq(side, trans, c.Rows, c.Cols, a.Rows, a.Data, max(1, a.Stride), tau, c.Data, max(1, c.Stride), work, lwork)
|
||||
}
|
||||
|
||||
// Orgqr generates an m×n matrix Q with orthonormal columns defined by the
|
||||
// product of elementary reflectors
|
||||
//
|
||||
// Q = H_0 * H_1 * ... * H_{k-1}
|
||||
//
|
||||
// as computed by Geqrf.
|
||||
//
|
||||
// k is determined by the length of tau.
|
||||
//
|
||||
// The length of work must be at least n and it also must be that 0 <= k <= n
|
||||
// and 0 <= n <= m.
|
||||
//
|
||||
// work is temporary storage, and lwork specifies the usable memory length. At
|
||||
// minimum, lwork >= n, and the amount of blocking is limited by the usable
|
||||
// length. If lwork == -1, instead of computing Orgqr the optimal work length
|
||||
// is stored into work[0].
|
||||
//
|
||||
// Orgqr will panic if the conditions on input values are not met.
|
||||
func Orgqr(a blas64.General, tau []float64, work []float64, lwork int) {
|
||||
lapack64.Dorgqr(a.Rows, a.Cols, len(tau), a.Data, a.Stride, tau, work, lwork)
|
||||
}
|
||||
|
||||
// Ormqr multiplies an m×n matrix C by an orthogonal matrix Q as
|
||||
//
|
||||
// C = Q * C if side == blas.Left and trans == blas.NoTrans,
|
||||
@@ -705,12 +727,13 @@ func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64
|
||||
//
|
||||
// Q = H_0 * H_1 * ... * H_{k-1}.
|
||||
//
|
||||
// k is determined by the length of tau.
|
||||
//
|
||||
// If side == blas.Left, A is an m×k matrix and 0 <= k <= m.
|
||||
// If side == blas.Right, A is an n×k matrix and 0 <= k <= n.
|
||||
// The ith column of A contains the vector which defines the elementary
|
||||
// reflector H_i and tau[i] contains its scalar factor. tau must have length k
|
||||
// and Ormqr will panic otherwise. Geqrf returns A and tau in the required
|
||||
// form.
|
||||
// reflector H_i and tau[i] contains its scalar factor. Geqrf returns A and tau
|
||||
// in the required form.
|
||||
//
|
||||
// work must have length at least max(1,lwork), and lwork must be at least n if
|
||||
// side == blas.Left and at least m if side == blas.Right, otherwise Ormqr will
|
||||
@@ -725,7 +748,7 @@ func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64
|
||||
// If lwork is -1, instead of performing Ormqr, the optimal workspace size will
|
||||
// be stored into work[0].
|
||||
func Ormqr(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64, c blas64.General, work []float64, lwork int) {
|
||||
lapack64.Dormqr(side, trans, c.Rows, c.Cols, a.Cols, a.Data, max(1, a.Stride), tau, c.Data, max(1, c.Stride), work, lwork)
|
||||
lapack64.Dormqr(side, trans, c.Rows, c.Cols, len(tau), a.Data, max(1, a.Stride), tau, c.Data, max(1, c.Stride), work, lwork)
|
||||
}
|
||||
|
||||
// Pocon estimates the reciprocal of the condition number of a positive-definite
|
||||
|
||||
66
mat/qr.go
66
mat/qr.go
@@ -18,10 +18,42 @@ const badQR = "mat: invalid QR factorization"
|
||||
// QR is a type for creating and using the QR factorization of a matrix.
|
||||
type QR struct {
|
||||
qr *Dense
|
||||
q *Dense
|
||||
tau []float64
|
||||
cond float64
|
||||
}
|
||||
|
||||
// Dims returns the dimensions of the matrix.
|
||||
func (qr *QR) Dims() (r, c int) {
|
||||
if qr.qr == nil {
|
||||
return 0, 0
|
||||
}
|
||||
return qr.qr.Dims()
|
||||
}
|
||||
|
||||
// At returns the element at row i, column j.
|
||||
func (qr *QR) At(i, j int) float64 {
|
||||
m, n := qr.Dims()
|
||||
if uint(i) >= uint(m) {
|
||||
panic(ErrRowAccess)
|
||||
}
|
||||
if uint(j) >= uint(n) {
|
||||
panic(ErrColAccess)
|
||||
}
|
||||
|
||||
var val float64
|
||||
for k := 0; k <= j; k++ {
|
||||
val += qr.q.at(i, k) * qr.qr.at(k, j)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// T performs an implicit transpose by returning the receiver inside a
|
||||
// Transpose.
|
||||
func (qr *QR) T() Matrix {
|
||||
return Transpose{qr}
|
||||
}
|
||||
|
||||
func (qr *QR) updateCond(norm lapack.MatrixNorm) {
|
||||
// Since A = Q*R, and Q is orthogonal, we get for the condition number κ
|
||||
// κ(A) := |A| |A^-1| = |Q*R| |(Q*R)^-1| = |R| |R^-1 * Qᵀ|
|
||||
@@ -55,18 +87,34 @@ func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) {
|
||||
if m < n {
|
||||
panic(ErrShape)
|
||||
}
|
||||
k := min(m, n)
|
||||
if qr.qr == nil {
|
||||
qr.qr = &Dense{}
|
||||
}
|
||||
qr.qr.CloneFrom(a)
|
||||
work := []float64{0}
|
||||
qr.tau = make([]float64, k)
|
||||
qr.tau = make([]float64, n)
|
||||
lapack64.Geqrf(qr.qr.mat, qr.tau, work, -1)
|
||||
work = getFloat64s(int(work[0]), false)
|
||||
lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work))
|
||||
putFloat64s(work)
|
||||
qr.updateCond(norm)
|
||||
qr.updateQ()
|
||||
}
|
||||
|
||||
func (qr *QR) updateQ() {
|
||||
m, _ := qr.Dims()
|
||||
if qr.q == nil {
|
||||
qr.q = NewDense(m, m, nil)
|
||||
} else {
|
||||
qr.q.reuseAsNonZeroed(m, m)
|
||||
}
|
||||
// Construct Q from the elementary reflectors.
|
||||
qr.q.Copy(qr.qr)
|
||||
work := []float64{0}
|
||||
lapack64.Orgqr(qr.q.mat, qr.tau, work, -1)
|
||||
work = getFloat64s(int(work[0]), false)
|
||||
lapack64.Orgqr(qr.q.mat, qr.tau, work, len(work))
|
||||
putFloat64s(work)
|
||||
}
|
||||
|
||||
// isValid returns whether the receiver contains a factorization.
|
||||
@@ -143,20 +191,8 @@ func (qr *QR) QTo(dst *Dense) {
|
||||
if r != r2 || r != c2 {
|
||||
panic(ErrShape)
|
||||
}
|
||||
dst.Zero()
|
||||
}
|
||||
|
||||
// Set Q = I.
|
||||
for i := 0; i < r*r; i += r + 1 {
|
||||
dst.mat.Data[i] = 1
|
||||
}
|
||||
|
||||
// Construct Q from the elementary reflectors.
|
||||
work := []float64{0}
|
||||
lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, -1)
|
||||
work = getFloat64s(int(work[0]), false)
|
||||
lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, len(work))
|
||||
putFloat64s(work)
|
||||
dst.Copy(qr.q)
|
||||
}
|
||||
|
||||
// SolveTo finds a minimum-norm solution to a system of linear equations defined
|
||||
|
||||
@@ -42,6 +42,13 @@ func TestQR(t *testing.T) {
|
||||
t.Errorf("Q is not orthonormal: m = %v, n = %v", m, n)
|
||||
}
|
||||
|
||||
if !EqualApprox(a, &qr, 1e-14) {
|
||||
t.Errorf("m=%d,n=%d: A and QR are not equal", m, n)
|
||||
}
|
||||
if !EqualApprox(a.T(), qr.T(), 1e-14) {
|
||||
t.Errorf("m=%d,n=%d: Aᵀ and (QR)ᵀ are not equal", m, n)
|
||||
}
|
||||
|
||||
qr.RTo(&r)
|
||||
|
||||
var got Dense
|
||||
|
||||
Reference in New Issue
Block a user