lapack/gonum: require exact length of tau in QR routines

This commit is contained in:
Vladimir Chalupecky
2023-10-05 19:30:15 +02:00
committed by Vladimír Chalupecký
parent bd767ae5eb
commit 6e2f5c5890
17 changed files with 57 additions and 55 deletions

View File

@@ -131,7 +131,7 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float
// Solve the minimization problem using a QR or an LQ decomposition.
var scllen int
if m >= n {
impl.Dgeqrf(m, n, a, lda, work, work[mn:], lwork-mn)
impl.Dgeqrf(m, n, a, lda, work[:n], work[mn:], lwork-mn)
if trans == blas.NoTrans {
impl.Dormqr(blas.Left, blas.Trans, m, nrhs, n,
a, lda,

View File

@@ -123,7 +123,7 @@ func (impl Implementation) Dgeqp3(m, n int, a []float64, lda int, jpvt []int, ta
// Compute the QR factorization of nfxd columns and update remaining columns.
if nfxd > 0 {
na := min(m, nfxd)
impl.Dgeqrf(m, na, a, lda, tau, work, lwork)
impl.Dgeqrf(m, na, a, lda, tau[:na], work, lwork)
iws = max(iws, int(work[0]))
if na < n {
impl.Dormqr(blas.Left, blas.Trans, m, n-na, na, a, lda, tau[:na], a[na:], lda,

View File

@@ -14,7 +14,7 @@ import "gonum.org/v1/gonum/blas"
// A is modified to contain the information to construct Q and R.
// The upper triangle of a contains the matrix R. The lower triangular elements
// (not including the diagonal) contain the elementary reflectors. tau is modified
// to contain the reflector scales. tau must have length at least min(m,n), and
// to contain the reflector scales. tau must have length min(m,n), and
// this function will panic otherwise.
//
// The ith elementary reflector can be explicitly constructed by first extracting
@@ -57,8 +57,8 @@ func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []fl
switch {
case len(a) < (m-1)*lda+n:
panic(shortA)
case len(tau) < k:
panic(shortTau)
case len(tau) != k:
panic(badLenTau)
}
for i := 0; i < k; i++ {

View File

@@ -20,7 +20,7 @@ import (
// by the temporary space available. If lwork == -1, instead of performing Dgeqrf,
// the optimal work length will be stored into work[0].
//
// tau must have length at least min(m,n), and this function will panic otherwise.
// tau must have length min(m,n), and this function will panic otherwise.
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
switch {
case m < 0:
@@ -52,8 +52,8 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl
if len(a) < (m-1)*lda+n {
panic(shortA)
}
if len(tau) < k {
panic(shortTau)
if len(tau) != k {
panic(badLenTau)
}
nbmin := 2 // Minimal block size.
@@ -83,7 +83,7 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl
for i = 0; i < k-nx; i += nb {
ib := min(k-i, nb)
// Compute the QR factorization of the current block.
impl.Dgeqr2(m-i, ib, a[i*lda+i:], lda, tau[i:], work)
impl.Dgeqr2(m-i, ib, a[i*lda+i:], lda, tau[i:i+ib], work)
if i+ib < n {
// Form the triangular factor of the block reflector and apply Hᵀ
// In Dlarft, work becomes the T matrix.

View File

@@ -406,7 +406,7 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
iwork := itau + n
// Compute A = Q * R.
impl.Dgeqrf(m, n, a, lda, work[itau:], work[iwork:], lwork-iwork)
impl.Dgeqrf(m, n, a, lda, work[itau:itau+n], work[iwork:], lwork-iwork)
// Zero out below R.
impl.Dlaset(blas.Lower, n-1, n-1, 0, 0, a[lda:], lda)
@@ -455,14 +455,14 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
itau := ir + ldworkr*n
iwork := itau + n
// Compute A = Q * R.
impl.Dgeqrf(m, n, a, lda, work[itau:], work[iwork:], lwork-iwork)
impl.Dgeqrf(m, n, a, lda, work[itau:itau+n], work[iwork:], lwork-iwork)
// Copy R to work[ir:], zeroing out below it.
impl.Dlacpy(blas.Upper, n, n, a, lda, work[ir:], ldworkr)
impl.Dlaset(blas.Lower, n-1, n-1, 0, 0, work[ir+ldworkr:], ldworkr)
// Generate Q in A.
impl.Dorgqr(m, n, n, a, lda, work[itau:], work[iwork:], lwork-iwork)
impl.Dorgqr(m, n, n, a, lda, work[itau:itau+n], work[iwork:], lwork-iwork)
ie := itau
itauq := ie + n
itaup := itauq + n
@@ -492,11 +492,11 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
iwork := itau + n
// Compute A = Q*R, copying result to U.
impl.Dgeqrf(m, n, a, lda, work[itau:], work[iwork:], lwork-iwork)
impl.Dgeqrf(m, n, a, lda, work[itau:itau+n], work[iwork:], lwork-iwork)
impl.Dlacpy(blas.Lower, m, n, a, lda, u, ldu)
// Generate Q in U.
impl.Dorgqr(m, n, n, u, ldu, work[itau:], work[iwork:], lwork-iwork)
impl.Dorgqr(m, n, n, u, ldu, work[itau:itau+n], work[iwork:], lwork-iwork)
ie := itau
itauq := ie + n
itaup := itauq + n
@@ -537,13 +537,13 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
iwork := itau + n
// Compute A = Q * R.
impl.Dgeqrf(m, n, a, lda, work[itau:], work[iwork:], lwork-iwork)
impl.Dgeqrf(m, n, a, lda, work[itau:itau+n], work[iwork:], lwork-iwork)
// Copy R to work[iu:], zeroing out below it.
impl.Dlacpy(blas.Upper, n, n, a, lda, work[iu:], ldworku)
impl.Dlaset(blas.Lower, n-1, n-1, 0, 0, work[iu+ldworku:], ldworku)
// Generate Q in A.
impl.Dorgqr(m, n, n, a, lda, work[itau:], work[iwork:], lwork-iwork)
impl.Dorgqr(m, n, n, a, lda, work[itau:itau+n], work[iwork:], lwork-iwork)
ie := itau
itauq := ie + n
@@ -580,11 +580,11 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
iwork := itau + n
// Compute A = Q * R, copying result to U.
impl.Dgeqrf(m, n, a, lda, work[itau:], work[iwork:], lwork-iwork)
impl.Dgeqrf(m, n, a, lda, work[itau:itau+n], work[iwork:], lwork-iwork)
impl.Dlacpy(blas.Lower, m, n, a, lda, u, ldu)
// Generate Q in U.
impl.Dorgqr(m, n, n, u, ldu, work[itau:], work[iwork:], lwork-iwork)
impl.Dorgqr(m, n, n, u, ldu, work[itau:itau+n], work[iwork:], lwork-iwork)
// Copy R to VT, zeroing out below it.
impl.Dlacpy(blas.Upper, n, n, a, lda, vt, ldvt)
@@ -631,7 +631,7 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
iwork := itau + n
// Compute A = Q*R, copying result to U.
impl.Dgeqrf(m, n, a, lda, work[itau:], work[iwork:], lwork-iwork)
impl.Dgeqrf(m, n, a, lda, work[itau:itau+n], work[iwork:], lwork-iwork)
impl.Dlacpy(blas.Lower, m, n, a, lda, u, ldu)
// Copy R to work[ir:], zeroing out below it.
@@ -639,7 +639,7 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
impl.Dlaset(blas.Lower, n-1, n-1, 0, 0, work[ir+ldworkr:], ldworkr)
// Generate Q in U.
impl.Dorgqr(m, m, n, u, ldu, work[itau:], work[iwork:], lwork-iwork)
impl.Dorgqr(m, m, n, u, ldu, work[itau:itau+n], work[iwork:], lwork-iwork)
ie := itau
itauq := ie + n
itaup := itauq + n
@@ -672,11 +672,11 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
iwork := itau + n
// Compute A = Q*R, copying result to U.
impl.Dgeqrf(m, n, a, lda, work[itau:], work[iwork:], lwork-iwork)
impl.Dgeqrf(m, n, a, lda, work[itau:itau+n], work[iwork:], lwork-iwork)
impl.Dlacpy(blas.Lower, m, n, a, lda, u, ldu)
// Generate Q in U.
impl.Dorgqr(m, m, n, u, ldu, work[itau:], work[iwork:], lwork-iwork)
impl.Dorgqr(m, m, n, u, ldu, work[itau:itau+n], work[iwork:], lwork-iwork)
ie := itau
itauq := ie + n
itaup := itauq + n
@@ -717,11 +717,11 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
iwork := itau + n
// Compute A = Q * R, copying result to U.
impl.Dgeqrf(m, n, a, lda, work[itau:], work[iwork:], lwork-iwork)
impl.Dgeqrf(m, n, a, lda, work[itau:itau+n], work[iwork:], lwork-iwork)
impl.Dlacpy(blas.Lower, m, n, a, lda, u, ldu)
// Generate Q in U.
impl.Dorgqr(m, m, n, u, ldu, work[itau:], work[iwork:], lwork-iwork)
impl.Dorgqr(m, m, n, u, ldu, work[itau:itau+n], work[iwork:], lwork-iwork)
// Copy R to work[iu:], zeroing out below it.
impl.Dlacpy(blas.Upper, n, n, a, lda, work[iu:], ldworku)
@@ -786,11 +786,11 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
iwork := itau + n
// Compute A = Q*R, copying result to U.
impl.Dgeqrf(m, n, a, lda, work[itau:], work[iwork:], lwork-iwork)
impl.Dgeqrf(m, n, a, lda, work[itau:itau+n], work[iwork:], lwork-iwork)
impl.Dlacpy(blas.Lower, m, n, a, lda, u, ldu)
// Generate Q in U.
impl.Dorgqr(m, m, n, u, ldu, work[itau:], work[iwork:], lwork-iwork)
impl.Dorgqr(m, m, n, u, ldu, work[itau:itau+n], work[iwork:], lwork-iwork)
// Copy R from A to VT, zeroing out below it.
impl.Dlacpy(blas.Upper, n, n, a, lda, vt, ldvt)

View File

@@ -154,7 +154,7 @@ func (impl Implementation) Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int,
if p > 1 {
impl.Dlacpy(blas.Lower, p-1, min(p, n), b[ldb:], ldb, v[ldv:], ldv)
}
impl.Dorg2r(p, p, min(p, n), v, ldv, tau, work)
impl.Dorg2r(p, p, min(p, n), v, ldv, tau[:min(p, n)], work)
}
// Clean up B.
@@ -216,7 +216,7 @@ func (impl Implementation) Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int,
}
// Update A12 := Uᵀ*A12, where A12 = A[0:m, n-l:n].
impl.Dorm2r(blas.Left, blas.Trans, m, l, min(m, n-l), a, lda, tau, a[n-l:], lda, work)
impl.Dorm2r(blas.Left, blas.Trans, m, l, min(m, n-l), a, lda, tau[:min(m, n-l)], a[n-l:], lda, work)
if wantu {
// Copy the details of U, and form U.
@@ -224,7 +224,8 @@ func (impl Implementation) Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int,
if m > 1 {
impl.Dlacpy(blas.Lower, m-1, min(m, n-l), a[lda:], lda, u[ldu:], ldu)
}
impl.Dorg2r(m, m, min(m, n-l), u, ldu, tau, work)
k := min(m, n-l)
impl.Dorg2r(m, m, k, u, ldu, tau[:k], work)
}
if wantq {
@@ -250,7 +251,7 @@ func (impl Implementation) Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int,
if wantq {
// Update Q[0:n, 0:n-l] := Q[0:n, 0:n-l]*Z1ᵀ.
impl.Dorm2r(blas.Right, blas.Trans, n, n-l, k, a, lda, tau, q, ldq, work)
impl.Dorm2r(blas.Right, blas.Trans, n, n-l, k, a, lda, tau[:k], q, ldq, work)
}
// Clean up A.
@@ -265,10 +266,10 @@ func (impl Implementation) Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int,
if m > k {
// QR factorization of A[k:m, n-l:n].
impl.Dgeqr2(m-k, l, a[k*lda+n-l:], lda, tau, work)
impl.Dgeqr2(m-k, l, a[k*lda+n-l:], lda, tau[:min(m-k, l)], work)
if wantu {
// Update U[:, k:m) := U[:, k:m]*U1.
impl.Dorm2r(blas.Right, blas.NoTrans, m, m-k, min(m-k, l), a[k*lda+n-l:], lda, tau, u[k:], ldu, work)
impl.Dorm2r(blas.Right, blas.NoTrans, m, m-k, min(m-k, l), a[k*lda+n-l:], lda, tau[:min(m-k, l)], u[k:], ldu, work)
}
// Clean up A.

View File

@@ -14,7 +14,7 @@ import (
//
// Q = H_0 * H_1 * ... * H_{k-1}
//
// len(tau) >= k, 0 <= k <= n, 0 <= n <= m, len(work) >= n.
// len(tau) = k, 0 <= k <= n, 0 <= n <= m, len(work) >= n.
// Dorg2r will panic if these conditions are not met.
//
// Dorg2r is an internal routine. It is exported for testing purposes.
@@ -41,8 +41,8 @@ func (impl Implementation) Dorg2r(m, n, k int, a []float64, lda int, tau []float
switch {
case len(a) < (m-1)*lda+n:
panic(shortA)
case len(tau) < k:
panic(shortTau)
case len(tau) != k:
panic(badLenTau)
case len(work) < n:
panic(shortWork)
}

View File

@@ -91,7 +91,7 @@ func (impl Implementation) Dorgbr(vect lapack.GenOrtho, m, n, k int, a []float64
if wantq {
// Form Q, determined by a call to Dgebrd to reduce an m×k matrix.
if m >= k {
impl.Dorgqr(m, n, k, a, lda, tau, work, lwork)
impl.Dorgqr(m, n, k, a, lda, tau[:k], work, lwork)
} else {
// Shift the vectors which define the elementary reflectors one
// column to the right, and set the first row and column of Q to
@@ -108,7 +108,7 @@ func (impl Implementation) Dorgbr(vect lapack.GenOrtho, m, n, k int, a []float64
}
if m > 1 {
// Form Q[1:m-1, 1:m-1]
impl.Dorgqr(m-1, m-1, m-1, a[lda+1:], lda, tau, work, lwork)
impl.Dorgqr(m-1, m-1, m-1, a[lda+1:], lda, tau[:m-1], work, lwork)
}
}
} else {

View File

@@ -18,7 +18,7 @@ import (
// Dorgqr is the blocked version of Dorg2r that makes greater use of level-3 BLAS
// routines.
//
// The length of tau must be at least k, and the length of work must be at least n.
// The length of tau must be k, and the length of work must be at least n.
// It also must be that 0 <= k <= n and 0 <= n <= m.
//
// work is temporary storage, and lwork specifies the usable memory length. At
@@ -70,8 +70,8 @@ func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work [
switch {
case len(a) < (m-1)*lda+n:
panic(shortA)
case len(tau) < k:
panic(shortTau)
case len(tau) != k:
panic(badLenTau)
}
nbmin := 2 // Minimum block size
@@ -123,7 +123,7 @@ func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work [
a[i*lda+i+ib:], lda,
work[ib*ldwork:], ldwork)
}
impl.Dorg2r(m-i, ib, ib, a[i*lda+i:], lda, tau[i:], work)
impl.Dorg2r(m-i, ib, ib, a[i*lda+i:], lda, tau[i:i+ib], work)
// Set rows 0:i-1 of current block to zero.
for j := i; j < i+ib; j++ {
for l := 0; l < i; l++ {

View File

@@ -99,7 +99,7 @@ func (impl Implementation) Dorgtr(uplo blas.Uplo, n int, a []float64, lda int, t
}
if n > 1 {
// Generate Q[1:n, 1:n].
impl.Dorgqr(n-1, n-1, n-1, a[lda+1:], lda, tau, work, lwork)
impl.Dorgqr(n-1, n-1, n-1, a[lda+1:], lda, tau[:n-1], work, lwork)
}
}
work[0] = float64(lworkopt)

View File

@@ -17,7 +17,7 @@ import "gonum.org/v1/gonum/blas"
// If side == blas.Left, a is a matrix of size m×k, and if side == blas.Right
// a is of size n×k.
//
// tau contains the Householder factors and is of length at least k and this function
// tau contains the Householder factors and must have length k and this function
// will panic otherwise.
//
// work is temporary storage of length at least n if side == blas.Left
@@ -59,8 +59,8 @@ func (impl Implementation) Dorm2r(side blas.Side, trans blas.Transpose, m, n, k
panic(shortA)
case len(c) < (m-1)*ldc+n:
panic(shortC)
case len(tau) < k:
panic(shortTau)
case len(tau) != k:
panic(badLenTau)
case left && len(work) < n:
panic(shortWork)
case !left && len(work) < m:

View File

@@ -267,7 +267,7 @@ func Geqp3(a blas64.General, jpvt []int, tau, work []float64, lwork int) {
// algorithm. A is modified to contain the information to construct Q and R.
// The upper triangle of a contains the matrix R. The lower triangular elements
// (not including the diagonal) contain the elementary reflectors. tau is modified
// to contain the reflector scales. tau must have length at least min(m,n), and
// to contain the reflector scales. tau must have length min(m,n), and
// this function will panic otherwise.
//
// The ith elementary reflector can be explicitly constructed by first extracting

View File

@@ -62,8 +62,9 @@ func DgeqrfTest(t *testing.T, impl Dgeqrfer) {
// Allocate a slice for scalar factors of elementary reflectors
// and fill it with random numbers.
tau := make([]float64, n)
for i := 0; i < n; i++ {
k := min(m, n)
tau := make([]float64, k)
for i := range tau {
tau[i] = rnd.Float64()
}

View File

@@ -84,7 +84,7 @@ func DlarfbTest(t *testing.T, impl Dlarfber) {
}
// Use dgeqr2 to find the v vectors
tau := make([]float64, na)
tau := make([]float64, k)
work := make([]float64, na)
impl.Dgeqr2(ma, k, a, lda, tau, work)

View File

@@ -52,7 +52,8 @@ func DlarftTest(t *testing.T, impl Dlarfter) {
}
}
// Use dgeqr2 to find the v vectors
tau := make([]float64, n)
k := min(m, n)
tau := make([]float64, k)
work := make([]float64, n)
impl.Dgeqr2(m, n, a, lda, tau, work)
@@ -64,7 +65,6 @@ func DlarftTest(t *testing.T, impl Dlarfter) {
h := constructH(tau, vMat, store, direct)
k := min(m, n)
ldt := test.ldt
if ldt == 0 {
ldt = k

View File

@@ -59,7 +59,7 @@ func Dorg2rTest(t *testing.T, impl Dorg2rer) {
q := constructQK("QR", m, n, k, a, lda, tau)
// Compute the matrix Q using Dorg2r.
impl.Dorg2r(m, n, k, a, lda, tau, work)
impl.Dorg2r(m, n, k, a, lda, tau[:k], work)
// Check that the first n columns of both results match.
same := true

View File

@@ -57,7 +57,7 @@ func DorgqrTest(t *testing.T, impl Dorgqrer) {
a[i] = rnd.Float64()
}
work := make([]float64, 1)
tau := make([]float64, n)
tau := make([]float64, min(m, n))
for i := range tau {
tau[i] = math.NaN()
}
@@ -71,12 +71,12 @@ func DorgqrTest(t *testing.T, impl Dorgqrer) {
for i := range work {
work[i] = math.NaN()
}
impl.Dorg2r(m, n, k, aUnblocked, lda, tau, work)
impl.Dorg2r(m, n, k, aUnblocked, lda, tau[:k], work)
// make sure work isn't used before initialized
for i := range work {
work[i] = math.NaN()
}
impl.Dorgqr(m, n, k, a, lda, tau, work, len(work))
impl.Dorgqr(m, n, k, a, lda, tau[:k], work, len(work))
if !floats.EqualApprox(a, aUnblocked, 1e-10) {
t.Errorf("Q Mismatch. m = %d, n = %d, k = %d, lda = %d", m, n, k, lda)
}