From 85af0b1105121ca2cc2ff8b5513f5e0fededb526 Mon Sep 17 00:00:00 2001 From: kortschak Date: Thu, 26 Jan 2017 10:16:48 +1030 Subject: [PATCH] cgo,native: fix optimal work length return --- cgo/lapack.go | 8 ++++---- native/dgeqrf.go | 10 ++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/cgo/lapack.go b/cgo/lapack.go index ec4665d5..4a2e21b6 100644 --- a/cgo/lapack.go +++ b/cgo/lapack.go @@ -894,18 +894,18 @@ func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []fl // // The C interface does not support providing temporary storage. To provide compatibility // with native, lwork == -1 will not run Dgeqrf but will instead write the minimum -// work necessary to work[0]. If len(work) < lwork, Dgeqrf will panic. +// work necessary to work[0]. If len(work) < max(1, lwork), Dgeqrf will panic. // // tau must have length at least min(m,n), and this function will panic otherwise. func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) { + if len(work) < max(1, lwork) { + panic(shortWork) + } if lwork == -1 { work[0] = float64(n) return } checkMatrix(m, n, a, lda) - if len(work) < lwork { - panic(shortWork) - } if lwork < n { panic(badWork) } diff --git a/native/dgeqrf.go b/native/dgeqrf.go index 5922ee9d..56297f05 100644 --- a/native/dgeqrf.go +++ b/native/dgeqrf.go @@ -14,7 +14,7 @@ import ( // parameters at entry and exit. // // work is temporary storage, and lwork specifies the usable memory length. -// At minimum, lwork >= m and this function will panic otherwise. +// At minimum, lwork >= max(1, m) and this function will panic otherwise. // Dgeqrf is a blocked QR factorization, but the block size is limited // by the temporary space available. If lwork == -1, instead of performing Dgeqrf, // the optimal work length will be stored into work[0]. @@ -25,14 +25,14 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl nb := impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1) lworkopt := n * max(nb, 1) lworkopt = max(n, lworkopt) + if len(work) < max(1, lwork) { + panic(shortWork) + } if lwork == -1 { work[0] = float64(lworkopt) return } checkMatrix(m, n, a, lda) - if len(work) < lwork { - panic(shortWork) - } if lwork < n { panic(badWork) } @@ -41,6 +41,7 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl panic(badTau) } if k == 0 { + work[0] = float64(lworkopt) return } nbmin := 2 // Minimal block size. @@ -93,4 +94,5 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl if i < k { impl.Dgeqr2(m-i, n-i, a[i*lda+i:], lda, tau[i:], work) } + work[0] = float64(lworkopt) }