cgo,native: fix optimal work length return

This commit is contained in:
kortschak
2017-01-26 10:16:48 +10:30
parent 6a2c601bf2
commit 85af0b1105
2 changed files with 10 additions and 8 deletions

View File

@@ -894,18 +894,18 @@ func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []fl
//
// The C interface does not support providing temporary storage. To provide compatibility
// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum
// work necessary to work[0]. If len(work) < lwork, Dgeqrf will panic.
// work necessary to work[0]. If len(work) < max(1, lwork), Dgeqrf will panic.
//
// tau must have length at least min(m,n), and this function will panic otherwise.
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
if len(work) < max(1, lwork) {
panic(shortWork)
}
if lwork == -1 {
work[0] = float64(n)
return
}
checkMatrix(m, n, a, lda)
if len(work) < lwork {
panic(shortWork)
}
if lwork < n {
panic(badWork)
}

View File

@@ -14,7 +14,7 @@ import (
// parameters at entry and exit.
//
// work is temporary storage, and lwork specifies the usable memory length.
// At minimum, lwork >= m and this function will panic otherwise.
// At minimum, lwork >= max(1, m) and this function will panic otherwise.
// Dgeqrf is a blocked QR factorization, but the block size is limited
// by the temporary space available. If lwork == -1, instead of performing Dgeqrf,
// the optimal work length will be stored into work[0].
@@ -25,14 +25,14 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl
nb := impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1)
lworkopt := n * max(nb, 1)
lworkopt = max(n, lworkopt)
if len(work) < max(1, lwork) {
panic(shortWork)
}
if lwork == -1 {
work[0] = float64(lworkopt)
return
}
checkMatrix(m, n, a, lda)
if len(work) < lwork {
panic(shortWork)
}
if lwork < n {
panic(badWork)
}
@@ -41,6 +41,7 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl
panic(badTau)
}
if k == 0 {
work[0] = float64(lworkopt)
return
}
nbmin := 2 // Minimal block size.
@@ -93,4 +94,5 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl
if i < k {
impl.Dgeqr2(m-i, n-i, a[i*lda+i:], lda, tau[i:], work)
}
work[0] = float64(lworkopt)
}