diff --git a/lapack/gonum/dbdsqr.go b/lapack/gonum/dbdsqr.go index b227820b..835186a9 100644 --- a/lapack/gonum/dbdsqr.go +++ b/lapack/gonum/dbdsqr.go @@ -50,17 +50,38 @@ import ( // // Dbdsqr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool) { - if uplo != blas.Upper && uplo != blas.Lower { + switch { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) + case n < 0: + panic(nLT0) + case ncvt < 0: + panic(ncvtLT0) + case nru < 0: + panic(nruLT0) + case ncc < 0: + panic(nccLT0) + case ldvt < max(1, ncvt): + panic(badLdVT) + case (ldu < max(1, n) && nru > 0) || (ldu < 1 && nru == 0): + panic(badLdU) + case ldc < max(1, ncc): + panic(badLdC) } - if ncvt != 0 { - checkMatrix(n, ncvt, vt, ldvt) + + // Quick return if possible. + if n == 0 { + return true } - if nru != 0 { - checkMatrix(nru, n, u, ldu) + + if len(vt) < (n-1)*ldvt+ncvt && ncvt != 0 { + panic(shortVT) } - if ncc != 0 { - checkMatrix(n, ncc, c, ldc) + if len(u) < (nru-1)*ldu+n && nru != 0 { + panic(shortU) + } + if len(c) < (n-1)*ldc+ncc && ncc != 0 { + panic(shortC) } if len(d) < n { panic(badD) @@ -71,14 +92,11 @@ func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, v if len(work) < 4*(n-1) { panic(badWork) } + var info int bi := blas64.Implementation() - const ( - maxIter = 6 - ) - if n == 0 { - return true - } + const maxIter = 6 + if n != 1 { // If the singular vectors do not need to be computed, use qd algorithm. if !(ncvt > 0 || nru > 0 || ncc > 0) { diff --git a/lapack/gonum/dgebak.go b/lapack/gonum/dgebak.go index 136522ad..7caa0b17 100644 --- a/lapack/gonum/dgebak.go +++ b/lapack/gonum/dgebak.go @@ -21,26 +21,37 @@ import ( // // Dgebak is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgebak(job lapack.BalanceJob, side lapack.EVSide, n, ilo, ihi int, scale []float64, m int, v []float64, ldv int) { - switch job { - default: - panic(badBalanceJob) - case lapack.BalanceNone, lapack.Permute, lapack.Scale, lapack.PermuteScale: - } - switch side { - default: - panic(badEVSide) - case lapack.EVLeft, lapack.EVRight: - } - checkMatrix(n, m, v, ldv) switch { + case job != lapack.BalanceNone && job != lapack.Permute && job != lapack.Scale && job != lapack.PermuteScale: + panic(badBalanceJob) + case side != lapack.EVLeft && side != lapack.EVRight: + panic(badEVSide) + case n < 0: + panic(nLT0) case ilo < 0 || max(0, n-1) < ilo: panic(badIlo) case ihi < min(ilo, n-1) || n <= ihi: panic(badIhi) + case m < 0: + panic(mLT0) + case ldv < max(1, m): + panic(badLdV) } // Quick return if possible. - if n == 0 || m == 0 || job == lapack.BalanceNone { + if n == 0 || m == 0 { + return + } + + if len(scale) < n { + panic(shortScale) + } + if len(v) < (n-1)*ldv+m { + panic(shortV) + } + + // Quick return if possible. + if job == lapack.BalanceNone { return } diff --git a/lapack/gonum/dgebal.go b/lapack/gonum/dgebal.go index cb591a84..6fb5170c 100644 --- a/lapack/gonum/dgebal.go +++ b/lapack/gonum/dgebal.go @@ -55,26 +55,37 @@ import ( // // Dgebal is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgebal(job lapack.BalanceJob, n int, a []float64, lda int, scale []float64) (ilo, ihi int) { - switch job { - default: + switch { + case job != lapack.BalanceNone && job != lapack.Permute && job != lapack.Scale && job != lapack.PermuteScale: panic(badBalanceJob) - case lapack.BalanceNone, lapack.Permute, lapack.Scale, lapack.PermuteScale: - } - checkMatrix(n, n, a, lda) - if len(scale) != n { - panic("lapack: bad length of scale") + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } ilo = 0 ihi = n - 1 - if n == 0 || job == lapack.BalanceNone { + if n == 0 { + return ilo, ihi + } + + if len(scale) != n { + panic(shortScale) + } + + if job == lapack.BalanceNone { for i := range scale { scale[i] = 1 } return ilo, ihi } + if len(a) < (n-1)*lda+n { + panic(shortA) + } + bi := blas64.Implementation() swapped := true diff --git a/lapack/gonum/dgebd2.go b/lapack/gonum/dgebd2.go index a8e4aacb..1daba7da 100644 --- a/lapack/gonum/dgebd2.go +++ b/lapack/gonum/dgebd2.go @@ -15,22 +15,34 @@ import "gonum.org/v1/gonum/blas" // // Dgebd2 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgebd2(m, n int, a []float64, lda int, d, e, tauQ, tauP, work []float64) { - checkMatrix(m, n, a, lda) - if len(d) < min(m, n) { + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + // Quick return if possible. + minmn := min(m, n) + if minmn == 0 { + return + } + + switch { + case len(d) < minmn: panic(badD) - } - if len(e) < min(m, n)-1 { + case len(e) < minmn-1: panic(badE) - } - if len(tauQ) < min(m, n) { + case len(tauQ) < minmn: panic(badTauQ) - } - if len(tauP) < min(m, n) { + case len(tauP) < minmn: panic(badTauP) - } - if len(work) < max(m, n) { + case len(work) < max(m, n): panic(badWork) } + if m >= n { for i := 0; i < n; i++ { a[i*lda+i], tauQ[i] = impl.Dlarfg(m-i, a[i*lda+i], a[min(i+1, m-1)*lda+i:], lda) diff --git a/lapack/gonum/dgecon.go b/lapack/gonum/dgecon.go index 04e01535..35f6f847 100644 --- a/lapack/gonum/dgecon.go +++ b/lapack/gonum/dgecon.go @@ -24,20 +24,31 @@ import ( // // iwork is a temporary data slice of length at least n and Dgecon will panic otherwise. func (impl Implementation) Dgecon(norm lapack.MatrixNorm, n int, a []float64, lda int, anorm float64, work []float64, iwork []int) float64 { - checkMatrix(n, n, a, lda) - if norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum { + switch { + case norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum: panic(badNorm) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } - if len(work) < 4*n { + + // Quick return if possible. + if n == 0 { + return 1 + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(work) < 4*n: panic(badWork) - } - if len(iwork) < n { + case len(iwork) < n: panic(badWork) } - if n == 0 { - return 1 - } else if anorm == 0 { + // Quick return if possible. + if anorm == 0 { return 0 } diff --git a/lapack/gonum/dgeev.go b/lapack/gonum/dgeev.go index baa5333c..7b275f72 100644 --- a/lapack/gonum/dgeev.go +++ b/lapack/gonum/dgeev.go @@ -61,50 +61,31 @@ import ( // computed and wr[first:] and wi[first:] contain those eigenvalues which have // converged. func (impl Implementation) Dgeev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob, n int, a []float64, lda int, wr, wi []float64, vl []float64, ldvl int, vr []float64, ldvr int, work []float64, lwork int) (first int) { - var wantvl bool - switch jobvl { - default: - panic("lapack: invalid LeftEVJob") - case lapack.LeftEVCompute: - wantvl = true - case lapack.LeftEVNone: - } - var wantvr bool - switch jobvr { - default: - panic("lapack: invalid RightEVJob") - case lapack.RightEVCompute: - wantvr = true - case lapack.RightEVNone: - } - switch { - case n < 0: - panic(nLT0) - case len(work) < lwork: - panic(shortWork) - } + wantvl := jobvl == lapack.LeftEVCompute + wantvr := jobvr == lapack.RightEVCompute var minwrk int if wantvl || wantvr { minwrk = max(1, 4*n) } else { minwrk = max(1, 3*n) } - if lwork != -1 { - checkMatrix(n, n, a, lda) - if wantvl { - checkMatrix(n, n, vl, ldvl) - } - if wantvr { - checkMatrix(n, n, vr, ldvr) - } - switch { - case len(wr) != n: - panic("lapack: bad length of wr") - case len(wi) != n: - panic("lapack: bad length of wi") - case lwork < minwrk: - panic(badWork) - } + switch { + case jobvl != lapack.LeftEVCompute && jobvl != lapack.LeftEVNone: + panic("lapack: invalid LeftEVJob") + case jobvr != lapack.RightEVCompute && jobvr != lapack.RightEVNone: + panic("lapack: invalid RightEVJob") + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case ldvl < 1 || (ldvl < n && wantvl): + panic(badLdVL) + case ldvr < 1 || (ldvr < n && wantvr): + panic(badLdVR) + case lwork < minwrk && lwork != -1: + panic(badWork) + case len(work) < lwork: + panic(shortWork) } // Quick return if possible. @@ -139,6 +120,19 @@ func (impl Implementation) Dgeev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob return 0 } + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(wr) != n: + panic("lapack: bad length of wr") + case len(wi) != n: + panic("lapack: bad length of wi") + case len(vl) < (n-1)*ldvl+n && wantvl: + panic(shortVL) + case len(vr) < (n-1)*ldvr+n && wantvr: + panic(shortVR) + } + // Get machine constants. smlnum := math.Sqrt(dlamchS) / dlamchP bignum := 1 / smlnum diff --git a/lapack/gonum/dgehd2.go b/lapack/gonum/dgehd2.go index 36828378..a6c9f184 100644 --- a/lapack/gonum/dgehd2.go +++ b/lapack/gonum/dgehd2.go @@ -56,16 +56,29 @@ import "gonum.org/v1/gonum/blas" // // Dgehd2 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgehd2(n, ilo, ihi int, a []float64, lda int, tau, work []float64) { - checkMatrix(n, n, a, lda) switch { - case ilo < 0 || ilo > max(0, n-1): + case n < 0: + panic(nLT0) + case ilo < 0 || max(0, n-1) < ilo: panic(badIlo) - case ihi < min(ilo, n-1) || ihi >= n: + case ihi < min(ilo, n-1) || n <= ihi: panic(badIhi) + case lda < max(1, n): + panic(badLdA) + } + + // Quick return if possible. + if n == 0 { + return + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) case len(tau) != n-1: panic(badTau) case len(work) < n: - panic(badWork) + panic(shortWork) } for i := ilo; i < ihi; i++ { diff --git a/lapack/gonum/dgehrd.go b/lapack/gonum/dgehrd.go index bac849d0..3891c918 100644 --- a/lapack/gonum/dgehrd.go +++ b/lapack/gonum/dgehrd.go @@ -81,6 +81,12 @@ func (impl Implementation) Dgehrd(n, ilo, ihi int, a []float64, lda int, tau, wo panic(shortWork) } + // Quick return if possible. + if n == 0 { + work[0] = 1 + return + } + const ( nbmax = 64 ldt = nbmax + 1 @@ -95,9 +101,9 @@ func (impl Implementation) Dgehrd(n, ilo, ihi int, a []float64, lda int, tau, wo } if len(a) < (n-1)*lda+n { - panic("lapack: insufficient length of a") + panic(shortA) } - if len(tau) != n-1 && n > 0 { + if len(tau) != n-1 { panic(badTau) } diff --git a/lapack/gonum/dgelq2.go b/lapack/gonum/dgelq2.go index 05b3ce45..b2d2747e 100644 --- a/lapack/gonum/dgelq2.go +++ b/lapack/gonum/dgelq2.go @@ -25,14 +25,30 @@ import "gonum.org/v1/gonum/blas" // // Dgelq2 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgelq2(m, n int, a []float64, lda int, tau, work []float64) { - checkMatrix(m, n, a, lda) + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + // Quick return if possible. k := min(m, n) - if len(tau) < k { + if k == 0 { + return + } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: panic(badTau) + case len(work) < m: + panic(shortWork) } - if len(work) < m { - panic(badWork) - } + for i := 0; i < k; i++ { a[i*lda+i], tau[i] = impl.Dlarfg(n-i, a[i*lda+i], a[i*lda+min(i+1, n-1):], 1) if i < m-1 { diff --git a/lapack/gonum/dgelqf.go b/lapack/gonum/dgelqf.go index eb7d58dd..479d2566 100644 --- a/lapack/gonum/dgelqf.go +++ b/lapack/gonum/dgelqf.go @@ -47,7 +47,7 @@ func (impl Implementation) Dgelqf(m, n int, a []float64, lda int, tau, work []fl } if len(a) < (m-1)*lda+n { - panic("lapack: insufficient length of a") + panic(shortA) } if len(tau) < k { panic(badTau) diff --git a/lapack/gonum/dgels.go b/lapack/gonum/dgels.go index 214b9663..679ef2cb 100644 --- a/lapack/gonum/dgels.go +++ b/lapack/gonum/dgels.go @@ -39,46 +39,63 @@ import ( // In the special case that lwork == -1, work[0] will be set to the optimal working // length. func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool { - notran := trans == blas.NoTrans - checkMatrix(m, n, a, lda) mn := min(m, n) - checkMatrix(max(m, n), nrhs, b, ldb) + minwrk := mn + max(mn, nrhs) + switch { + case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans: + panic(badTrans) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case nrhs < 0: + panic(nrhsLT0) + case lda < max(1, n): + panic(badLdA) + case ldb < max(1, nrhs): + panic(badLdB) + case lwork < max(1, minwrk) && lwork != -1: + panic(badWork) + case len(work) < max(1, lwork): + panic(shortWork) + } + + // Quick return if possible. + if mn == 0 || nrhs == 0 { + impl.Dlaset(blas.All, max(m, n), nrhs, 0, 0, b, ldb) + work[0] = 1 + return true + } // Find optimal block size. - tpsd := true - if notran { - tpsd = false - } var nb int if m >= n { nb = impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1) - if tpsd { + if trans != blas.NoTrans { nb = max(nb, impl.Ilaenv(1, "DORMQR", "LN", m, nrhs, n, -1)) } else { nb = max(nb, impl.Ilaenv(1, "DORMQR", "LT", m, nrhs, n, -1)) } } else { nb = impl.Ilaenv(1, "DGELQF", " ", m, n, -1, -1) - if tpsd { + if trans != blas.NoTrans { nb = max(nb, impl.Ilaenv(1, "DORMLQ", "LT", n, nrhs, m, -1)) } else { nb = max(nb, impl.Ilaenv(1, "DORMLQ", "LN", n, nrhs, m, -1)) } } + wsize := max(1, mn+max(mn, nrhs)*nb) + work[0] = float64(wsize) + if lwork == -1 { - work[0] = float64(max(1, mn+max(mn, nrhs)*nb)) return true } - if len(work) < lwork { - panic(shortWork) - } - if lwork < mn+max(mn, nrhs) { - panic(badWork) - } - if m == 0 || n == 0 || nrhs == 0 { - impl.Dlaset(blas.All, max(m, n), nrhs, 0, 0, b, ldb) - return true + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(b) < (max(m, n)-1)*ldb+nrhs: + panic(shortB) } // Scale the input matrices if they contain extreme values. @@ -97,7 +114,7 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float return true } brow := m - if tpsd { + if trans != blas.NoTrans { brow = n } bnrm := impl.Dlange(lapack.MaxAbs, brow, nrhs, b, ldb, nil) @@ -114,7 +131,7 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float var scllen int if m >= n { impl.Dgeqrf(m, n, a, lda, work, work[mn:], lwork-mn) - if !tpsd { + if trans == blas.NoTrans { impl.Dormqr(blas.Left, blas.Trans, m, nrhs, n, a, lda, work[:n], @@ -148,7 +165,7 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float } } else { impl.Dgelqf(m, n, a, lda, work, work[mn:], lwork-mn) - if !tpsd { + if trans == blas.NoTrans { ok := impl.Dtrtrs(blas.Lower, blas.NoTrans, blas.NonUnit, m, nrhs, a, lda, @@ -196,5 +213,7 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float if ibscl == 2 { impl.Dlascl(lapack.General, 0, 0, bignum, bnrm, scllen, nrhs, b, ldb) } + + work[0] = float64(wsize) return true } diff --git a/lapack/gonum/dgeql2.go b/lapack/gonum/dgeql2.go index 6d9b7413..0a92b194 100644 --- a/lapack/gonum/dgeql2.go +++ b/lapack/gonum/dgeql2.go @@ -24,14 +24,30 @@ import "gonum.org/v1/gonum/blas" // // Dgeql2 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgeql2(m, n int, a []float64, lda int, tau, work []float64) { - checkMatrix(m, n, a, lda) - if len(tau) < min(m, n) { - panic(badTau) - } - if len(work) < n { - panic(badWork) + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } + + // Quick return if possible. k := min(m, n) + if k == 0 { + return + } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(badTau) + case len(work) < n: + panic(shortWork) + } + var aii float64 for i := k - 1; i >= 0; i-- { // Generate elementary reflector H_i to annihilate A[0:m-k+i-1, n-k+i]. diff --git a/lapack/gonum/errors.go b/lapack/gonum/errors.go index 33c4a4bf..a58b0764 100644 --- a/lapack/gonum/errors.go +++ b/lapack/gonum/errors.go @@ -30,7 +30,12 @@ const ( badK2 = "lapack: k2 out of range" badKperm = "lapack: incorrect permutation length" badLdA = "lapack: bad leading dimension of A" + badLdB = "lapack: bad leading dimension of B" + badLdC = "lapack: bad leading dimension of C" badLdU = "lapack: bad leading dimension of U" + badLdV = "lapack: bad leading dimension of V" + badLdVL = "lapack: bad leading dimension of VL" + badLdVR = "lapack: bad leading dimension of VR" badLdVT = "lapack: bad leading dimension of VT" badNb = "lapack: nb out of range" badNorm = "lapack: bad norm" @@ -62,7 +67,20 @@ const ( negZ = "lapack: negative z value" nLT0 = "lapack: n < 0" nLTM = "lapack: n < m" + nrhsLT0 = "lapack: nrhs < 0" offsetGTM = "lapack: offset > m" - shortWork = "lapack: working array shorter than declared" + ncvtLT0 = "lapack: ncvt < 0" + nruLT0 = "lapack: nru < 0" + nccLT0 = "lapack: ncc < 0" + shortA = "lapack: insufficient length of a" + shortB = "lapack: insufficient length of b" + shortC = "lapack: insufficient length of c" + shortScale = "lapack: insufficient length of scale" + shortU = "lapack: insufficient length of u" + shortV = "lapack: insufficient length of v" + shortVL = "lapack: insufficient length of vl" + shortVR = "lapack: insufficient length of vr" + shortVT = "lapack: insufficient length of vt" + shortWork = "lapack: working slice shorter than declared" zeroDiv = "lapack: zero divisor" ) diff --git a/lapack/testlapack/dbdsqr.go b/lapack/testlapack/dbdsqr.go index 72c84c6d..71bbc879 100644 --- a/lapack/testlapack/dbdsqr.go +++ b/lapack/testlapack/dbdsqr.go @@ -23,7 +23,6 @@ type Dbdsqrer interface { func DbdsqrTest(t *testing.T, impl Dbdsqrer) { rnd := rand.New(rand.NewSource(1)) bi := blas64.Implementation() - _ = bi for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { for _, test := range []struct { n, ncvt, nru, ncc, ldvt, ldu, ldc int @@ -49,13 +48,13 @@ func DbdsqrTest(t *testing.T, impl Dbdsqrer) { ldu := test.ldu ldc := test.ldc if ldvt == 0 { - ldvt = ncvt + ldvt = max(1, ncvt) } if ldu == 0 { - ldu = n + ldu = max(1, n) } if ldc == 0 { - ldc = ncc + ldc = max(1, ncc) } d := make([]float64, n) @@ -92,7 +91,7 @@ func DbdsqrTest(t *testing.T, impl Dbdsqrer) { pt[i*ldpt+i] = 1 } - ok := impl.Dbdsqr(uplo, n, n, n, 0, d, e, pt, ldpt, q, ldq, nil, 0, work) + ok := impl.Dbdsqr(uplo, n, n, n, 0, d, e, pt, ldpt, q, ldq, nil, 1, work) isUpper := uplo == blas.Upper errStr := fmt.Sprintf("isUpper = %v, n = %v, ncvt = %v, nru = %v, ncc = %v", isUpper, n, ncvt, nru, ncc) diff --git a/lapack/testlapack/dgeev.go b/lapack/testlapack/dgeev.go index f4101434..bc67f9a3 100644 --- a/lapack/testlapack/dgeev.go +++ b/lapack/testlapack/dgeev.go @@ -559,11 +559,15 @@ func testDgeev(t *testing.T, impl Dgeever, tc string, test dgeevTest, jobvl lapa var vl blas64.General if jobvl == lapack.LeftEVCompute { vl = nanGeneral(n, n, n) + } else { + vl.Stride = 1 } var vr blas64.General if jobvr == lapack.RightEVCompute { vr = nanGeneral(n, n, n) + } else { + vr.Stride = 1 } wr := make([]float64, n) diff --git a/mat/eigen.go b/mat/eigen.go index eadf2843..90a94ae9 100644 --- a/mat/eigen.go +++ b/mat/eigen.go @@ -156,10 +156,15 @@ func (e *Eigen) Factorize(a Matrix, left, right bool) (ok bool) { if left { vl = *NewDense(r, r, nil) jobvl = lapack.LeftEVCompute + } else { + vl.mat.Stride = 1 } + if right { vr = *NewDense(c, c, nil) jobvr = lapack.RightEVCompute + } else { + vr.mat.Stride = 1 } wr := getFloats(c, false)