cgo,native: fix optimal work length return

This commit is contained in:
kortschak
2017-01-26 10:16:48 +10:30
parent 6a2c601bf2
commit 85af0b1105
2 changed files with 10 additions and 8 deletions

View File

@@ -894,18 +894,18 @@ func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []fl
// //
// The C interface does not support providing temporary storage. To provide compatibility // The C interface does not support providing temporary storage. To provide compatibility
// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum // with native, lwork == -1 will not run Dgeqrf but will instead write the minimum
// work necessary to work[0]. If len(work) < lwork, Dgeqrf will panic. // work necessary to work[0]. If len(work) < max(1, lwork), Dgeqrf will panic.
// //
// tau must have length at least min(m,n), and this function will panic otherwise. // tau must have length at least min(m,n), and this function will panic otherwise.
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) { func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
if len(work) < max(1, lwork) {
panic(shortWork)
}
if lwork == -1 { if lwork == -1 {
work[0] = float64(n) work[0] = float64(n)
return return
} }
checkMatrix(m, n, a, lda) checkMatrix(m, n, a, lda)
if len(work) < lwork {
panic(shortWork)
}
if lwork < n { if lwork < n {
panic(badWork) panic(badWork)
} }

View File

@@ -14,7 +14,7 @@ import (
// parameters at entry and exit. // parameters at entry and exit.
// //
// work is temporary storage, and lwork specifies the usable memory length. // work is temporary storage, and lwork specifies the usable memory length.
// At minimum, lwork >= m and this function will panic otherwise. // At minimum, lwork >= max(1, m) and this function will panic otherwise.
// Dgeqrf is a blocked QR factorization, but the block size is limited // Dgeqrf is a blocked QR factorization, but the block size is limited
// by the temporary space available. If lwork == -1, instead of performing Dgeqrf, // by the temporary space available. If lwork == -1, instead of performing Dgeqrf,
// the optimal work length will be stored into work[0]. // the optimal work length will be stored into work[0].
@@ -25,14 +25,14 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl
nb := impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1) nb := impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1)
lworkopt := n * max(nb, 1) lworkopt := n * max(nb, 1)
lworkopt = max(n, lworkopt) lworkopt = max(n, lworkopt)
if len(work) < max(1, lwork) {
panic(shortWork)
}
if lwork == -1 { if lwork == -1 {
work[0] = float64(lworkopt) work[0] = float64(lworkopt)
return return
} }
checkMatrix(m, n, a, lda) checkMatrix(m, n, a, lda)
if len(work) < lwork {
panic(shortWork)
}
if lwork < n { if lwork < n {
panic(badWork) panic(badWork)
} }
@@ -41,6 +41,7 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl
panic(badTau) panic(badTau)
} }
if k == 0 { if k == 0 {
work[0] = float64(lworkopt)
return return
} }
nbmin := 2 // Minimal block size. nbmin := 2 // Minimal block size.
@@ -93,4 +94,5 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl
if i < k { if i < k {
impl.Dgeqr2(m-i, n-i, a[i*lda+i:], lda, tau[i:], work) impl.Dgeqr2(m-i, n-i, a[i*lda+i:], lda, tau[i:], work)
} }
work[0] = float64(lworkopt)
} }