From b28cd3dcad20540e02115df95cc0bb7347f2aeef Mon Sep 17 00:00:00 2001 From: Vladimir Chalupecky Date: Wed, 24 Aug 2016 14:59:50 +0900 Subject: [PATCH] native: clean up Dormqr --- native/dormqr.go | 82 +++++++++++++++++++++++++++--------------------- 1 file changed, 47 insertions(+), 35 deletions(-) diff --git a/native/dormqr.go b/native/dormqr.go index c0e82094..935c614a 100644 --- a/native/dormqr.go +++ b/native/dormqr.go @@ -15,8 +15,8 @@ import ( // C = Q^T * C if side == blas.Left and trans == blas.Trans // C = C * Q if side == blas.Right and trans == blas.NoTrans // C = C * Q^T if side == blas.Right and trans == blas.Trans -// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right -// A is of size k×n. This uses a blocked algorithm. +// If side == blas.Left, A is a matrix of side m×k, and if side == blas.Right +// A is of size n×k. This uses a blocked algorithm. // // work is temporary storage, and lwork specifies the usable memory length. // At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right, @@ -28,28 +28,47 @@ import ( // tau contains the Householder scales and must have length at least k, and // this function will panic otherwise. func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) { - left := side == blas.Left - notran := trans == blas.NoTrans - if left { - checkMatrix(m, k, a, lda) - } else { - checkMatrix(n, k, a, lda) + var nq, nw int + switch side { + default: + panic(badSide) + case blas.Left: + nq = m + nw = n + case blas.Right: + nq = n + nw = m + } + switch { + case trans != blas.NoTrans && trans != blas.Trans: + panic(badTrans) + case m < 0 || n < 0: + panic(negDimension) + case k < 0 || nq < k: + panic("lapack: invalid value of k") + case len(work) < lwork: + panic(shortWork) + case lwork < max(1, nw) && lwork != -1: + panic(badWork) + } + if lwork != -1 { + checkMatrix(nq, k, a, lda) + checkMatrix(m, n, c, ldc) + if len(tau) < k { + panic(badTau) + } } - checkMatrix(m, n, c, ldc) - if len(tau) < k { - panic(badTau) + if m == 0 || n == 0 || k == 0 { + work[0] = 1 + return } const ( nbmax = 64 ldt = nbmax - tsize = nbmax * nbmax + tsize = nbmax * ldt ) - nw := n - if side == blas.Right { - nw = m - } opts := string(side) + string(trans) nb := min(nbmax, impl.Ilaenv(1, "DORMQR", opts, m, n, k, -1)) lworkopt := max(1, nw)*nb + tsize @@ -57,35 +76,27 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k work[0] = float64(lworkopt) return } - if len(work) < lwork { - panic(badWork) - } - if left { - if lwork < n { - panic(badWork) - } - } else { - if lwork < m { - panic(badWork) - } - } - if m == 0 || n == 0 || k == 0 { - return - } - nbmin := 2 - if nb > 1 && nb < k { + nbmin := 2 + if 1 < nb && nb < k { if lwork < nw*nb+tsize { nb = (lwork - tsize) / nw nbmin = max(2, impl.Ilaenv(2, "DORMQR", opts, m, n, k, -1)) } } - if nb < nbmin || nb >= k { + + if nb < nbmin || k <= nb { // Call unblocked code. impl.Dorm2r(side, trans, m, n, k, a, lda, tau, c, ldc, work) + work[0] = float64(lworkopt) return } - ldwork := nb + + var ( + ldwork = nb + left = side == blas.Left + notran = trans == blas.NoTrans + ) switch { case left && notran: for i := ((k - 1) / nb) * nb; i >= 0; i -= nb { @@ -143,4 +154,5 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k work[tsize:], ldwork) } } + work[0] = float64(lworkopt) }