mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00
lapack/gonum: unify parameter checks
This commit is contained in:

committed by
Vladimír Chalupecký

parent
454e9ef3f4
commit
70a1e933af
@@ -50,17 +50,38 @@ import (
|
|||||||
//
|
//
|
||||||
// Dbdsqr is an internal routine. It is exported for testing purposes.
|
// Dbdsqr is an internal routine. It is exported for testing purposes.
|
||||||
func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool) {
|
func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool) {
|
||||||
if uplo != blas.Upper && uplo != blas.Lower {
|
switch {
|
||||||
|
case uplo != blas.Upper && uplo != blas.Lower:
|
||||||
panic(badUplo)
|
panic(badUplo)
|
||||||
|
case n < 0:
|
||||||
|
panic(nLT0)
|
||||||
|
case ncvt < 0:
|
||||||
|
panic(ncvtLT0)
|
||||||
|
case nru < 0:
|
||||||
|
panic(nruLT0)
|
||||||
|
case ncc < 0:
|
||||||
|
panic(nccLT0)
|
||||||
|
case ldvt < max(1, ncvt):
|
||||||
|
panic(badLdVT)
|
||||||
|
case (ldu < max(1, n) && nru > 0) || (ldu < 1 && nru == 0):
|
||||||
|
panic(badLdU)
|
||||||
|
case ldc < max(1, ncc):
|
||||||
|
panic(badLdC)
|
||||||
}
|
}
|
||||||
if ncvt != 0 {
|
|
||||||
checkMatrix(n, ncvt, vt, ldvt)
|
// Quick return if possible.
|
||||||
|
if n == 0 {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
if nru != 0 {
|
|
||||||
checkMatrix(nru, n, u, ldu)
|
if len(vt) < (n-1)*ldvt+ncvt && ncvt != 0 {
|
||||||
|
panic(shortVT)
|
||||||
}
|
}
|
||||||
if ncc != 0 {
|
if len(u) < (nru-1)*ldu+n && nru != 0 {
|
||||||
checkMatrix(n, ncc, c, ldc)
|
panic(shortU)
|
||||||
|
}
|
||||||
|
if len(c) < (n-1)*ldc+ncc && ncc != 0 {
|
||||||
|
panic(shortC)
|
||||||
}
|
}
|
||||||
if len(d) < n {
|
if len(d) < n {
|
||||||
panic(badD)
|
panic(badD)
|
||||||
@@ -71,14 +92,11 @@ func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, v
|
|||||||
if len(work) < 4*(n-1) {
|
if len(work) < 4*(n-1) {
|
||||||
panic(badWork)
|
panic(badWork)
|
||||||
}
|
}
|
||||||
|
|
||||||
var info int
|
var info int
|
||||||
bi := blas64.Implementation()
|
bi := blas64.Implementation()
|
||||||
const (
|
const maxIter = 6
|
||||||
maxIter = 6
|
|
||||||
)
|
|
||||||
if n == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if n != 1 {
|
if n != 1 {
|
||||||
// If the singular vectors do not need to be computed, use qd algorithm.
|
// If the singular vectors do not need to be computed, use qd algorithm.
|
||||||
if !(ncvt > 0 || nru > 0 || ncc > 0) {
|
if !(ncvt > 0 || nru > 0 || ncc > 0) {
|
||||||
|
@@ -21,26 +21,37 @@ import (
|
|||||||
//
|
//
|
||||||
// Dgebak is an internal routine. It is exported for testing purposes.
|
// Dgebak is an internal routine. It is exported for testing purposes.
|
||||||
func (impl Implementation) Dgebak(job lapack.BalanceJob, side lapack.EVSide, n, ilo, ihi int, scale []float64, m int, v []float64, ldv int) {
|
func (impl Implementation) Dgebak(job lapack.BalanceJob, side lapack.EVSide, n, ilo, ihi int, scale []float64, m int, v []float64, ldv int) {
|
||||||
switch job {
|
|
||||||
default:
|
|
||||||
panic(badBalanceJob)
|
|
||||||
case lapack.BalanceNone, lapack.Permute, lapack.Scale, lapack.PermuteScale:
|
|
||||||
}
|
|
||||||
switch side {
|
|
||||||
default:
|
|
||||||
panic(badEVSide)
|
|
||||||
case lapack.EVLeft, lapack.EVRight:
|
|
||||||
}
|
|
||||||
checkMatrix(n, m, v, ldv)
|
|
||||||
switch {
|
switch {
|
||||||
|
case job != lapack.BalanceNone && job != lapack.Permute && job != lapack.Scale && job != lapack.PermuteScale:
|
||||||
|
panic(badBalanceJob)
|
||||||
|
case side != lapack.EVLeft && side != lapack.EVRight:
|
||||||
|
panic(badEVSide)
|
||||||
|
case n < 0:
|
||||||
|
panic(nLT0)
|
||||||
case ilo < 0 || max(0, n-1) < ilo:
|
case ilo < 0 || max(0, n-1) < ilo:
|
||||||
panic(badIlo)
|
panic(badIlo)
|
||||||
case ihi < min(ilo, n-1) || n <= ihi:
|
case ihi < min(ilo, n-1) || n <= ihi:
|
||||||
panic(badIhi)
|
panic(badIhi)
|
||||||
|
case m < 0:
|
||||||
|
panic(mLT0)
|
||||||
|
case ldv < max(1, m):
|
||||||
|
panic(badLdV)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quick return if possible.
|
// Quick return if possible.
|
||||||
if n == 0 || m == 0 || job == lapack.BalanceNone {
|
if n == 0 || m == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(scale) < n {
|
||||||
|
panic(shortScale)
|
||||||
|
}
|
||||||
|
if len(v) < (n-1)*ldv+m {
|
||||||
|
panic(shortV)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quick return if possible.
|
||||||
|
if job == lapack.BalanceNone {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -55,26 +55,37 @@ import (
|
|||||||
//
|
//
|
||||||
// Dgebal is an internal routine. It is exported for testing purposes.
|
// Dgebal is an internal routine. It is exported for testing purposes.
|
||||||
func (impl Implementation) Dgebal(job lapack.BalanceJob, n int, a []float64, lda int, scale []float64) (ilo, ihi int) {
|
func (impl Implementation) Dgebal(job lapack.BalanceJob, n int, a []float64, lda int, scale []float64) (ilo, ihi int) {
|
||||||
switch job {
|
switch {
|
||||||
default:
|
case job != lapack.BalanceNone && job != lapack.Permute && job != lapack.Scale && job != lapack.PermuteScale:
|
||||||
panic(badBalanceJob)
|
panic(badBalanceJob)
|
||||||
case lapack.BalanceNone, lapack.Permute, lapack.Scale, lapack.PermuteScale:
|
case n < 0:
|
||||||
}
|
panic(nLT0)
|
||||||
checkMatrix(n, n, a, lda)
|
case lda < max(1, n):
|
||||||
if len(scale) != n {
|
panic(badLdA)
|
||||||
panic("lapack: bad length of scale")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ilo = 0
|
ilo = 0
|
||||||
ihi = n - 1
|
ihi = n - 1
|
||||||
|
|
||||||
if n == 0 || job == lapack.BalanceNone {
|
if n == 0 {
|
||||||
|
return ilo, ihi
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(scale) != n {
|
||||||
|
panic(shortScale)
|
||||||
|
}
|
||||||
|
|
||||||
|
if job == lapack.BalanceNone {
|
||||||
for i := range scale {
|
for i := range scale {
|
||||||
scale[i] = 1
|
scale[i] = 1
|
||||||
}
|
}
|
||||||
return ilo, ihi
|
return ilo, ihi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(a) < (n-1)*lda+n {
|
||||||
|
panic(shortA)
|
||||||
|
}
|
||||||
|
|
||||||
bi := blas64.Implementation()
|
bi := blas64.Implementation()
|
||||||
swapped := true
|
swapped := true
|
||||||
|
|
||||||
|
@@ -15,22 +15,34 @@ import "gonum.org/v1/gonum/blas"
|
|||||||
//
|
//
|
||||||
// Dgebd2 is an internal routine. It is exported for testing purposes.
|
// Dgebd2 is an internal routine. It is exported for testing purposes.
|
||||||
func (impl Implementation) Dgebd2(m, n int, a []float64, lda int, d, e, tauQ, tauP, work []float64) {
|
func (impl Implementation) Dgebd2(m, n int, a []float64, lda int, d, e, tauQ, tauP, work []float64) {
|
||||||
checkMatrix(m, n, a, lda)
|
switch {
|
||||||
if len(d) < min(m, n) {
|
case m < 0:
|
||||||
|
panic(mLT0)
|
||||||
|
case n < 0:
|
||||||
|
panic(nLT0)
|
||||||
|
case lda < max(1, n):
|
||||||
|
panic(badLdA)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quick return if possible.
|
||||||
|
minmn := min(m, n)
|
||||||
|
if minmn == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(d) < minmn:
|
||||||
panic(badD)
|
panic(badD)
|
||||||
}
|
case len(e) < minmn-1:
|
||||||
if len(e) < min(m, n)-1 {
|
|
||||||
panic(badE)
|
panic(badE)
|
||||||
}
|
case len(tauQ) < minmn:
|
||||||
if len(tauQ) < min(m, n) {
|
|
||||||
panic(badTauQ)
|
panic(badTauQ)
|
||||||
}
|
case len(tauP) < minmn:
|
||||||
if len(tauP) < min(m, n) {
|
|
||||||
panic(badTauP)
|
panic(badTauP)
|
||||||
}
|
case len(work) < max(m, n):
|
||||||
if len(work) < max(m, n) {
|
|
||||||
panic(badWork)
|
panic(badWork)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m >= n {
|
if m >= n {
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
a[i*lda+i], tauQ[i] = impl.Dlarfg(m-i, a[i*lda+i], a[min(i+1, m-1)*lda+i:], lda)
|
a[i*lda+i], tauQ[i] = impl.Dlarfg(m-i, a[i*lda+i], a[min(i+1, m-1)*lda+i:], lda)
|
||||||
|
@@ -24,20 +24,31 @@ import (
|
|||||||
//
|
//
|
||||||
// iwork is a temporary data slice of length at least n and Dgecon will panic otherwise.
|
// iwork is a temporary data slice of length at least n and Dgecon will panic otherwise.
|
||||||
func (impl Implementation) Dgecon(norm lapack.MatrixNorm, n int, a []float64, lda int, anorm float64, work []float64, iwork []int) float64 {
|
func (impl Implementation) Dgecon(norm lapack.MatrixNorm, n int, a []float64, lda int, anorm float64, work []float64, iwork []int) float64 {
|
||||||
checkMatrix(n, n, a, lda)
|
switch {
|
||||||
if norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum {
|
case norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum:
|
||||||
panic(badNorm)
|
panic(badNorm)
|
||||||
|
case n < 0:
|
||||||
|
panic(nLT0)
|
||||||
|
case lda < max(1, n):
|
||||||
|
panic(badLdA)
|
||||||
}
|
}
|
||||||
if len(work) < 4*n {
|
|
||||||
|
// Quick return if possible.
|
||||||
|
if n == 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(a) < (n-1)*lda+n:
|
||||||
|
panic(shortA)
|
||||||
|
case len(work) < 4*n:
|
||||||
panic(badWork)
|
panic(badWork)
|
||||||
}
|
case len(iwork) < n:
|
||||||
if len(iwork) < n {
|
|
||||||
panic(badWork)
|
panic(badWork)
|
||||||
}
|
}
|
||||||
|
|
||||||
if n == 0 {
|
// Quick return if possible.
|
||||||
return 1
|
if anorm == 0 {
|
||||||
} else if anorm == 0 {
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -61,50 +61,31 @@ import (
|
|||||||
// computed and wr[first:] and wi[first:] contain those eigenvalues which have
|
// computed and wr[first:] and wi[first:] contain those eigenvalues which have
|
||||||
// converged.
|
// converged.
|
||||||
func (impl Implementation) Dgeev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob, n int, a []float64, lda int, wr, wi []float64, vl []float64, ldvl int, vr []float64, ldvr int, work []float64, lwork int) (first int) {
|
func (impl Implementation) Dgeev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob, n int, a []float64, lda int, wr, wi []float64, vl []float64, ldvl int, vr []float64, ldvr int, work []float64, lwork int) (first int) {
|
||||||
var wantvl bool
|
wantvl := jobvl == lapack.LeftEVCompute
|
||||||
switch jobvl {
|
wantvr := jobvr == lapack.RightEVCompute
|
||||||
default:
|
|
||||||
panic("lapack: invalid LeftEVJob")
|
|
||||||
case lapack.LeftEVCompute:
|
|
||||||
wantvl = true
|
|
||||||
case lapack.LeftEVNone:
|
|
||||||
}
|
|
||||||
var wantvr bool
|
|
||||||
switch jobvr {
|
|
||||||
default:
|
|
||||||
panic("lapack: invalid RightEVJob")
|
|
||||||
case lapack.RightEVCompute:
|
|
||||||
wantvr = true
|
|
||||||
case lapack.RightEVNone:
|
|
||||||
}
|
|
||||||
switch {
|
|
||||||
case n < 0:
|
|
||||||
panic(nLT0)
|
|
||||||
case len(work) < lwork:
|
|
||||||
panic(shortWork)
|
|
||||||
}
|
|
||||||
var minwrk int
|
var minwrk int
|
||||||
if wantvl || wantvr {
|
if wantvl || wantvr {
|
||||||
minwrk = max(1, 4*n)
|
minwrk = max(1, 4*n)
|
||||||
} else {
|
} else {
|
||||||
minwrk = max(1, 3*n)
|
minwrk = max(1, 3*n)
|
||||||
}
|
}
|
||||||
if lwork != -1 {
|
switch {
|
||||||
checkMatrix(n, n, a, lda)
|
case jobvl != lapack.LeftEVCompute && jobvl != lapack.LeftEVNone:
|
||||||
if wantvl {
|
panic("lapack: invalid LeftEVJob")
|
||||||
checkMatrix(n, n, vl, ldvl)
|
case jobvr != lapack.RightEVCompute && jobvr != lapack.RightEVNone:
|
||||||
}
|
panic("lapack: invalid RightEVJob")
|
||||||
if wantvr {
|
case n < 0:
|
||||||
checkMatrix(n, n, vr, ldvr)
|
panic(nLT0)
|
||||||
}
|
case lda < max(1, n):
|
||||||
switch {
|
panic(badLdA)
|
||||||
case len(wr) != n:
|
case ldvl < 1 || (ldvl < n && wantvl):
|
||||||
panic("lapack: bad length of wr")
|
panic(badLdVL)
|
||||||
case len(wi) != n:
|
case ldvr < 1 || (ldvr < n && wantvr):
|
||||||
panic("lapack: bad length of wi")
|
panic(badLdVR)
|
||||||
case lwork < minwrk:
|
case lwork < minwrk && lwork != -1:
|
||||||
panic(badWork)
|
panic(badWork)
|
||||||
}
|
case len(work) < lwork:
|
||||||
|
panic(shortWork)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quick return if possible.
|
// Quick return if possible.
|
||||||
@@ -139,6 +120,19 @@ func (impl Implementation) Dgeev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(a) < (n-1)*lda+n:
|
||||||
|
panic(shortA)
|
||||||
|
case len(wr) != n:
|
||||||
|
panic("lapack: bad length of wr")
|
||||||
|
case len(wi) != n:
|
||||||
|
panic("lapack: bad length of wi")
|
||||||
|
case len(vl) < (n-1)*ldvl+n && wantvl:
|
||||||
|
panic(shortVL)
|
||||||
|
case len(vr) < (n-1)*ldvr+n && wantvr:
|
||||||
|
panic(shortVR)
|
||||||
|
}
|
||||||
|
|
||||||
// Get machine constants.
|
// Get machine constants.
|
||||||
smlnum := math.Sqrt(dlamchS) / dlamchP
|
smlnum := math.Sqrt(dlamchS) / dlamchP
|
||||||
bignum := 1 / smlnum
|
bignum := 1 / smlnum
|
||||||
|
@@ -56,16 +56,29 @@ import "gonum.org/v1/gonum/blas"
|
|||||||
//
|
//
|
||||||
// Dgehd2 is an internal routine. It is exported for testing purposes.
|
// Dgehd2 is an internal routine. It is exported for testing purposes.
|
||||||
func (impl Implementation) Dgehd2(n, ilo, ihi int, a []float64, lda int, tau, work []float64) {
|
func (impl Implementation) Dgehd2(n, ilo, ihi int, a []float64, lda int, tau, work []float64) {
|
||||||
checkMatrix(n, n, a, lda)
|
|
||||||
switch {
|
switch {
|
||||||
case ilo < 0 || ilo > max(0, n-1):
|
case n < 0:
|
||||||
|
panic(nLT0)
|
||||||
|
case ilo < 0 || max(0, n-1) < ilo:
|
||||||
panic(badIlo)
|
panic(badIlo)
|
||||||
case ihi < min(ilo, n-1) || ihi >= n:
|
case ihi < min(ilo, n-1) || n <= ihi:
|
||||||
panic(badIhi)
|
panic(badIhi)
|
||||||
|
case lda < max(1, n):
|
||||||
|
panic(badLdA)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quick return if possible.
|
||||||
|
if n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(a) < (n-1)*lda+n:
|
||||||
|
panic(shortA)
|
||||||
case len(tau) != n-1:
|
case len(tau) != n-1:
|
||||||
panic(badTau)
|
panic(badTau)
|
||||||
case len(work) < n:
|
case len(work) < n:
|
||||||
panic(badWork)
|
panic(shortWork)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := ilo; i < ihi; i++ {
|
for i := ilo; i < ihi; i++ {
|
||||||
|
@@ -81,6 +81,12 @@ func (impl Implementation) Dgehrd(n, ilo, ihi int, a []float64, lda int, tau, wo
|
|||||||
panic(shortWork)
|
panic(shortWork)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Quick return if possible.
|
||||||
|
if n == 0 {
|
||||||
|
work[0] = 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
nbmax = 64
|
nbmax = 64
|
||||||
ldt = nbmax + 1
|
ldt = nbmax + 1
|
||||||
@@ -95,9 +101,9 @@ func (impl Implementation) Dgehrd(n, ilo, ihi int, a []float64, lda int, tau, wo
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(a) < (n-1)*lda+n {
|
if len(a) < (n-1)*lda+n {
|
||||||
panic("lapack: insufficient length of a")
|
panic(shortA)
|
||||||
}
|
}
|
||||||
if len(tau) != n-1 && n > 0 {
|
if len(tau) != n-1 {
|
||||||
panic(badTau)
|
panic(badTau)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -25,14 +25,30 @@ import "gonum.org/v1/gonum/blas"
|
|||||||
//
|
//
|
||||||
// Dgelq2 is an internal routine. It is exported for testing purposes.
|
// Dgelq2 is an internal routine. It is exported for testing purposes.
|
||||||
func (impl Implementation) Dgelq2(m, n int, a []float64, lda int, tau, work []float64) {
|
func (impl Implementation) Dgelq2(m, n int, a []float64, lda int, tau, work []float64) {
|
||||||
checkMatrix(m, n, a, lda)
|
switch {
|
||||||
|
case m < 0:
|
||||||
|
panic(mLT0)
|
||||||
|
case n < 0:
|
||||||
|
panic(nLT0)
|
||||||
|
case lda < max(1, n):
|
||||||
|
panic(badLdA)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quick return if possible.
|
||||||
k := min(m, n)
|
k := min(m, n)
|
||||||
if len(tau) < k {
|
if k == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(a) < (m-1)*lda+n:
|
||||||
|
panic(shortA)
|
||||||
|
case len(tau) < k:
|
||||||
panic(badTau)
|
panic(badTau)
|
||||||
|
case len(work) < m:
|
||||||
|
panic(shortWork)
|
||||||
}
|
}
|
||||||
if len(work) < m {
|
|
||||||
panic(badWork)
|
|
||||||
}
|
|
||||||
for i := 0; i < k; i++ {
|
for i := 0; i < k; i++ {
|
||||||
a[i*lda+i], tau[i] = impl.Dlarfg(n-i, a[i*lda+i], a[i*lda+min(i+1, n-1):], 1)
|
a[i*lda+i], tau[i] = impl.Dlarfg(n-i, a[i*lda+i], a[i*lda+min(i+1, n-1):], 1)
|
||||||
if i < m-1 {
|
if i < m-1 {
|
||||||
|
@@ -47,7 +47,7 @@ func (impl Implementation) Dgelqf(m, n int, a []float64, lda int, tau, work []fl
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(a) < (m-1)*lda+n {
|
if len(a) < (m-1)*lda+n {
|
||||||
panic("lapack: insufficient length of a")
|
panic(shortA)
|
||||||
}
|
}
|
||||||
if len(tau) < k {
|
if len(tau) < k {
|
||||||
panic(badTau)
|
panic(badTau)
|
||||||
|
@@ -39,46 +39,63 @@ import (
|
|||||||
// In the special case that lwork == -1, work[0] will be set to the optimal working
|
// In the special case that lwork == -1, work[0] will be set to the optimal working
|
||||||
// length.
|
// length.
|
||||||
func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool {
|
func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool {
|
||||||
notran := trans == blas.NoTrans
|
|
||||||
checkMatrix(m, n, a, lda)
|
|
||||||
mn := min(m, n)
|
mn := min(m, n)
|
||||||
checkMatrix(max(m, n), nrhs, b, ldb)
|
minwrk := mn + max(mn, nrhs)
|
||||||
|
switch {
|
||||||
|
case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans:
|
||||||
|
panic(badTrans)
|
||||||
|
case m < 0:
|
||||||
|
panic(mLT0)
|
||||||
|
case n < 0:
|
||||||
|
panic(nLT0)
|
||||||
|
case nrhs < 0:
|
||||||
|
panic(nrhsLT0)
|
||||||
|
case lda < max(1, n):
|
||||||
|
panic(badLdA)
|
||||||
|
case ldb < max(1, nrhs):
|
||||||
|
panic(badLdB)
|
||||||
|
case lwork < max(1, minwrk) && lwork != -1:
|
||||||
|
panic(badWork)
|
||||||
|
case len(work) < max(1, lwork):
|
||||||
|
panic(shortWork)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quick return if possible.
|
||||||
|
if mn == 0 || nrhs == 0 {
|
||||||
|
impl.Dlaset(blas.All, max(m, n), nrhs, 0, 0, b, ldb)
|
||||||
|
work[0] = 1
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// Find optimal block size.
|
// Find optimal block size.
|
||||||
tpsd := true
|
|
||||||
if notran {
|
|
||||||
tpsd = false
|
|
||||||
}
|
|
||||||
var nb int
|
var nb int
|
||||||
if m >= n {
|
if m >= n {
|
||||||
nb = impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1)
|
nb = impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1)
|
||||||
if tpsd {
|
if trans != blas.NoTrans {
|
||||||
nb = max(nb, impl.Ilaenv(1, "DORMQR", "LN", m, nrhs, n, -1))
|
nb = max(nb, impl.Ilaenv(1, "DORMQR", "LN", m, nrhs, n, -1))
|
||||||
} else {
|
} else {
|
||||||
nb = max(nb, impl.Ilaenv(1, "DORMQR", "LT", m, nrhs, n, -1))
|
nb = max(nb, impl.Ilaenv(1, "DORMQR", "LT", m, nrhs, n, -1))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
nb = impl.Ilaenv(1, "DGELQF", " ", m, n, -1, -1)
|
nb = impl.Ilaenv(1, "DGELQF", " ", m, n, -1, -1)
|
||||||
if tpsd {
|
if trans != blas.NoTrans {
|
||||||
nb = max(nb, impl.Ilaenv(1, "DORMLQ", "LT", n, nrhs, m, -1))
|
nb = max(nb, impl.Ilaenv(1, "DORMLQ", "LT", n, nrhs, m, -1))
|
||||||
} else {
|
} else {
|
||||||
nb = max(nb, impl.Ilaenv(1, "DORMLQ", "LN", n, nrhs, m, -1))
|
nb = max(nb, impl.Ilaenv(1, "DORMLQ", "LN", n, nrhs, m, -1))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
wsize := max(1, mn+max(mn, nrhs)*nb)
|
||||||
|
work[0] = float64(wsize)
|
||||||
|
|
||||||
if lwork == -1 {
|
if lwork == -1 {
|
||||||
work[0] = float64(max(1, mn+max(mn, nrhs)*nb))
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(work) < lwork {
|
switch {
|
||||||
panic(shortWork)
|
case len(a) < (m-1)*lda+n:
|
||||||
}
|
panic(shortA)
|
||||||
if lwork < mn+max(mn, nrhs) {
|
case len(b) < (max(m, n)-1)*ldb+nrhs:
|
||||||
panic(badWork)
|
panic(shortB)
|
||||||
}
|
|
||||||
if m == 0 || n == 0 || nrhs == 0 {
|
|
||||||
impl.Dlaset(blas.All, max(m, n), nrhs, 0, 0, b, ldb)
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scale the input matrices if they contain extreme values.
|
// Scale the input matrices if they contain extreme values.
|
||||||
@@ -97,7 +114,7 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
brow := m
|
brow := m
|
||||||
if tpsd {
|
if trans != blas.NoTrans {
|
||||||
brow = n
|
brow = n
|
||||||
}
|
}
|
||||||
bnrm := impl.Dlange(lapack.MaxAbs, brow, nrhs, b, ldb, nil)
|
bnrm := impl.Dlange(lapack.MaxAbs, brow, nrhs, b, ldb, nil)
|
||||||
@@ -114,7 +131,7 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float
|
|||||||
var scllen int
|
var scllen int
|
||||||
if m >= n {
|
if m >= n {
|
||||||
impl.Dgeqrf(m, n, a, lda, work, work[mn:], lwork-mn)
|
impl.Dgeqrf(m, n, a, lda, work, work[mn:], lwork-mn)
|
||||||
if !tpsd {
|
if trans == blas.NoTrans {
|
||||||
impl.Dormqr(blas.Left, blas.Trans, m, nrhs, n,
|
impl.Dormqr(blas.Left, blas.Trans, m, nrhs, n,
|
||||||
a, lda,
|
a, lda,
|
||||||
work[:n],
|
work[:n],
|
||||||
@@ -148,7 +165,7 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
impl.Dgelqf(m, n, a, lda, work, work[mn:], lwork-mn)
|
impl.Dgelqf(m, n, a, lda, work, work[mn:], lwork-mn)
|
||||||
if !tpsd {
|
if trans == blas.NoTrans {
|
||||||
ok := impl.Dtrtrs(blas.Lower, blas.NoTrans, blas.NonUnit,
|
ok := impl.Dtrtrs(blas.Lower, blas.NoTrans, blas.NonUnit,
|
||||||
m, nrhs,
|
m, nrhs,
|
||||||
a, lda,
|
a, lda,
|
||||||
@@ -196,5 +213,7 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float
|
|||||||
if ibscl == 2 {
|
if ibscl == 2 {
|
||||||
impl.Dlascl(lapack.General, 0, 0, bignum, bnrm, scllen, nrhs, b, ldb)
|
impl.Dlascl(lapack.General, 0, 0, bignum, bnrm, scllen, nrhs, b, ldb)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
work[0] = float64(wsize)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@@ -24,14 +24,30 @@ import "gonum.org/v1/gonum/blas"
|
|||||||
//
|
//
|
||||||
// Dgeql2 is an internal routine. It is exported for testing purposes.
|
// Dgeql2 is an internal routine. It is exported for testing purposes.
|
||||||
func (impl Implementation) Dgeql2(m, n int, a []float64, lda int, tau, work []float64) {
|
func (impl Implementation) Dgeql2(m, n int, a []float64, lda int, tau, work []float64) {
|
||||||
checkMatrix(m, n, a, lda)
|
switch {
|
||||||
if len(tau) < min(m, n) {
|
case m < 0:
|
||||||
panic(badTau)
|
panic(mLT0)
|
||||||
}
|
case n < 0:
|
||||||
if len(work) < n {
|
panic(nLT0)
|
||||||
panic(badWork)
|
case lda < max(1, n):
|
||||||
|
panic(badLdA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Quick return if possible.
|
||||||
k := min(m, n)
|
k := min(m, n)
|
||||||
|
if k == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(a) < (m-1)*lda+n:
|
||||||
|
panic(shortA)
|
||||||
|
case len(tau) < k:
|
||||||
|
panic(badTau)
|
||||||
|
case len(work) < n:
|
||||||
|
panic(shortWork)
|
||||||
|
}
|
||||||
|
|
||||||
var aii float64
|
var aii float64
|
||||||
for i := k - 1; i >= 0; i-- {
|
for i := k - 1; i >= 0; i-- {
|
||||||
// Generate elementary reflector H_i to annihilate A[0:m-k+i-1, n-k+i].
|
// Generate elementary reflector H_i to annihilate A[0:m-k+i-1, n-k+i].
|
||||||
|
@@ -30,7 +30,12 @@ const (
|
|||||||
badK2 = "lapack: k2 out of range"
|
badK2 = "lapack: k2 out of range"
|
||||||
badKperm = "lapack: incorrect permutation length"
|
badKperm = "lapack: incorrect permutation length"
|
||||||
badLdA = "lapack: bad leading dimension of A"
|
badLdA = "lapack: bad leading dimension of A"
|
||||||
|
badLdB = "lapack: bad leading dimension of B"
|
||||||
|
badLdC = "lapack: bad leading dimension of C"
|
||||||
badLdU = "lapack: bad leading dimension of U"
|
badLdU = "lapack: bad leading dimension of U"
|
||||||
|
badLdV = "lapack: bad leading dimension of V"
|
||||||
|
badLdVL = "lapack: bad leading dimension of VL"
|
||||||
|
badLdVR = "lapack: bad leading dimension of VR"
|
||||||
badLdVT = "lapack: bad leading dimension of VT"
|
badLdVT = "lapack: bad leading dimension of VT"
|
||||||
badNb = "lapack: nb out of range"
|
badNb = "lapack: nb out of range"
|
||||||
badNorm = "lapack: bad norm"
|
badNorm = "lapack: bad norm"
|
||||||
@@ -62,7 +67,20 @@ const (
|
|||||||
negZ = "lapack: negative z value"
|
negZ = "lapack: negative z value"
|
||||||
nLT0 = "lapack: n < 0"
|
nLT0 = "lapack: n < 0"
|
||||||
nLTM = "lapack: n < m"
|
nLTM = "lapack: n < m"
|
||||||
|
nrhsLT0 = "lapack: nrhs < 0"
|
||||||
offsetGTM = "lapack: offset > m"
|
offsetGTM = "lapack: offset > m"
|
||||||
shortWork = "lapack: working array shorter than declared"
|
ncvtLT0 = "lapack: ncvt < 0"
|
||||||
|
nruLT0 = "lapack: nru < 0"
|
||||||
|
nccLT0 = "lapack: ncc < 0"
|
||||||
|
shortA = "lapack: insufficient length of a"
|
||||||
|
shortB = "lapack: insufficient length of b"
|
||||||
|
shortC = "lapack: insufficient length of c"
|
||||||
|
shortScale = "lapack: insufficient length of scale"
|
||||||
|
shortU = "lapack: insufficient length of u"
|
||||||
|
shortV = "lapack: insufficient length of v"
|
||||||
|
shortVL = "lapack: insufficient length of vl"
|
||||||
|
shortVR = "lapack: insufficient length of vr"
|
||||||
|
shortVT = "lapack: insufficient length of vt"
|
||||||
|
shortWork = "lapack: working slice shorter than declared"
|
||||||
zeroDiv = "lapack: zero divisor"
|
zeroDiv = "lapack: zero divisor"
|
||||||
)
|
)
|
||||||
|
@@ -23,7 +23,6 @@ type Dbdsqrer interface {
|
|||||||
func DbdsqrTest(t *testing.T, impl Dbdsqrer) {
|
func DbdsqrTest(t *testing.T, impl Dbdsqrer) {
|
||||||
rnd := rand.New(rand.NewSource(1))
|
rnd := rand.New(rand.NewSource(1))
|
||||||
bi := blas64.Implementation()
|
bi := blas64.Implementation()
|
||||||
_ = bi
|
|
||||||
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
|
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
n, ncvt, nru, ncc, ldvt, ldu, ldc int
|
n, ncvt, nru, ncc, ldvt, ldu, ldc int
|
||||||
@@ -49,13 +48,13 @@ func DbdsqrTest(t *testing.T, impl Dbdsqrer) {
|
|||||||
ldu := test.ldu
|
ldu := test.ldu
|
||||||
ldc := test.ldc
|
ldc := test.ldc
|
||||||
if ldvt == 0 {
|
if ldvt == 0 {
|
||||||
ldvt = ncvt
|
ldvt = max(1, ncvt)
|
||||||
}
|
}
|
||||||
if ldu == 0 {
|
if ldu == 0 {
|
||||||
ldu = n
|
ldu = max(1, n)
|
||||||
}
|
}
|
||||||
if ldc == 0 {
|
if ldc == 0 {
|
||||||
ldc = ncc
|
ldc = max(1, ncc)
|
||||||
}
|
}
|
||||||
|
|
||||||
d := make([]float64, n)
|
d := make([]float64, n)
|
||||||
@@ -92,7 +91,7 @@ func DbdsqrTest(t *testing.T, impl Dbdsqrer) {
|
|||||||
pt[i*ldpt+i] = 1
|
pt[i*ldpt+i] = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
ok := impl.Dbdsqr(uplo, n, n, n, 0, d, e, pt, ldpt, q, ldq, nil, 0, work)
|
ok := impl.Dbdsqr(uplo, n, n, n, 0, d, e, pt, ldpt, q, ldq, nil, 1, work)
|
||||||
|
|
||||||
isUpper := uplo == blas.Upper
|
isUpper := uplo == blas.Upper
|
||||||
errStr := fmt.Sprintf("isUpper = %v, n = %v, ncvt = %v, nru = %v, ncc = %v", isUpper, n, ncvt, nru, ncc)
|
errStr := fmt.Sprintf("isUpper = %v, n = %v, ncvt = %v, nru = %v, ncc = %v", isUpper, n, ncvt, nru, ncc)
|
||||||
|
@@ -559,11 +559,15 @@ func testDgeev(t *testing.T, impl Dgeever, tc string, test dgeevTest, jobvl lapa
|
|||||||
var vl blas64.General
|
var vl blas64.General
|
||||||
if jobvl == lapack.LeftEVCompute {
|
if jobvl == lapack.LeftEVCompute {
|
||||||
vl = nanGeneral(n, n, n)
|
vl = nanGeneral(n, n, n)
|
||||||
|
} else {
|
||||||
|
vl.Stride = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
var vr blas64.General
|
var vr blas64.General
|
||||||
if jobvr == lapack.RightEVCompute {
|
if jobvr == lapack.RightEVCompute {
|
||||||
vr = nanGeneral(n, n, n)
|
vr = nanGeneral(n, n, n)
|
||||||
|
} else {
|
||||||
|
vr.Stride = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
wr := make([]float64, n)
|
wr := make([]float64, n)
|
||||||
|
@@ -156,10 +156,15 @@ func (e *Eigen) Factorize(a Matrix, left, right bool) (ok bool) {
|
|||||||
if left {
|
if left {
|
||||||
vl = *NewDense(r, r, nil)
|
vl = *NewDense(r, r, nil)
|
||||||
jobvl = lapack.LeftEVCompute
|
jobvl = lapack.LeftEVCompute
|
||||||
|
} else {
|
||||||
|
vl.mat.Stride = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
if right {
|
if right {
|
||||||
vr = *NewDense(c, c, nil)
|
vr = *NewDense(c, c, nil)
|
||||||
jobvr = lapack.RightEVCompute
|
jobvr = lapack.RightEVCompute
|
||||||
|
} else {
|
||||||
|
vr.mat.Stride = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
wr := getFloats(c, false)
|
wr := getFloats(c, false)
|
||||||
|
Reference in New Issue
Block a user