lapack/gonum: clean up parameter checks in Dorgqr

This commit is contained in:
Vladimir Chalupecky
2018-11-19 18:51:02 +01:00
committed by Vladimír Chalupecký
parent 1cc352d23d
commit c8ad599984

View File

@@ -28,39 +28,45 @@ import (
// //
// Dorgqr is an internal routine. It is exported for testing purposes. // Dorgqr is an internal routine. It is exported for testing purposes.
func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) { func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) {
switch {
case k < 0:
panic(kLT0)
case n < k:
panic(kGTN)
case m < n:
panic(mLTN)
case lda < max(1, n):
panic(badLdA)
case lwork < max(1, n) && lwork != -1:
panic(badWork)
case len(work) < max(1, lwork):
panic(shortWork)
}
if n == 0 {
work[0] = 1
return
}
nb := impl.Ilaenv(1, "DORGQR", " ", m, n, k, -1) nb := impl.Ilaenv(1, "DORGQR", " ", m, n, k, -1)
// work is treated as an n×nb matrix // work is treated as an n×nb matrix
if lwork == -1 { if lwork == -1 {
work[0] = float64(max(1, n) * nb) work[0] = float64(n * nb)
return return
} }
checkMatrix(m, n, a, lda)
if k < 0 { if len(a) < (m-1)*lda+n {
panic(kLT0) panic("lapack: insuffcient length of a")
}
if k > n {
panic(kGTN)
}
if n > m {
panic(mLTN)
} }
if len(tau) < k { if len(tau) < k {
panic(badTau) panic(badTau)
} }
if len(work) < lwork {
panic(shortWork) nbmin := 2 // Minimum block size
} var nx int // Crossover size from blocked to unbloked code
if lwork < n {
panic(badWork)
}
if n == 0 {
return
}
nbmin := 2 // Minimum number of blocks
var nx int // Minimum number of rows
iws := n // Length of work needed iws := n // Length of work needed
var ldwork int var ldwork int
if nb > 1 && nb < k { if 1 < nb && nb < k {
nx = max(0, impl.Ilaenv(3, "DORGQR", " ", m, n, k, -1)) nx = max(0, impl.Ilaenv(3, "DORGQR", " ", m, n, k, -1))
if nx < k { if nx < k {
ldwork = nb ldwork = nb
@@ -73,14 +79,12 @@ func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work [
} }
} }
var ki, kk int var ki, kk int
if nb >= nbmin && nb < k && nx < k { if nbmin <= nb && nb < k && nx < k {
// The first kk columns are handled by the blocked method. // The first kk columns are handled by the blocked method.
// Note: lapack has nx here, but this means the last nx rows are handled ki = ((k - nx - 1) / nb) * nb
// serially which could be quite different than nb.
ki = ((k - nb - 1) / nb) * nb
kk = min(k, ki+nb) kk = min(k, ki+nb)
for j := kk; j < n; j++ { for i := 0; i < kk; i++ {
for i := 0; i < kk; i++ { for j := kk; j < n; j++ {
a[i*lda+j] = 0 a[i*lda+j] = 0
} }
} }
@@ -89,32 +93,32 @@ func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work [
// Perform the operation on colums kk to the end. // Perform the operation on colums kk to the end.
impl.Dorg2r(m-kk, n-kk, k-kk, a[kk*lda+kk:], lda, tau[kk:], work) impl.Dorg2r(m-kk, n-kk, k-kk, a[kk*lda+kk:], lda, tau[kk:], work)
} }
if kk == 0 { if kk > 0 {
return // Perform the operation on column-blocks.
} for i := ki; i >= 0; i -= nb {
// Perform the operation on column-blocks ib := min(nb, k-i)
for i := ki; i >= 0; i -= nb { if i+ib < n {
ib := min(nb, k-i) impl.Dlarft(lapack.Forward, lapack.ColumnWise,
if i+ib < n { m-i, ib,
impl.Dlarft(lapack.Forward, lapack.ColumnWise, a[i*lda+i:], lda,
m-i, ib, tau[i:],
a[i*lda+i:], lda, work, ldwork)
tau[i:],
work, ldwork)
impl.Dlarfb(blas.Left, blas.NoTrans, lapack.Forward, lapack.ColumnWise, impl.Dlarfb(blas.Left, blas.NoTrans, lapack.Forward, lapack.ColumnWise,
m-i, n-i-ib, ib, m-i, n-i-ib, ib,
a[i*lda+i:], lda, a[i*lda+i:], lda,
work, ldwork, work, ldwork,
a[i*lda+i+ib:], lda, a[i*lda+i+ib:], lda,
work[ib*ldwork:], ldwork) work[ib*ldwork:], ldwork)
} }
impl.Dorg2r(m-i, ib, ib, a[i*lda+i:], lda, tau[i:], work) impl.Dorg2r(m-i, ib, ib, a[i*lda+i:], lda, tau[i:], work)
// Set rows 0:i-1 of current block to zero // Set rows 0:i-1 of current block to zero.
for j := i; j < i+ib; j++ { for j := i; j < i+ib; j++ {
for l := 0; l < i; l++ { for l := 0; l < i; l++ {
a[l*lda+j] = 0 a[l*lda+j] = 0
}
} }
} }
} }
work[0] = float64(iws)
} }