diff --git a/lapack/gonum/dgesvd.go b/lapack/gonum/dgesvd.go index fc309e03..afb73152 100644 --- a/lapack/gonum/dgesvd.go +++ b/lapack/gonum/dgesvd.go @@ -95,7 +95,7 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float panic(nLT0) case lda < max(1, n): panic(badLdA) - case ldu < 1 || (wantua && ldu < m) || (wantus && ldu < minmn): + case ldu < 1, wantua && ldu < m, wantus && ldu < minmn: panic(badLdU) case ldvt < 1 || (wantvas && ldvt < n): panic(badLdVT) diff --git a/lapack/gonum/dggsvd3.go b/lapack/gonum/dggsvd3.go index abdaf9b0..f73198ba 100644 --- a/lapack/gonum/dggsvd3.go +++ b/lapack/gonum/dggsvd3.go @@ -128,11 +128,11 @@ func (impl Implementation) Dggsvd3(jobU, jobV, jobQ lapack.GSVDJob, m, n, p int, panic(badLdA) case ldb < max(1, n): panic(badLdB) - case ldu < 1 || (wantu && ldu < m): + case ldu < 1, wantu && ldu < m: panic(badLdU) - case ldv < 1 || (wantv && ldv < p): + case ldv < 1, wantv && ldv < p: panic(badLdV) - case ldq < 1 || (wantq && ldq < n): + case ldq < 1, wantq && ldq < n: panic(badLdQ) case len(iwork) < n: panic(shortWork) diff --git a/lapack/gonum/dggsvp3.go b/lapack/gonum/dggsvp3.go index 84905160..5b428be7 100644 --- a/lapack/gonum/dggsvp3.go +++ b/lapack/gonum/dggsvp3.go @@ -74,11 +74,11 @@ func (impl Implementation) Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int, panic(badLdA) case ldb < max(1, n): panic(badLdB) - case ldu < 1 || (wantu && ldu < m): + case ldu < 1, wantu && ldu < m: panic(badLdU) - case ldv < 1 || (wantv && ldv < p): + case ldv < 1, wantv && ldv < p: panic(badLdV) - case ldq < 1 || (wantq && ldq < n): + case ldq < 1, wantq && ldq < n: panic(badLdQ) case len(iwork) != n: panic(shortWork) diff --git a/lapack/gonum/dhseqr.go b/lapack/gonum/dhseqr.go index bfbbd941..11091c27 100644 --- a/lapack/gonum/dhseqr.go +++ b/lapack/gonum/dhseqr.go @@ -135,7 +135,7 @@ func (impl Implementation) Dhseqr(job lapack.SchurJob, compz lapack.SchurComp, n panic(badIhi) case ldh < max(1, n): panic(badLdH) - case ldz < 1 || (wantz && ldz < max(1, n)): + case ldz < 1, wantz && ldz < n: panic(badLdZ) case lwork < max(1, n) && lwork != -1: panic(badWork) diff --git a/lapack/gonum/dlahqr.go b/lapack/gonum/dlahqr.go index 44fea325..3dd168c3 100644 --- a/lapack/gonum/dlahqr.go +++ b/lapack/gonum/dlahqr.go @@ -75,7 +75,7 @@ func (impl Implementation) Dlahqr(wantt, wantz bool, n, ilo, ihi int, h []float6 switch { case n < 0: panic(nLT0) - case ilo < 0 || max(0, ihi) < ilo: + case ilo < 0, max(0, ihi) < ilo: panic(badIlo) case ihi >= n: panic(badIhi) @@ -85,7 +85,7 @@ func (impl Implementation) Dlahqr(wantt, wantz bool, n, ilo, ihi int, h []float6 panic("lapack: iloz out of range") case wantz && (ihiz < ihi || n <= ihiz): panic("lapack: ihiz out of range") - case ldz < 1 || (wantz && ldz < max(1, n)): + case ldz < 1, wantz && ldz < n: panic(badLdZ) } diff --git a/lapack/gonum/dlaqr04.go b/lapack/gonum/dlaqr04.go index 033f67e6..1bb2f9a0 100644 --- a/lapack/gonum/dlaqr04.go +++ b/lapack/gonum/dlaqr04.go @@ -135,7 +135,7 @@ func (impl Implementation) Dlaqr04(wantt, wantz bool, n, ilo, ihi int, h []float panic("lapack: invalid value of iloz") case wantz && (ihiz < ihi || n <= ihiz): panic("lapack: invalid value of ihiz") - case ldz < 1 || (wantz && ldz < max(1, n)): + case ldz < 1, wantz && ldz < n: panic(badLdZ) case lwork < 1 && lwork != -1: panic(badWork) diff --git a/lapack/gonum/dlaqr23.go b/lapack/gonum/dlaqr23.go index f375aabc..e8e2f211 100644 --- a/lapack/gonum/dlaqr23.go +++ b/lapack/gonum/dlaqr23.go @@ -5,7 +5,6 @@ package gonum import ( - "fmt" "math" "gonum.org/v1/gonum/blas" @@ -88,7 +87,6 @@ func (impl Implementation) Dlaqr23(wantt, wantz bool, n, ktop, kbot, nw int, h [ case kbot < min(ktop, n-1) || n <= kbot: panic("lapack: invalid value of kbot") case nw < 0 || kbot-ktop+1+1 < nw: - fmt.Println(nw, kbot, ktop) panic("lapack: invalid value of nw") case ldh < max(1, n): panic(badLdH) @@ -96,7 +94,7 @@ func (impl Implementation) Dlaqr23(wantt, wantz bool, n, ktop, kbot, nw int, h [ panic("lapack: invalid value of iloz") case wantz && (ihiz < kbot || n <= ihiz): panic("lapack: invalid value of ihiz") - case ldz < 1 || (wantz && ldz < max(1, n)): + case ldz < 1, wantz && ldz < n: panic(badLdZ) case ldv < max(1, nw): panic(badLdV) diff --git a/lapack/gonum/dlaqr5.go b/lapack/gonum/dlaqr5.go index f4e83dbd..ca65ac78 100644 --- a/lapack/gonum/dlaqr5.go +++ b/lapack/gonum/dlaqr5.go @@ -91,7 +91,7 @@ func (impl Implementation) Dlaqr5(wantt, wantz bool, kacc22 int, n, ktop, kbot, panic("lapack: invalid value of ihiz") case wantz && iloz < 0 || ihiz < iloz: panic("lapack: invalid value of iloz") - case ldz < 1 || (wantz && ldz < max(1, n)): + case ldz < 1, wantz && ldz < n: panic(badLdZ) case wantz && len(z) < (n-1)*ldz+n: panic(shortZ) diff --git a/lapack/gonum/dorg2l.go b/lapack/gonum/dorg2l.go index 32071986..fdc8f46f 100644 --- a/lapack/gonum/dorg2l.go +++ b/lapack/gonum/dorg2l.go @@ -22,23 +22,34 @@ import ( // // Dorg2l is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorg2l(m, n, k int, a []float64, lda int, tau, work []float64) { - checkMatrix(m, n, a, lda) - if len(tau) < k { - panic(badTau) - } - if len(work) < n { - panic(badWork) - } - if m < n { - panic(mLTN) - } - if k > n { + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case n > m: + panic(nGTM) + case k < 0: + panic(kLT0) + case k > n: panic(kGTN) + case lda < max(1, n): + panic(badLdA) } + if n == 0 { return } + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(badTau) + case len(work) < n: + panic(shortWork) + } + // Initialize columns 0:n-k to columns of the unit matrix. for j := 0; j < n-k; j++ { for l := 0; l < m; l++ { diff --git a/lapack/gonum/dorg2r.go b/lapack/gonum/dorg2r.go index d7525017..c3de651c 100644 --- a/lapack/gonum/dorg2r.go +++ b/lapack/gonum/dorg2r.go @@ -17,26 +17,36 @@ import ( // // Dorg2r is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorg2r(m, n, k int, a []float64, lda int, tau []float64, work []float64) { - checkMatrix(m, n, a, lda) - if len(tau) < k { - panic(badTau) - } - if len(work) < n { - panic(badWork) - } - if k > n { + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case n > m: + panic(nGTM) + case k < 0: + panic(kLT0) + case k > n: panic(kGTN) + case lda < max(1, n): + panic(badLdA) } - if n > m { - panic(mLTN) - } - if len(work) < n { - panic(badWork) - } + if n == 0 { return } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(badTau) + case len(work) < n: + panic(shortWork) + } + bi := blas64.Implementation() + // Initialize columns k+1:n to columns of the unit matrix. for l := 0; l < m; l++ { for j := k; j < n; j++ { diff --git a/lapack/gonum/dorgbr.go b/lapack/gonum/dorgbr.go index f1406281..80491d79 100644 --- a/lapack/gonum/dorgbr.go +++ b/lapack/gonum/dorgbr.go @@ -20,38 +20,45 @@ import "gonum.org/v1/gonum/lapack" // // Dorgbr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorgbr(vect lapack.GenOrtho, m, n, k int, a []float64, lda int, tau, work []float64, lwork int) { + wantq := vect == lapack.GenerateQ mn := min(m, n) - var wantq bool - switch vect { - case lapack.GenerateQ: - wantq = true - case lapack.GeneratePT: - default: + switch { + case vect != lapack.GenerateQ && vect != lapack.GeneratePT: panic(badGenOrtho) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case wantq && n > m: + panic(nGTM) + case wantq && n < min(m, k): + panic("lapack: n < min(m,k)") + case !wantq && m > n: + panic(mGTN) + case !wantq && m < min(n, k): + panic("lapack: m < min(n,k)") + case lda < max(1, n) && lwork != -1: + // Normally, we follow the reference and require the leading + // dimension to be always valid, even in case of workspace + // queries. However, if a caller provided a placeholder value + // for lda (and a) when doing a workspace query that didn't + // fulfill the condition here, it would cause a panic. This is + // exactly what Dgesvd does. + panic(badLdA) + case lwork < max(1, mn) && lwork != -1: + panic(badWork) + case len(work) < max(1, lwork): + panic(shortWork) } - if wantq { - if m < n || n < min(m, k) || m < min(m, k) { - panic(badDims) - } - } else { - if n < m || m < min(n, k) || n < min(n, k) { - panic(badDims) - } - } - if wantq { - if m >= k { - checkMatrix(m, k, a, lda) - } else { - checkMatrix(m, m, a, lda) - } - } else { - if n >= k { - checkMatrix(k, n, a, lda) - } else { - checkMatrix(n, n, a, lda) - } - } + + // Quick return if possible. work[0] = 1 + if m == 0 || n == 0 { + return + } + if wantq { if m >= k { impl.Dorgqr(m, n, k, a, lda, tau, work, -1) @@ -71,16 +78,16 @@ func (impl Implementation) Dorgbr(vect lapack.GenOrtho, m, n, k int, a []float64 work[0] = float64(lworkopt) return } - if len(work) < lwork { - panic(badWork) - } - if lwork < mn { - panic(badWork) - } - if m == 0 || n == 0 { - work[0] = 1 - return + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case wantq && len(tau) < min(m, k): + panic(badTau) + case !wantq && len(tau) < min(n, k): + panic(badTau) } + if wantq { // Form Q, determined by a call to Dgebrd to reduce an m×k matrix. if m >= k { diff --git a/lapack/gonum/dorghr.go b/lapack/gonum/dorghr.go index b7ea7b2c..27d5aaac 100644 --- a/lapack/gonum/dorghr.go +++ b/lapack/gonum/dorghr.go @@ -33,29 +33,37 @@ package gonum // // Dorghr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorghr(n, ilo, ihi int, a []float64, lda int, tau, work []float64, lwork int) { - checkMatrix(n, n, a, lda) nh := ihi - ilo switch { case ilo < 0 || max(1, n) <= ilo: panic(badIlo) case ihi < min(ilo, n-1) || n <= ihi: panic(badIhi) + case lda < max(1, n): + panic(badLdA) case lwork < max(1, nh) && lwork != -1: panic(badWork) case len(work) < max(1, lwork): panic(shortWork) } + // Quick return if possible. + if n == 0 { + work[0] = 1 + return + } + lwkopt := max(1, nh) * impl.Ilaenv(1, "DORGQR", " ", nh, nh, nh, -1) if lwork == -1 { work[0] = float64(lwkopt) return } - // Quick return if possible. - if n == 0 { - work[0] = 1 - return + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(tau) < n-1: + panic(badTau) } // Shift the vectors which define the elementary reflectors one column diff --git a/lapack/gonum/dorgl2.go b/lapack/gonum/dorgl2.go index 06303812..fe3673ce 100644 --- a/lapack/gonum/dorgl2.go +++ b/lapack/gonum/dorgl2.go @@ -17,26 +17,34 @@ import ( // // Dorgl2 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorgl2(m, n, k int, a []float64, lda int, tau, work []float64) { - checkMatrix(m, n, a, lda) - if len(tau) < k { - panic(badTau) - } - if k > m { - panic(kGTM) - } - if k > m { - panic(kGTM) - } - if m > n { + switch { + case m < 0: + panic(mLT0) + case n < m: panic(nLTM) + case k < 0: + panic(kLT0) + case k > m: + panic(kGTM) + case lda < max(1, m): + panic(badLdA) } - if len(work) < m { - panic(badWork) - } + if m == 0 { return } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(badTau) + case len(work) < m: + panic(shortWork) + } + bi := blas64.Implementation() + if k < m { for i := k; i < m; i++ { for j := 0; j < n; j++ { diff --git a/lapack/gonum/dorglq.go b/lapack/gonum/dorglq.go index 2416d973..2f55a26e 100644 --- a/lapack/gonum/dorglq.go +++ b/lapack/gonum/dorglq.go @@ -27,12 +27,14 @@ import ( // Dorglq is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorglq(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) { switch { - case k < 0: - panic(kLT0) - case m < k: - panic(kGTM) + case m < 0: + panic(mLT0) case n < m: panic(nLTM) + case k < 0: + panic(kLT0) + case k > m: + panic(kGTM) case lda < max(1, n): panic(badLdA) case lwork < max(1, m) && lwork != -1: @@ -52,10 +54,10 @@ func (impl Implementation) Dorglq(m, n, k int, a []float64, lda int, tau, work [ return } - if len(a) < (m-1)*lda+n { - panic("lapack: insufficient length of a") - } - if len(tau) < k { + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: panic(badTau) } diff --git a/lapack/gonum/dorgql.go b/lapack/gonum/dorgql.go index c48eb69a..281975a0 100644 --- a/lapack/gonum/dorgql.go +++ b/lapack/gonum/dorgql.go @@ -33,10 +33,12 @@ import ( // Dorgql is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) { switch { + case m < 0: + panic(mLT0) case n < 0: panic(nLT0) - case m < n: - panic(mLTN) + case n > m: + panic(nGTM) case k < 0: panic(kLT0) case k > n: @@ -61,10 +63,10 @@ func (impl Implementation) Dorgql(m, n, k int, a []float64, lda int, tau, work [ return } - if len(a) < (m-1)*lda+n { - panic("lapack: insufficient length of a") - } - if len(tau) < k { + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: panic(badTau) } diff --git a/lapack/gonum/dorgqr.go b/lapack/gonum/dorgqr.go index 0cd63595..d996df27 100644 --- a/lapack/gonum/dorgqr.go +++ b/lapack/gonum/dorgqr.go @@ -29,12 +29,24 @@ import ( // Dorgqr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) { switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case n > m: + panic(nGTM) case k < 0: panic(kLT0) - case n < k: + case k > n: panic(kGTN) - case m < n: - panic(mLTN) + case lda < max(1, n) && lwork != -1: + // Normally, we follow the reference and require the leading + // dimension to be always valid, even in case of workspace + // queries. However, if a caller provided a placeholder value + // for lda (and a) when doing a workspace query that didn't + // fulfill the condition here, it would cause a panic. This is + // exactly what Dgesvd does. + panic(badLdA) case lwork < max(1, n) && lwork != -1: panic(badWork) case len(work) < max(1, lwork): @@ -54,10 +66,8 @@ func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work [ } switch { - case lda < max(1, n): - panic(badLdA) case len(a) < (m-1)*lda+n: - panic("lapack: insuffcient length of a") + panic(shortA) case len(tau) < k: panic(badTau) } diff --git a/lapack/gonum/dorgtr.go b/lapack/gonum/dorgtr.go index 6984ff55..fefbf9b9 100644 --- a/lapack/gonum/dorgtr.go +++ b/lapack/gonum/dorgtr.go @@ -25,19 +25,17 @@ import "gonum.org/v1/gonum/blas" // // Dorgtr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorgtr(uplo blas.Uplo, n int, a []float64, lda int, tau, work []float64, lwork int) { - checkMatrix(n, n, a, lda) - if len(tau) < n-1 { - panic(badTau) - } - if len(work) < lwork { - panic(badWork) - } - if lwork < n-1 && lwork != -1 { - panic(badWork) - } - upper := uplo == blas.Upper - if !upper && uplo != blas.Lower { + switch { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case lwork < max(1, n-1) && lwork != -1: + panic(badWork) + case len(work) < max(1, lwork): + panic(shortWork) } if n == 0 { @@ -46,7 +44,7 @@ func (impl Implementation) Dorgtr(uplo blas.Uplo, n int, a []float64, lda int, t } var nb int - if upper { + if uplo == blas.Upper { nb = impl.Ilaenv(1, "DORGQL", " ", n-1, n-1, n-1, -1) } else { nb = impl.Ilaenv(1, "DORGQR", " ", n-1, n-1, n-1, -1) @@ -57,7 +55,14 @@ func (impl Implementation) Dorgtr(uplo blas.Uplo, n int, a []float64, lda int, t return } - if upper { + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(tau) < n-1: + panic(badTau) + } + + if uplo == blas.Upper { // Q was determined by a call to Dsytrd with uplo == blas.Upper. // Shift the vectors which define the elementary reflectors one column // to the left, and set the last row and column of Q to those of the unit diff --git a/lapack/gonum/dorm2r.go b/lapack/gonum/dorm2r.go index e8fb1d4d..30012cc1 100644 --- a/lapack/gonum/dorm2r.go +++ b/lapack/gonum/dorm2r.go @@ -23,37 +23,50 @@ import "gonum.org/v1/gonum/blas" // // Dorm2r is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorm2r(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64) { - if side != blas.Left && side != blas.Right { + left := side == blas.Left + switch { + case !left && side != blas.Right: panic(badSide) - } - if trans != blas.Trans && trans != blas.NoTrans { + case trans != blas.Trans && trans != blas.NoTrans: panic(badTrans) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case left && k > m: + panic(kGTM) + case !left && k > n: + panic(kGTN) + case lda < max(1, k): + panic(badLdA) + case ldc < max(1, n): + panic(badLdC) } - left := side == blas.Left - notran := trans == blas.NoTrans - if left { - // Q is m x m - checkMatrix(m, k, a, lda) - if len(work) < n { - panic(badWork) - } - } else { - // Q is n x n - checkMatrix(n, k, a, lda) - if len(work) < m { - panic(badWork) - } - } - checkMatrix(m, n, c, ldc) + // Quick return if possible. if m == 0 || n == 0 || k == 0 { return } - if len(tau) < k { + + switch { + case left && len(a) < (m-1)*lda+k: + panic(shortA) + case !left && len(a) < (n-1)*lda+k: + panic(shortA) + case len(c) < (m-1)*ldc+n: + panic(shortC) + case len(tau) < k: panic(badTau) + case left && len(work) < n: + panic(shortWork) + case !left && len(work) < m: + panic(badWork) } + if left { - if notran { + if trans == blas.NoTrans { for i := k - 1; i >= 0; i-- { aii := a[i*lda+i] a[i*lda+i] = 1 @@ -70,7 +83,7 @@ func (impl Implementation) Dorm2r(side blas.Side, trans blas.Transpose, m, n, k } return } - if notran { + if trans == blas.NoTrans { for i := 0; i < k; i++ { aii := a[i*lda+i] a[i*lda+i] = 1 diff --git a/lapack/gonum/dormbr.go b/lapack/gonum/dormbr.go index cbd65149..1b3c4123 100644 --- a/lapack/gonum/dormbr.go +++ b/lapack/gonum/dormbr.go @@ -45,40 +45,43 @@ import ( // // Dormbr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dormbr(vect lapack.ApplyOrtho, side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) { - if side != blas.Left && side != blas.Right { - panic(badSide) - } - if trans != blas.NoTrans && trans != blas.Trans { - panic(badTrans) - } - if vect != lapack.ApplyP && vect != lapack.ApplyQ { - panic(badApplyOrtho) - } nq := n nw := m if side == blas.Left { nq = m nw = n } - if vect == lapack.ApplyQ { - checkMatrix(nq, min(nq, k), a, lda) - } else { - checkMatrix(min(nq, k), nq, a, lda) - } - if len(tau) < min(nq, k) { - panic(badTau) - } - checkMatrix(m, n, c, ldc) - if len(work) < lwork { + applyQ := vect == lapack.ApplyQ + switch { + case !applyQ && vect != lapack.ApplyP: + panic(badApplyOrtho) + case side != blas.Left && side != blas.Right: + panic(badSide) + case trans != blas.NoTrans && trans != blas.Trans: + panic(badTrans) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case applyQ && lda < max(1, min(nq, k)): + panic(badLdA) + case !applyQ && lda < max(1, nq): + panic(badLdA) + case ldc < max(1, n): + panic(badLdC) + case lwork < max(1, nw) && lwork != -1: + panic(badWork) + case len(work) < max(1, lwork): panic(shortWork) } - if lwork < max(1, nw) && lwork != -1 { - panic(badWork) - } - applyQ := vect == lapack.ApplyQ - left := side == blas.Left - var nb int + // Quick return if possible. + if m == 0 || n == 0 { + work[0] = 1 + return + } // The current implementation does not use opts, but a future change may // use these options so construct them. @@ -93,14 +96,15 @@ func (impl Implementation) Dormbr(vect lapack.ApplyOrtho, side blas.Side, trans } else { opts += "N" } + var nb int if applyQ { - if left { + if side == blas.Left { nb = impl.Ilaenv(1, "DORMQR", opts, m-1, n, m-1, -1) } else { nb = impl.Ilaenv(1, "DORMQR", opts, m, n-1, n-1, -1) } } else { - if left { + if side == blas.Left { nb = impl.Ilaenv(1, "DORMLQ", opts, m-1, n, m-1, -1) } else { nb = impl.Ilaenv(1, "DORMLQ", opts, m, n-1, n-1, -1) @@ -109,7 +113,21 @@ func (impl Implementation) Dormbr(vect lapack.ApplyOrtho, side blas.Side, trans lworkopt := max(1, nw) * nb if lwork == -1 { work[0] = float64(lworkopt) + return } + + minnqk := min(nq, k) + switch { + case applyQ && len(a) < (nq-1)*lda+minnqk: + panic(shortA) + case !applyQ && len(a) < (minnqk-1)*lda+nq: + panic(shortA) + case len(tau) < minnqk: + panic(badTau) + case len(c) < (m-1)*ldc+n: + panic(shortC) + } + if applyQ { // Change the operation to get Q depending on the size of the initial // matrix to Dgebrd. The size matters due to the storage location of @@ -121,7 +139,7 @@ func (impl Implementation) Dormbr(vect lapack.ApplyOrtho, side blas.Side, trans ni := n - 1 i1 := 0 i2 := 1 - if left { + if side == blas.Left { mi = m - 1 ni = n i1 = 1 @@ -132,10 +150,12 @@ func (impl Implementation) Dormbr(vect lapack.ApplyOrtho, side blas.Side, trans work[0] = float64(lworkopt) return } + transt := blas.Trans if trans == blas.Trans { transt = blas.NoTrans } + // Change the operation to get P depending on the size of the initial // matrix to Dgebrd. The size matters due to the storage location of // the off-diagonal elements. @@ -146,7 +166,7 @@ func (impl Implementation) Dormbr(vect lapack.ApplyOrtho, side blas.Side, trans ni := n - 1 i1 := 0 i2 := 1 - if left { + if side == blas.Left { mi = m - 1 ni = n i1 = 1 diff --git a/lapack/gonum/dormhr.go b/lapack/gonum/dormhr.go index f6cb1b26..c912d864 100644 --- a/lapack/gonum/dormhr.go +++ b/lapack/gonum/dormhr.go @@ -50,39 +50,37 @@ import "gonum.org/v1/gonum/blas" // // Dormhr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dormhr(side blas.Side, trans blas.Transpose, m, n, ilo, ihi int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) { - var ( - nq int // The order of Q. - nw int // The minimum length of work. - ) - switch side { - case blas.Left: + nq := n // The order of Q. + nw := m // The minimum length of work. + if side == blas.Left { nq = m nw = n - case blas.Right: - nq = n - nw = m - default: - panic(badSide) } switch { + case side != blas.Left && side != blas.Right: + panic(badSide) case trans != blas.NoTrans && trans != blas.Trans: panic(badTrans) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) case ilo < 0 || max(1, nq) <= ilo: panic(badIlo) case ihi < min(ilo, nq-1) || nq <= ihi: panic(badIhi) + case lda < max(1, nq): + panic(badLdA) case lwork < max(1, nw) && lwork != -1: panic(badWork) case len(work) < max(1, lwork): panic(shortWork) } - if lwork != -1 { - checkMatrix(m, n, c, ldc) - checkMatrix(nq, nq, a, lda) - if len(tau) != nq-1 && nq > 0 { - panic(badTau) - } + // Quick return if possible. + if m == 0 || n == 0 { + work[0] = 1 + return } nh := ihi - ilo @@ -106,10 +104,20 @@ func (impl Implementation) Dormhr(side blas.Side, trans blas.Transpose, m, n, il return } - if m == 0 || n == 0 || nh == 0 { + if nh == 0 { work[0] = 1 return } + + switch { + case len(a) < (nq-1)*lda+nq: + panic(shortA) + case len(c) < (m-1)*ldc+n: + panic(shortC) + case len(tau) != nq-1: + panic(badTau) + } + if side == blas.Left { impl.Dormqr(side, trans, nh, n, nh, a[(ilo+1)*lda+ilo:], lda, tau[ilo:ihi], c[(ilo+1)*ldc:], ldc, work, lwork) diff --git a/lapack/gonum/dorml2.go b/lapack/gonum/dorml2.go index 1c217b5b..4e04deee 100644 --- a/lapack/gonum/dorml2.go +++ b/lapack/gonum/dorml2.go @@ -23,32 +23,51 @@ import "gonum.org/v1/gonum/blas" // // Dorml2 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorml2(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64) { - if side != blas.Left && side != blas.Right { + left := side == blas.Left + switch { + case !left && side != blas.Right: panic(badSide) - } - if trans != blas.Trans && trans != blas.NoTrans { + case trans != blas.Trans && trans != blas.NoTrans: panic(badTrans) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case left && k > m: + panic(kGTM) + case !left && k > n: + panic(kGTN) + case left && lda < max(1, m): + panic(badLdA) + case !left && lda < max(1, n): + panic(badLdA) } - left := side == blas.Left - notran := trans == blas.NoTrans - if left { - checkMatrix(k, m, a, lda) - if len(work) < n { - panic(badWork) - } - } else { - checkMatrix(k, n, a, lda) - if len(work) < m { - panic(badWork) - } - } - checkMatrix(m, n, c, ldc) + // Quick return if possible. if m == 0 || n == 0 || k == 0 { return } + switch { - case left && notran: + case left && len(a) < (k-1)*lda+m: + panic(shortA) + case !left && len(a) < (k-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(badTau) + case len(c) < (m-1)*ldc+n: + panic(shortC) + case left && len(work) < n: + panic(shortWork) + case !left && len(work) < m: + panic(shortWork) + } + + notrans := trans == blas.NoTrans + switch { + case left && notrans: for i := 0; i < k; i++ { aii := a[i*lda+i] a[i*lda+i] = 1 @@ -56,7 +75,7 @@ func (impl Implementation) Dorml2(side blas.Side, trans blas.Transpose, m, n, k a[i*lda+i] = aii } - case left && !notran: + case left && !notrans: for i := k - 1; i >= 0; i-- { aii := a[i*lda+i] a[i*lda+i] = 1 @@ -64,7 +83,7 @@ func (impl Implementation) Dorml2(side blas.Side, trans blas.Transpose, m, n, k a[i*lda+i] = aii } - case !left && notran: + case !left && notrans: for i := k - 1; i >= 0; i-- { aii := a[i*lda+i] a[i*lda+i] = 1 @@ -72,7 +91,7 @@ func (impl Implementation) Dorml2(side blas.Side, trans blas.Transpose, m, n, k a[i*lda+i] = aii } - case !left && !notran: + case !left && !notrans: for i := 0; i < k; i++ { aii := a[i*lda+i] a[i*lda+i] = 1 diff --git a/lapack/gonum/dormlq.go b/lapack/gonum/dormlq.go index d7a27643..6ad71aae 100644 --- a/lapack/gonum/dormlq.go +++ b/lapack/gonum/dormlq.go @@ -28,33 +28,37 @@ import ( // tau contains the Householder scales and must have length at least k, and // this function will panic otherwise. func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) { - if side != blas.Left && side != blas.Right { - panic(badSide) - } - if trans != blas.Trans && trans != blas.NoTrans { - panic(badTrans) - } left := side == blas.Left - if left { - checkMatrix(k, m, a, lda) - } else { - checkMatrix(k, n, a, lda) - } - checkMatrix(m, n, c, ldc) - if len(tau) < k { - panic(badTau) - } - if len(work) < lwork { - panic(shortWork) - } nw := m if left { nw = n } - if lwork < max(1, nw) && lwork != -1 { + switch { + case !left && side != blas.Right: + panic(badSide) + case trans != blas.Trans && trans != blas.NoTrans: + panic(badTrans) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case left && k > m: + panic(kGTM) + case !left && k > n: + panic(kGTN) + case left && lda < max(1, m): + panic(badLdA) + case !left && lda < max(1, n): + panic(badLdA) + case lwork < max(1, nw) && lwork != -1: panic(badWork) + case len(work) < max(1, lwork): + panic(shortWork) } + // Quick return if possible. if m == 0 || n == 0 || k == 0 { work[0] = 1 return @@ -73,6 +77,17 @@ func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k return } + switch { + case left && len(a) < (k-1)*lda+m: + panic(shortA) + case !left && len(a) < (k-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(badTau) + case len(c) < (m-1)*ldc+n: + panic(shortC) + } + nbmin := 2 if 1 < nb && nb < k { iws := nw*nb + tsize @@ -92,14 +107,14 @@ func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k wrk := work[tsize:] ldwrk := nb - notran := trans == blas.NoTrans + notrans := trans == blas.NoTrans transt := blas.NoTrans - if notran { + if notrans { transt = blas.Trans } switch { - case left && notran: + case left && notrans: for i := 0; i < k; i += nb { ib := min(nb, k-i) impl.Dlarft(lapack.Forward, lapack.RowWise, m-i, ib, @@ -113,7 +128,7 @@ func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k wrk, ldwrk) } - case left && !notran: + case left && !notrans: for i := ((k - 1) / nb) * nb; i >= 0; i -= nb { ib := min(nb, k-i) impl.Dlarft(lapack.Forward, lapack.RowWise, m-i, ib, @@ -127,7 +142,7 @@ func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k wrk, ldwrk) } - case !left && notran: + case !left && notrans: for i := ((k - 1) / nb) * nb; i >= 0; i -= nb { ib := min(nb, k-i) impl.Dlarft(lapack.Forward, lapack.RowWise, n-i, ib, @@ -141,7 +156,7 @@ func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k wrk, ldwrk) } - case !left && !notran: + case !left && !notrans: for i := 0; i < k; i += nb { ib := min(nb, k-i) impl.Dlarft(lapack.Forward, lapack.RowWise, n-i, ib, diff --git a/lapack/gonum/dormqr.go b/lapack/gonum/dormqr.go index 3fa9009f..5a22371d 100644 --- a/lapack/gonum/dormqr.go +++ b/lapack/gonum/dormqr.go @@ -37,37 +37,39 @@ import ( // If lwork is -1, instead of performing Dormqr, the optimal workspace size will // be stored into work[0]. func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) { - var nq, nw int - switch side { - default: - panic(badSide) - case blas.Left: + left := side == blas.Left + nq := n + nw := m + if left { nq = m nw = n - case blas.Right: - nq = n - nw = m } switch { + case !left && side != blas.Right: + panic(badSide) case trans != blas.NoTrans && trans != blas.Trans: panic(badTrans) - case m < 0 || n < 0: - panic(negDimension) - case k < 0 || nq < k: - panic("lapack: invalid value of k") - case len(work) < lwork: - panic(shortWork) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case left && k > m: + panic(kGTM) + case !left && k > n: + panic(kGTN) + case lda < max(1, k): + panic(badLdA) + case ldc < max(1, n): + panic(badLdC) case lwork < max(1, nw) && lwork != -1: panic(badWork) - } - if lwork != -1 { - checkMatrix(nq, k, a, lda) - checkMatrix(m, n, c, ldc) - if len(tau) != k { - panic(badTau) - } + case len(work) < max(1, lwork): + panic(shortWork) } + // Quick return if possible. if m == 0 || n == 0 || k == 0 { work[0] = 1 return @@ -86,6 +88,15 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k return } + switch { + case len(a) < (nq-1)*lda+k: + panic(shortA) + case len(tau) != k: + panic(badTau) + case len(c) < (m-1)*ldc+n: + panic(shortC) + } + nbmin := 2 if 1 < nb && nb < k { if lwork < nw*nb+tsize { @@ -102,12 +113,11 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k } var ( - ldwork = nb - left = side == blas.Left - notran = trans == blas.NoTrans + ldwork = nb + notrans = trans == blas.NoTrans ) switch { - case left && notran: + case left && notrans: for i := ((k - 1) / nb) * nb; i >= 0; i -= nb { ib := min(nb, k-i) impl.Dlarft(lapack.Forward, lapack.ColumnWise, m-i, ib, @@ -121,7 +131,7 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k work[tsize:], ldwork) } - case left && !notran: + case left && !notrans: for i := 0; i < k; i += nb { ib := min(nb, k-i) impl.Dlarft(lapack.Forward, lapack.ColumnWise, m-i, ib, @@ -135,7 +145,7 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k work[tsize:], ldwork) } - case !left && notran: + case !left && notrans: for i := 0; i < k; i += nb { ib := min(nb, k-i) impl.Dlarft(lapack.Forward, lapack.ColumnWise, n-i, ib, @@ -149,7 +159,7 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k work[tsize:], ldwork) } - case !left && !notran: + case !left && !notrans: for i := ((k - 1) / nb) * nb; i >= 0; i -= nb { ib := min(nb, k-i) impl.Dlarft(lapack.Forward, lapack.ColumnWise, n-i, ib, diff --git a/lapack/gonum/dormr2.go b/lapack/gonum/dormr2.go index 3a6b4330..d43f0a7a 100644 --- a/lapack/gonum/dormr2.go +++ b/lapack/gonum/dormr2.go @@ -23,42 +23,52 @@ import "gonum.org/v1/gonum/blas" // // Dormr2 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dormr2(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64) { - if side != blas.Left && side != blas.Right { - panic(badSide) - } - if trans != blas.Trans && trans != blas.NoTrans { - panic(badTrans) - } - left := side == blas.Left - notran := trans == blas.NoTrans + nq := n + nw := m if left { - if k > m { - panic(kGTM) - } - checkMatrix(k, m, a, lda) - if len(work) < n { - panic(badWork) - } - } else { - if k > n { - panic(kGTN) - } - checkMatrix(k, n, a, lda) - if len(work) < m { - panic(badWork) - } + nq = m + nw = n } - if len(tau) < k { - panic(badTau) + switch { + case !left && side != blas.Right: + panic(badSide) + case trans != blas.NoTrans && trans != blas.Trans: + panic(badTrans) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case left && k > m: + panic(kGTM) + case !left && k > n: + panic(kGTN) + case lda < max(1, nq): + panic(badLdA) + case ldc < max(1, n): + panic(badLdC) } - checkMatrix(m, n, c, ldc) + // Quick return if possible. if m == 0 || n == 0 || k == 0 { return } + + switch { + case len(a) < (k-1)*lda+nq: + panic(shortA) + case len(tau) < k: + panic(badTau) + case len(c) < (m-1)*ldc+n: + panic(shortC) + case len(work) < nw: + panic(shortWork) + } + if left { - if notran { + if trans == blas.NoTrans { for i := k - 1; i >= 0; i-- { aii := a[i*lda+(m-k+i)] a[i*lda+(m-k+i)] = 1 @@ -75,7 +85,7 @@ func (impl Implementation) Dormr2(side blas.Side, trans blas.Transpose, m, n, k } return } - if notran { + if trans == blas.NoTrans { for i := 0; i < k; i++ { aii := a[i*lda+(n-k+i)] a[i*lda+(n-k+i)] = 1 diff --git a/lapack/gonum/dpbtf2.go b/lapack/gonum/dpbtf2.go index 0c60385b..a5beb80b 100644 --- a/lapack/gonum/dpbtf2.go +++ b/lapack/gonum/dpbtf2.go @@ -48,14 +48,27 @@ import ( // // Dpbtf2 is an internal routine, exported for testing purposes. func (Implementation) Dpbtf2(ul blas.Uplo, n, kd int, ab []float64, ldab int) (ok bool) { - if ul != blas.Upper && ul != blas.Lower { + switch { + case ul != blas.Upper && ul != blas.Lower: panic(badUplo) + case n < 0: + panic(nLT0) + case kd < 0: + panic(kdLT0) + case ldab < kd+1: + panic(badLdA) } - checkSymBanded(ab, n, kd, ldab) + if n == 0 { return } + + if len(ab) < (n-1)*ldab+kd { + panic(shortAB) + } + bi := blas64.Implementation() + kld := max(1, ldab-1) if ul == blas.Upper { for j := 0; j < n; j++ { diff --git a/lapack/gonum/dpocon.go b/lapack/gonum/dpocon.go index 98d6c02b..7af4c187 100644 --- a/lapack/gonum/dpocon.go +++ b/lapack/gonum/dpocon.go @@ -21,41 +21,55 @@ import ( // // iwork is a temporary data slice of length at least n and Dpocon will panic otherwise. func (impl Implementation) Dpocon(uplo blas.Uplo, n int, a []float64, lda int, anorm float64, work []float64, iwork []int) float64 { - checkMatrix(n, n, a, lda) - if uplo != blas.Upper && uplo != blas.Lower { + switch { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case anorm < 0: + panic(negANorm) } - if len(work) < 3*n { - panic(badWork) - } - if len(iwork) < n { - panic(badWork) - } - var rcond float64 + + // Quick return if possible. if n == 0 { return 1 } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(work) < 3*n: + panic(shortWork) + case len(iwork) < n: + panic(shortIWork) + } + if anorm == 0 { - return rcond + return 0 } bi := blas64.Implementation() - var ainvnm float64 - smlnum := dlamchS - upper := uplo == blas.Upper - var kase int - var normin bool - isave := new([3]int) - var sl, su float64 + + var ( + smlnum = dlamchS + rcond float64 + sl, su float64 + normin bool + ainvnm float64 + kase int + isave [3]int + ) for { - ainvnm, kase = impl.Dlacn2(n, work[n:], work, iwork, ainvnm, kase, isave) + ainvnm, kase = impl.Dlacn2(n, work[n:], work, iwork, ainvnm, kase, &isave) if kase == 0 { if ainvnm != 0 { rcond = (1 / ainvnm) / anorm } return rcond } - if upper { + if uplo == blas.Upper { sl = impl.Dlatrs(blas.Upper, blas.Trans, blas.NonUnit, normin, n, a, lda, work, work[2*n:]) normin = true su = impl.Dlatrs(blas.Upper, blas.NoTrans, blas.NonUnit, normin, n, a, lda, work, work[2*n:]) diff --git a/lapack/gonum/dpotf2.go b/lapack/gonum/dpotf2.go index 3d1cfb68..5d3327c2 100644 --- a/lapack/gonum/dpotf2.go +++ b/lapack/gonum/dpotf2.go @@ -19,16 +19,26 @@ import ( // // Dpotf2 is an internal routine. It is exported for testing purposes. func (Implementation) Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool) { - if ul != blas.Upper && ul != blas.Lower { + switch { + case ul != blas.Upper && ul != blas.Lower: panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } - checkMatrix(n, n, a, lda) + // Quick return if possible. if n == 0 { return true } + if len(a) < (n-1)*lda+n { + panic(shortA) + } + bi := blas64.Implementation() + if ul == blas.Upper { for j := 0; j < n; j++ { ajj := a[j*lda+j] diff --git a/lapack/gonum/dpotrf.go b/lapack/gonum/dpotrf.go index 0ff3afcc..21241687 100644 --- a/lapack/gonum/dpotrf.go +++ b/lapack/gonum/dpotrf.go @@ -15,15 +15,24 @@ import ( // is computed and stored in-place into a. If a is not positive definite, false // is returned. This is the blocked version of the algorithm. func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) { - if ul != blas.Upper && ul != blas.Lower { + switch { + case ul != blas.Upper && ul != blas.Lower: panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } - checkMatrix(n, n, a, lda) + // Quick return if possible. if n == 0 { return true } + if len(a) < (n-1)*lda+n { + panic(shortA) + } + nb := impl.Ilaenv(1, "DPOTRF", string(ul), n, -1, -1, -1) if nb <= 1 || n <= nb { return impl.Dpotf2(ul, n, a, lda) diff --git a/lapack/gonum/dpotri.go b/lapack/gonum/dpotri.go index fc4ffa9a..2394775c 100644 --- a/lapack/gonum/dpotri.go +++ b/lapack/gonum/dpotri.go @@ -21,8 +21,6 @@ func (impl Implementation) Dpotri(uplo blas.Uplo, n int, a []float64, lda int) ( panic(nLT0) case lda < max(1, n): panic(badLdA) - case len(a) < (n-1)*lda+n: - panic("lapack: a has insufficient length") } // Quick return if possible. @@ -30,6 +28,10 @@ func (impl Implementation) Dpotri(uplo blas.Uplo, n int, a []float64, lda int) ( return true } + if len(a) < (n-1)*lda+n { + panic(shortA) + } + // Invert the triangular Cholesky factor U or L. ok = impl.Dtrtri(uplo, blas.NonUnit, n, a, lda) if !ok { diff --git a/lapack/gonum/dpotrs.go b/lapack/gonum/dpotrs.go index 3c12423b..689e0439 100644 --- a/lapack/gonum/dpotrs.go +++ b/lapack/gonum/dpotrs.go @@ -17,17 +17,33 @@ import ( // as computed by Dpotrf. On entry, B contains the right-hand side matrix B, on // return it contains the solution matrix X. func (Implementation) Dpotrs(uplo blas.Uplo, n, nrhs int, a []float64, lda int, b []float64, ldb int) { - if uplo != blas.Upper && uplo != blas.Lower { + switch { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) + case n < 0: + panic(nLT0) + case nrhs < 0: + panic(nrhsLT0) + case lda < max(1, n): + panic(badLdA) + case ldb < max(1, nrhs): + panic(badLdB) } - checkMatrix(n, n, a, lda) - checkMatrix(n, nrhs, b, ldb) + // Quick return if possible. if n == 0 || nrhs == 0 { return } + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(b) < (n-1)*ldb+nrhs: + panic(shortB) + } + bi := blas64.Implementation() + if uplo == blas.Upper { // Solve U^T * U * X = B where U is stored in the upper triangle of A. diff --git a/lapack/gonum/drscl.go b/lapack/gonum/drscl.go index 302c3230..b2772dbc 100644 --- a/lapack/gonum/drscl.go +++ b/lapack/gonum/drscl.go @@ -15,8 +15,24 @@ import ( // // Drscl is an internal routine. It is exported for testing purposes. func (impl Implementation) Drscl(n int, a float64, x []float64, incX int) { - checkVector(n, x, incX) + switch { + case n < 0: + panic(nLT0) + case incX <= 0: + panic(badIncX) + } + + // Quick return if possible. + if n == 0 { + return + } + + if len(x) < 1+(n-1)*incX { + panic(shortX) + } + bi := blas64.Implementation() + cden := a cnum := 1.0 smlnum := dlamchS diff --git a/lapack/gonum/dsteqr.go b/lapack/gonum/dsteqr.go index 90bf37d6..a2b91903 100644 --- a/lapack/gonum/dsteqr.go +++ b/lapack/gonum/dsteqr.go @@ -37,23 +37,29 @@ import ( // // Dsteqr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dsteqr(compz lapack.EVComp, n int, d, e, z []float64, ldz int, work []float64) (ok bool) { - if n < 0 { - panic(nLT0) - } - if len(d) < n { - panic(badD) - } - if len(e) < n-1 { - panic(badE) - } - if compz != lapack.EVCompNone && compz != lapack.EVTridiag && compz != lapack.EVOrig { + switch { + case compz != lapack.EVCompNone && compz != lapack.EVTridiag && compz != lapack.EVOrig: panic(badEVComp) + case n < 0: + panic(nLT0) + case ldz < 1, compz != lapack.EVCompNone && ldz < n: + panic(badLdZ) } - if compz != lapack.EVCompNone { - if len(work) < max(1, 2*n-2) { - panic(badWork) - } - checkMatrix(n, n, z, ldz) + + // Quick return if possible. + if n == 0 { + return true + } + + switch { + case len(d) < n: + panic(badD) + case len(e) < n-1: + panic(badE) + case compz != lapack.EVCompNone && len(z) < (n-1)*ldz+n: + panic(shortZ) + case compz != lapack.EVCompNone && len(work) < max(1, 2*n-2): + panic(shortWork) } var icompz int @@ -63,9 +69,6 @@ func (impl Implementation) Dsteqr(compz lapack.EVComp, n int, d, e, z []float64, icompz = 2 } - if n == 0 { - return true - } if n == 1 { if icompz == 2 { z[0] = 1 diff --git a/lapack/gonum/dsterf.go b/lapack/gonum/dsterf.go index 636cf1eb..f9320d15 100644 --- a/lapack/gonum/dsterf.go +++ b/lapack/gonum/dsterf.go @@ -26,16 +26,23 @@ func (impl Implementation) Dsterf(n int, d, e []float64) (ok bool) { if n < 0 { panic(nLT0) } + + // Quick return if possible. if n == 0 { return true } - if len(d) < n { + + switch { + case len(d) < n: panic(badD) - } - if len(e) < n-1 { + case len(e) < n-1: panic(badE) } + if n == 1 { + return true + } + const ( none = 0 // The values are not scaled. down = 1 // The values are scaled below ssfmax threshold. diff --git a/lapack/gonum/dsyev.go b/lapack/gonum/dsyev.go index 29a78319..3cf090af 100644 --- a/lapack/gonum/dsyev.go +++ b/lapack/gonum/dsyev.go @@ -28,45 +28,55 @@ import ( // limited by the usable length. If lwork == -1, instead of computing Dsyev the // optimal work length is stored into work[0]. func (impl Implementation) Dsyev(jobz lapack.EVJob, uplo blas.Uplo, n int, a []float64, lda int, w, work []float64, lwork int) (ok bool) { - checkMatrix(n, n, a, lda) - upper := uplo == blas.Upper - var wantz bool - switch jobz { - default: + switch { + case jobz != lapack.EVNone && jobz != lapack.EVCompute: panic(badEVJob) - case lapack.EVCompute: - wantz = true - case lapack.EVNone: + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case lwork < max(1, 3*n-1) && lwork != -1: + panic(badWork) + case len(work) < max(1, lwork): + panic(shortWork) } + + // Quick return if possible. + if n == 0 { + return true + } + var opts string - if upper { + if uplo == blas.Upper { opts = "U" } else { opts = "L" } nb := impl.Ilaenv(1, "DSYTRD", opts, n, -1, -1, -1) lworkopt := max(1, (nb+2)*n) - work[0] = float64(lworkopt) if lwork == -1 { + work[0] = float64(lworkopt) return } - if len(work) < lwork { - panic(badWork) - } - if lwork < 3*n-1 { - panic(badWork) - } - if n == 0 { - return true + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(w) < n: + panic(shortW) } + if n == 1 { w[0] = a[0] work[0] = 2 - if wantz { + if jobz == lapack.EVCompute { a[0] = 1 } return true } + safmin := dlamchS eps := dlamchP smlnum := safmin / eps @@ -87,7 +97,7 @@ func (impl Implementation) Dsyev(jobz lapack.EVJob, uplo blas.Uplo, n int, a []f } if scaled { kind := lapack.LowerTri - if upper { + if uplo == blas.Upper { kind = lapack.UpperTri } impl.Dlascl(kind, 0, 0, 1, sigma, n, n, a, lda) @@ -100,7 +110,7 @@ func (impl Implementation) Dsyev(jobz lapack.EVJob, uplo blas.Uplo, n int, a []f // For eigenvalues only, call Dsterf. For eigenvectors, first call Dorgtr // to generate the orthogonal matrix, then call Dsteqr. - if !wantz { + if jobz == lapack.EVNone { ok = impl.Dsterf(n, w, work[inde:]) } else { impl.Dorgtr(uplo, n, a, lda, work[indtau:], work[indwork:], llwork) diff --git a/lapack/gonum/dsytd2.go b/lapack/gonum/dsytd2.go index b6dc60c0..2a6a6156 100644 --- a/lapack/gonum/dsytd2.go +++ b/lapack/gonum/dsytd2.go @@ -49,20 +49,33 @@ import ( // // Dsytd2 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dsytd2(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau []float64) { - checkMatrix(n, n, a, lda) - if len(d) < n { - panic(badD) + switch { + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } - if len(e) < n-1 { - panic(badE) - } - if len(tau) < n-1 { - panic(badTau) - } - if n <= 0 { + + // Quick return if possible. + if n == 0 { return } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(d) < n: + panic(badD) + case len(e) < n-1: + panic(badE) + case len(tau) < n-1: + panic(badTau) + } + bi := blas64.Implementation() + if uplo == blas.Upper { // Reduce the upper triangle of A. for i := n - 2; i >= 0; i-- { diff --git a/lapack/gonum/dsytrd.go b/lapack/gonum/dsytrd.go index 2fee1e5f..e95aab33 100644 --- a/lapack/gonum/dsytrd.go +++ b/lapack/gonum/dsytrd.go @@ -55,16 +55,9 @@ import ( // // Dsytrd is an internal routine. It is exported for testing purposes. func (impl Implementation) Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int) { - var opts string - switch uplo { - case blas.Upper: - opts = "U" - case blas.Lower: - opts = "L" - default: - panic(badUplo) - } switch { + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) case n < 0: panic(nLT0) case lda < max(1, n): @@ -75,12 +68,13 @@ func (impl Implementation) Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d panic(shortWork) } + // Quick return if possible. if n == 0 { work[0] = 1 return } - nb := impl.Ilaenv(1, "DSYTRD", opts, n, -1, -1, -1) + nb := impl.Ilaenv(1, "DSYTRD", string(uplo), n, -1, -1, -1) lworkopt := n * nb if lwork == -1 { work[0] = float64(lworkopt) @@ -89,7 +83,7 @@ func (impl Implementation) Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d switch { case len(a) < (n-1)*lda+n: - panic("lapack: insufficient length of a") + panic(shortA) case len(d) < n: panic(badD) case len(e) < n-1: @@ -98,14 +92,15 @@ func (impl Implementation) Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d panic(badTau) } + bi := blas64.Implementation() + nx := n iws := 1 - bi := blas64.Implementation() var ldwork int if 1 < nb && nb < n { // Determine when to cross over from blocked to unblocked code. The last // block is always handled by unblocked code. - nx = max(nb, impl.Ilaenv(3, "DSYTRD", opts, n, -1, -1, -1)) + nx = max(nb, impl.Ilaenv(3, "DSYTRD", string(uplo), n, -1, -1, -1)) if nx < n { // Determine if workspace is large enough for blocked code. ldwork = nb @@ -115,7 +110,7 @@ func (impl Implementation) Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d // value of nb and reduce nb or force use of unblocked code by // setting nx = n. nb = max(lwork/n, 1) - nbmin := impl.Ilaenv(2, "DSYTRD", opts, n, -1, -1, -1) + nbmin := impl.Ilaenv(2, "DSYTRD", string(uplo), n, -1, -1, -1) if nb < nbmin { nx = n } diff --git a/lapack/gonum/dtgsja.go b/lapack/gonum/dtgsja.go index 8a1beefe..d063727a 100644 --- a/lapack/gonum/dtgsja.go +++ b/lapack/gonum/dtgsja.go @@ -159,45 +159,61 @@ import ( func (impl Implementation) Dtgsja(jobU, jobV, jobQ lapack.GSVDJob, m, p, n, k, l int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64) (cycles int, ok bool) { const maxit = 40 - checkMatrix(m, n, a, lda) - checkMatrix(p, n, b, ldb) - - if len(alpha) != n { - panic(badAlpha) - } - if len(beta) != n { - panic(badBeta) - } - initu := jobU == lapack.GSVDUnit wantu := initu || jobU == lapack.GSVDU - if !initu && !wantu && jobU != lapack.GSVDNone { - panic(badGSVDJob + "U") - } - if jobU != lapack.GSVDNone { - checkMatrix(m, m, u, ldu) - } initv := jobV == lapack.GSVDUnit wantv := initv || jobV == lapack.GSVDV - if !initv && !wantv && jobV != lapack.GSVDNone { - panic(badGSVDJob + "V") - } - if jobV != lapack.GSVDNone { - checkMatrix(p, p, v, ldv) - } initq := jobQ == lapack.GSVDUnit wantq := initq || jobQ == lapack.GSVDQ - if !initq && !wantq && jobQ != lapack.GSVDNone { - panic(badGSVDJob + "Q") - } - if jobQ != lapack.GSVDNone { - checkMatrix(n, n, q, ldq) - } - if len(work) < 2*n { - panic(badWork) + switch { + case !initu && !wantu && jobU != lapack.GSVDNone: + panic(badGSVDJob + "U") + case !initv && !wantv && jobV != lapack.GSVDNone: + panic(badGSVDJob + "V") + case !initq && !wantq && jobQ != lapack.GSVDNone: + panic(badGSVDJob + "Q") + case m < 0: + panic(mLT0) + case p < 0: + panic(pLT0) + case n < 0: + panic(nLT0) + + case lda < max(1, n): + panic(badLdA) + case len(a) < (m-1)*lda+n: + panic(shortA) + + case ldb < max(1, n): + panic(badLdB) + case len(b) < (p-1)*ldb+n: + panic(shortB) + + case len(alpha) != n: + panic(badAlpha) + case len(beta) != n: + panic(badBeta) + + case ldu < 1, wantu && ldu < m: + panic(badLdU) + case wantu && len(u) < (m-1)*ldu+m: + panic(shortU) + + case ldv < 1, wantv && ldv < p: + panic(badLdV) + case wantv && len(v) < (p-1)*ldv+p: + panic(shortV) + + case ldq < 1, wantq && ldq < n: + panic(badLdQ) + case wantq && len(q) < (n-1)*ldq+n: + panic(shortQ) + + case len(work) < 2*n: + panic(shortWork) } // Initialize U, V and Q, if necessary diff --git a/lapack/gonum/dtrcon.go b/lapack/gonum/dtrcon.go index 42d9648f..899c95dd 100644 --- a/lapack/gonum/dtrcon.go +++ b/lapack/gonum/dtrcon.go @@ -19,24 +19,32 @@ import ( // // iwork is a temporary data slice of length at least n and Dtrcon will panic otherwise. func (impl Implementation) Dtrcon(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int, work []float64, iwork []int) float64 { - if norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum { + switch { + case norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum: panic(badNorm) - } - if uplo != blas.Upper && uplo != blas.Lower { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) - } - if diag != blas.NonUnit && diag != blas.Unit { + case diag != blas.NonUnit && diag != blas.Unit: panic(badDiag) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } - if len(work) < 3*n { - panic(badWork) - } - if len(iwork) < n { - panic(badWork) - } + if n == 0 { return 1 } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(work) < 3*n: + panic(shortWork) + case len(iwork) < n: + panic(shortIWork) + } + bi := blas64.Implementation() var rcond float64 diff --git a/lapack/gonum/dtrevc3.go b/lapack/gonum/dtrevc3.go index 5d8ed29c..ad01ac4f 100644 --- a/lapack/gonum/dtrevc3.go +++ b/lapack/gonum/dtrevc3.go @@ -106,86 +106,110 @@ import ( // // Dtrevc3 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dtrevc3(side lapack.EVSide, howmny lapack.EVHowMany, selected []bool, n int, t []float64, ldt int, vl []float64, ldvl int, vr []float64, ldvr int, mm int, work []float64, lwork int) (m int) { - switch side { - default: - panic(badEVSide) - case lapack.EVRight, lapack.EVLeft, lapack.EVBoth: - } - switch howmny { - default: - panic(badEVHowMany) - case lapack.EVAll, lapack.EVAllMulQ, lapack.EVSelected: - } + bothv := side == lapack.EVBoth + rightv := side == lapack.EVRight || bothv + leftv := side == lapack.EVLeft || bothv switch { + case !rightv && !leftv: + panic(badEVSide) + case howmny != lapack.EVAll && howmny != lapack.EVAllMulQ && howmny != lapack.EVSelected: + panic(badEVHowMany) case n < 0: panic(nLT0) - case len(work) < lwork: - panic(shortWork) + case ldt < max(1, n): + panic(badLdT) + case mm < 0: + panic(mmLT0) + case ldvl < 1: + // ldvl and ldvr are also checked below after the computation of + // m (number of columns of VL and VR) in case of howmny == EVSelected. + panic(badLdVL) + case ldvr < 1: + panic(badLdVR) case lwork < max(1, 3*n) && lwork != -1: panic(badWork) - } - if lwork != -1 { - if howmny == lapack.EVSelected { - if len(selected) != n { - panic("lapack: bad selected length") - } - // Set m to the number of columns required to store the - // selected eigenvectors, and standardize the slice - // selected. - for j := 0; j < n; { - if j == n-1 || t[(j+1)*ldt+j] == 0 { - // Diagonal 1×1 block corresponding to a - // real eigenvalue. - if selected[j] { - m++ - } - j++ - } else { - // Diagonal 2×2 block corresponding to a - // complex eigenvalue. - if selected[j] || selected[j+1] { - selected[j] = true - selected[j+1] = false - m += 2 - } - j += 2 - } - } - } else { - m = n - } - if m > mm { - panic("lapack: insufficient number of columns") - } - checkMatrix(n, n, t, ldt) - if (side == lapack.EVRight || side == lapack.EVBoth) && m > 0 { - checkMatrix(n, m, vr, ldvr) - } - if (side == lapack.EVLeft || side == lapack.EVBoth) && m > 0 { - checkMatrix(n, m, vl, ldvl) - } + case len(work) < max(1, lwork): + panic(shortWork) } // Quick return if possible. if n == 0 { work[0] = 1 - return m + return 0 } - const ( - nbmin = 8 - nbmax = 128 - ) - nb := impl.Ilaenv(1, "DTREVC", string(side)+string(howmny), n, -1, -1, -1) + // Normally we don't check slice lengths until after the workspace + // query. However, even in case of the workspace query we need to + // compute and return the value of m, and since the computation accesses t, + // we put the length check of t here. + if len(t) < (n-1)*ldt+n { + panic(shortT) + } + + if howmny == lapack.EVSelected { + if len(selected) != n { + panic(badSelected) + } + // Set m to the number of columns required to store the selected + // eigenvectors, and standardize the slice selected. + // Each selected real eigenvector occupies one column and each + // selected complex eigenvector occupies two columns. + for j := 0; j < n; { + if j == n-1 || t[(j+1)*ldt+j] == 0 { + // Diagonal 1×1 block corresponding to a + // real eigenvalue. + if selected[j] { + m++ + } + j++ + } else { + // Diagonal 2×2 block corresponding to a + // complex eigenvalue. + if selected[j] || selected[j+1] { + selected[j] = true + selected[j+1] = false + m += 2 + } + j += 2 + } + } + } else { + m = n + } + if mm < m { + panic(badMM) + } // Quick return in case of a workspace query. + nb := impl.Ilaenv(1, "DTREVC", string(side)+string(howmny), n, -1, -1, -1) if lwork == -1 { work[0] = float64(n + 2*n*nb) return m } + // Quick return if no eigenvectors were selected. + if m == 0 { + return 0 + } + + switch { + case leftv && ldvl < mm: + panic(badLdVL) + case leftv && len(vl) < (n-1)*ldvl+mm: + panic(shortVL) + + case rightv && ldvr < mm: + panic(badLdVR) + case rightv && len(vr) < (n-1)*ldvr+mm: + panic(shortVR) + } + // Use blocked version of back-transformation if sufficient workspace. // Zero-out the workspace to avoid potential NaN propagation. + const ( + nbmin = 8 + nbmax = 128 + ) if howmny == lapack.EVAllMulQ && lwork >= n+2*n*nbmin { nb = min((lwork-n)/(2*n), nbmax) impl.Dlaset(blas.All, n, 1+2*nb, 0, 0, work[:n+2*nb*n], 1+2*nb) diff --git a/lapack/gonum/dtrexc.go b/lapack/gonum/dtrexc.go index 1953fca9..9f3f90ba 100644 --- a/lapack/gonum/dtrexc.go +++ b/lapack/gonum/dtrexc.go @@ -46,31 +46,37 @@ import "gonum.org/v1/gonum/lapack" // // Dtrexc is an internal routine. It is exported for testing purposes. func (impl Implementation) Dtrexc(compq lapack.UpdateSchurComp, n int, t []float64, ldt int, q []float64, ldq int, ifst, ilst int, work []float64) (ifstOut, ilstOut int, ok bool) { - checkMatrix(n, n, t, ldt) - var wantq bool - switch compq { - default: - panic("lapack: bad value of compq") - case lapack.UpdateSchurNone: - // Nothing to do because wantq is already false. - case lapack.UpdateSchur: - wantq = true - checkMatrix(n, n, q, ldq) + switch { + case compq != lapack.UpdateSchur && compq != lapack.UpdateSchurNone: + panic(badUpdateSchurComp) + case n < 0: + panic(nLT0) + case ldt < max(1, n): + panic(badLdT) + case ldq < 1, compq == lapack.UpdateSchur && ldq < n: + panic(badLdQ) + case (ifst < 0 || n <= ifst) && n > 0: + panic(badIfst) + case (ilst < 0 || n <= ilst) && n > 0: + panic(badIlst) } - if (ifst < 0 || n <= ifst) && n > 0 { - panic("lapack: ifst out of range") - } - if (ilst < 0 || n <= ilst) && n > 0 { - panic("lapack: ilst out of range") - } - if len(work) < n { - panic(badWork) - } - - ok = true // Quick return if possible. - if n <= 1 { + if n == 0 { + return ifst, ilst, true + } + + switch { + case len(t) < (n-1)*ldt+n: + panic(shortT) + case compq == lapack.UpdateSchur && len(q) < (n-1)*ldq+n: + panic(shortQ) + case len(work) < n: + panic(shortWork) + } + + // Quick return if possible. + if n == 1 { return ifst, ilst, true } @@ -93,6 +99,9 @@ func (impl Implementation) Dtrexc(compq lapack.UpdateSchurComp, n int, t []float nbl = 2 } + ok = true + wantq := compq == lapack.UpdateSchur + switch { case ifst == ilst: return ifst, ilst, true diff --git a/lapack/gonum/dtrti2.go b/lapack/gonum/dtrti2.go index a43efe6f..efc24b65 100644 --- a/lapack/gonum/dtrti2.go +++ b/lapack/gonum/dtrti2.go @@ -14,13 +14,25 @@ import ( // // Dtrti2 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) { - checkMatrix(n, n, a, lda) - if uplo != blas.Upper && uplo != blas.Lower { + switch { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) - } - if diag != blas.NonUnit && diag != blas.Unit { + case diag != blas.NonUnit && diag != blas.Unit: panic(badDiag) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } + + if n == 0 { + return + } + + if len(a) < (n-1)*lda+n { + panic(shortA) + } + bi := blas64.Implementation() nonUnit := diag == blas.NonUnit diff --git a/lapack/gonum/dtrtri.go b/lapack/gonum/dtrtri.go index 95f1b3be..6ec3663c 100644 --- a/lapack/gonum/dtrtri.go +++ b/lapack/gonum/dtrtri.go @@ -16,18 +16,26 @@ import ( // Dtrtri will not perform the inversion if the matrix is singular, and returns // a boolean indicating whether the inversion was successful. func (impl Implementation) Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) (ok bool) { - checkMatrix(n, n, a, lda) - if uplo != blas.Upper && uplo != blas.Lower { + switch { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) - } - if diag != blas.NonUnit && diag != blas.Unit { + case diag != blas.NonUnit && diag != blas.Unit: panic(badDiag) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } + if n == 0 { - return false + return true } - nonUnit := diag == blas.NonUnit - if nonUnit { + + if len(a) < (n-1)*lda+n { + panic(shortA) + } + + if diag == blas.NonUnit { for i := 0; i < n; i++ { if a[i*lda+i] == 0 { return false diff --git a/lapack/gonum/dtrtrs.go b/lapack/gonum/dtrtrs.go index e1782d23..1752dc5c 100644 --- a/lapack/gonum/dtrtrs.go +++ b/lapack/gonum/dtrtrs.go @@ -12,11 +12,36 @@ import ( // Dtrtrs solves a triangular system of the form A * X = B or A^T * X = B. Dtrtrs // returns whether the solve completed successfully. If A is singular, no solve is performed. func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) { - nounit := diag == blas.NonUnit - if n == 0 { - return false + switch { + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans: + panic(badTrans) + case diag != blas.NonUnit && diag != blas.Unit: + panic(badDiag) + case n < 0: + panic(nLT0) + case nrhs < 0: + panic(nrhsLT0) + case lda < max(1, n): + panic(badLdA) + case ldb < max(1, nrhs): + panic(badLdB) } + + if n == 0 { + return true + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(b) < (n-1)*ldb+nrhs: + panic(shortB) + } + // Check for singularity. + nounit := diag == blas.NonUnit if nounit { for i := 0; i < n; i++ { if a[i*lda+i] == 0 { diff --git a/lapack/gonum/errors.go b/lapack/gonum/errors.go index f17ab3e5..3f2d1990 100644 --- a/lapack/gonum/errors.go +++ b/lapack/gonum/errors.go @@ -25,6 +25,8 @@ const ( badIlo = "lapack: ilo out of range" badIhi = "lapack: ihi out of range" badIpiv = "lapack: bad permutation length" + badIfst = "lapack: ifst out of range" + badIlst = "lapack: ilst out of range" badBalanceJob = "lapack: bad BalanceJob" badJ1 = "lapack: j1 out of range" badK1 = "lapack: k1 out of range" @@ -49,6 +51,7 @@ const ( badLdX = "lapack: bad leading dimension of X" badLdY = "lapack: bad leading dimension of Y" badLdZ = "lapack: bad leading dimension of Z" + badMM = "lapack: bad value of mm" badNb = "lapack: nb out of range" badNorm = "lapack: bad norm" badPivot = "lapack: bad pivot" @@ -56,6 +59,7 @@ const ( badS = "lapack: s has insufficient length" badSchurComp = "lapack: bad SchurComp" badSchurJob = "lapack: bad SchurJob" + badSelected = "lapack: bad length of selected" badShifts = "lapack: bad shifts" badSide = "lapack: bad side" badSlice = "lapack: bad input slice length" @@ -79,9 +83,14 @@ const ( kGTN = "lapack: k > n" kLT0 = "lapack: k < 0" kLT1 = "lapack: k < 1" + kdLT0 = "lapack: kd < 0" mLT0 = "lapack: m < 0" mLTN = "lapack: m < n" + mmLT0 = "lapack: mm < 0" + nGTM = "lapack: n > m" + mGTN = "lapack: m > n" nanScale = "lapack: NaN scale factor" + negANorm = "lapack: anorm < 0" negDimension = "lapack: negative matrix dimension" negZ = "lapack: negative z value" nLT0 = "lapack: n < 0" @@ -96,6 +105,7 @@ const ( offsetLT0 = "lapack: offset < 0" offsetGTM = "lapack: offset > m" shortA = "lapack: insufficient length of a" + shortAB = "lapack: insufficient length of ab" shortB = "lapack: insufficient length of b" shortC = "lapack: insufficient length of c" shortCNorm = "lapack: insufficient length of cnorm" @@ -121,6 +131,9 @@ const ( shortX = "lapack: insufficient length of x" shortY = "lapack: insufficient length of y" shortZ = "lapack: insufficient length of z" + shortIWork = "lapack: insufficient length of iwork" shortWork = "lapack: working array shorter than declared" zeroDiv = "lapack: zero divisor" + + badUpdateSchurComp = "lapack: bad UpdateSchurComp" ) diff --git a/lapack/gonum/general.go b/lapack/gonum/general.go index 86cd01a9..434da02d 100644 --- a/lapack/gonum/general.go +++ b/lapack/gonum/general.go @@ -13,46 +13,6 @@ type Implementation struct{} var _ lapack.Float64 = Implementation{} -// checkMatrix verifies the parameters of a matrix input. -func checkMatrix(m, n int, a []float64, lda int) { - if m < 0 { - panic("lapack: has negative number of rows") - } - if n < 0 { - panic("lapack: has negative number of columns") - } - if lda < n { - panic("lapack: stride less than number of columns") - } - if len(a) < (m-1)*lda+n { - panic("lapack: insufficient matrix slice length") - } -} - -func checkVector(n int, v []float64, inc int) { - if n < 0 { - panic("lapack: negative vector length") - } - if (inc > 0 && (n-1)*inc >= len(v)) || (inc < 0 && (1-n)*inc >= len(v)) { - panic("lapack: insufficient vector slice length") - } -} - -func checkSymBanded(ab []float64, n, kd, lda int) { - if n < 0 { - panic("lapack: negative banded length") - } - if kd < 0 { - panic("lapack: negative bandwidth value") - } - if lda < kd+1 { - panic("lapack: stride less than number of bands") - } - if len(ab) < (n-1)*lda+kd { - panic("lapack: insufficient banded vector length") - } -} - func min(a, b int) int { if a < b { return a diff --git a/lapack/gonum/iladlc.go b/lapack/gonum/iladlc.go index bd0e4d8f..b251d726 100644 --- a/lapack/gonum/iladlc.go +++ b/lapack/gonum/iladlc.go @@ -9,10 +9,22 @@ package gonum // // Iladlc is an internal routine. It is exported for testing purposes. func (Implementation) Iladlc(m, n int, a []float64, lda int) int { - if n == 0 || m == 0 { - return n - 1 + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + if n == 0 || m == 0 { + return -1 + } + + if len(a) < (m-1)*lda+n { + panic(shortA) } - checkMatrix(m, n, a, lda) // Test common case where corner is non-zero. if a[n-1] != 0 || a[(m-1)*lda+(n-1)] != 0 { diff --git a/lapack/gonum/iladlr.go b/lapack/gonum/iladlr.go index 9f9e0d93..b73fe18e 100644 --- a/lapack/gonum/iladlr.go +++ b/lapack/gonum/iladlr.go @@ -9,11 +9,22 @@ package gonum // // Iladlr is an internal routine. It is exported for testing purposes. func (Implementation) Iladlr(m, n int, a []float64, lda int) int { - if m == 0 { - return m - 1 + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } - checkMatrix(m, n, a, lda) + if n == 0 || m == 0 { + return -1 + } + + if len(a) < (m-1)*lda+n { + panic(shortA) + } // Check the common case where the corner is non-zero if a[(m-1)*lda] != 0 || a[(m-1)*lda+n-1] != 0 { diff --git a/lapack/testlapack/dtrevc3.go b/lapack/testlapack/dtrevc3.go index 9ea31594..0715e6df 100644 --- a/lapack/testlapack/dtrevc3.go +++ b/lapack/testlapack/dtrevc3.go @@ -109,12 +109,12 @@ func testDtrevc3(t *testing.T, impl Dtrevc3er, side lapack.EVSide, howmny lapack work := make([]float64, max(1, 3*n)) if optwork { impl.Dtrevc3(side, howmny, selected, n, tmat.Data, tmat.Stride, - vl.Data, vl.Stride, vr.Data, vr.Stride, mWant, work, -1) + vl.Data, max(1, vl.Stride), vr.Data, max(1, vr.Stride), mWant, work, -1) work = make([]float64, int(work[0])) } m := impl.Dtrevc3(side, howmny, selected, n, tmat.Data, tmat.Stride, - vl.Data, vl.Stride, vr.Data, vr.Stride, mWant, work, len(work)) + vl.Data, max(1, vl.Stride), vr.Data, max(1, vr.Stride), mWant, work, len(work)) prefix := fmt.Sprintf("Case side=%v, howmny=%v, n=%v, extra=%v, optwk=%v", side, howmny, n, extra, optwork) diff --git a/lapack/testlapack/dtrexc.go b/lapack/testlapack/dtrexc.go index f684f9ad..92c29efd 100644 --- a/lapack/testlapack/dtrexc.go +++ b/lapack/testlapack/dtrexc.go @@ -63,7 +63,7 @@ func testDtrexc(t *testing.T, impl Dtrexcer, compq lapack.UpdateSchurComp, tmat work := nanSlice(n) - ifstGot, ilstGot, ok := impl.Dtrexc(compq, n, tmat.Data, tmat.Stride, q.Data, q.Stride, ifst, ilst, work) + ifstGot, ilstGot, ok := impl.Dtrexc(compq, n, tmat.Data, tmat.Stride, q.Data, max(1, q.Stride), ifst, ilst, work) prefix := fmt.Sprintf("Case compq=%v, n=%v, ifst=%v, nbf=%v, ilst=%v, nbl=%v, extra=%v", compq, n, ifst, fstSize, ilst, lstSize, extra)