lapack: add SchurComp and SchurUpdateComp types and consts, rename EVComp consts

This commit is contained in:
Vladimir Chalupecky
2018-09-14 13:12:30 +02:00
committed by Vladimír Chalupecký
parent 6d5ac7aa26
commit 8ecf638470
9 changed files with 60 additions and 52 deletions

View File

@@ -116,7 +116,7 @@ func (impl Implementation) Dgeev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob
maxwrk := 2*n + n*impl.Ilaenv(1, "DGEHRD", " ", n, 1, n, 0) maxwrk := 2*n + n*impl.Ilaenv(1, "DGEHRD", " ", n, 1, n, 0)
if wantvl || wantvr { if wantvl || wantvr {
maxwrk = max(maxwrk, 2*n+(n-1)*impl.Ilaenv(1, "DORGHR", " ", n, 1, n, -1)) maxwrk = max(maxwrk, 2*n+(n-1)*impl.Ilaenv(1, "DORGHR", " ", n, 1, n, -1))
impl.Dhseqr(lapack.EigenvaluesAndSchur, lapack.OriginalEV, n, 0, n-1, impl.Dhseqr(lapack.EigenvaluesAndSchur, lapack.SchurOrig, n, 0, n-1,
nil, 1, nil, nil, nil, 1, work, -1) nil, 1, nil, nil, nil, 1, work, -1)
maxwrk = max(maxwrk, max(n+1, n+int(work[0]))) maxwrk = max(maxwrk, max(n+1, n+int(work[0])))
side := lapack.EVLeft side := lapack.EVLeft
@@ -176,7 +176,7 @@ func (impl Implementation) Dgeev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob
impl.Dorghr(n, ilo, ihi, vl, ldvl, tau, work[iwrk:], lwork-iwrk) impl.Dorghr(n, ilo, ihi, vl, ldvl, tau, work[iwrk:], lwork-iwrk)
// Perform QR iteration, accumulating Schur vectors in VL. // Perform QR iteration, accumulating Schur vectors in VL.
iwrk = n iwrk = n
first = impl.Dhseqr(lapack.EigenvaluesAndSchur, lapack.OriginalEV, n, ilo, ihi, first = impl.Dhseqr(lapack.EigenvaluesAndSchur, lapack.SchurOrig, n, ilo, ihi,
a, lda, wr, wi, vl, ldvl, work[iwrk:], lwork-iwrk) a, lda, wr, wi, vl, ldvl, work[iwrk:], lwork-iwrk)
if wantvr { if wantvr {
// Want left and right eigenvectors. // Want left and right eigenvectors.
@@ -192,7 +192,7 @@ func (impl Implementation) Dgeev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob
impl.Dorghr(n, ilo, ihi, vr, ldvr, tau, work[iwrk:], lwork-iwrk) impl.Dorghr(n, ilo, ihi, vr, ldvr, tau, work[iwrk:], lwork-iwrk)
// Perform QR iteration, accumulating Schur vectors in VR. // Perform QR iteration, accumulating Schur vectors in VR.
iwrk = n iwrk = n
first = impl.Dhseqr(lapack.EigenvaluesAndSchur, lapack.OriginalEV, n, ilo, ihi, first = impl.Dhseqr(lapack.EigenvaluesAndSchur, lapack.SchurOrig, n, ilo, ihi,
a, lda, wr, wi, vr, ldvr, work[iwrk:], lwork-iwrk) a, lda, wr, wi, vr, ldvr, work[iwrk:], lwork-iwrk)
} else { } else {
// Compute eigenvalues only. // Compute eigenvalues only.

View File

@@ -27,11 +27,11 @@ import (
// be computed. // be computed.
// For other values of job Dhseqr will panic. // For other values of job Dhseqr will panic.
// //
// If compz == lapack.None, no Schur vectors will be computed and Z will not be // If compz == lapack.SchurNone, no Schur vectors will be computed and Z will not be
// referenced. // referenced.
// If compz == lapack.HessEV, on return Z will contain the matrix of Schur // If compz == lapack.SchurHess, on return Z will contain the matrix of Schur
// vectors of H. // vectors of H.
// If compz == lapack.OriginalEV, on entry z is assumed to contain the orthogonal // If compz == lapack.SchurOrig, on entry z is assumed to contain the orthogonal
// matrix Q that is the identity except for the submatrix // matrix Q that is the identity except for the submatrix
// Q[ilo:ihi+1,ilo:ihi+1]. On return z will be updated to the product Q*Z. // Q[ilo:ihi+1,ilo:ihi+1]. On return z will be updated to the product Q*Z.
// //
@@ -96,11 +96,11 @@ import (
// where U is an orthogonal matrix. The final H is upper Hessenberg and // where U is an orthogonal matrix. The final H is upper Hessenberg and
// H[unconverged:ihi+1,unconverged:ihi+1] is upper quasi-triangular. // H[unconverged:ihi+1,unconverged:ihi+1] is upper quasi-triangular.
// //
// If unconverged > 0 and compz == lapack.OriginalEV, then on return // If unconverged > 0 and compz == lapack.SchurOrig, then on return
// (final Z) = (initial Z) U, // (final Z) = (initial Z) U,
// where U is the orthogonal matrix in (*) regardless of the value of job. // where U is the orthogonal matrix in (*) regardless of the value of job.
// //
// If unconverged > 0 and compz == lapack.HessEV, then on return // If unconverged > 0 and compz == lapack.SchurHess, then on return
// (final Z) = U, // (final Z) = U,
// where U is the orthogonal matrix in (*) regardless of the value of job. // where U is the orthogonal matrix in (*) regardless of the value of job.
// //
@@ -118,7 +118,7 @@ import (
// URL: http://dx.doi.org/10.1137/S0895479801384585 // URL: http://dx.doi.org/10.1137/S0895479801384585
// //
// Dhseqr is an internal routine. It is exported for testing purposes. // Dhseqr is an internal routine. It is exported for testing purposes.
func (impl Implementation) Dhseqr(job lapack.SchurJob, compz lapack.EVComp, n, ilo, ihi int, h []float64, ldh int, wr, wi []float64, z []float64, ldz int, work []float64, lwork int) (unconverged int) { func (impl Implementation) Dhseqr(job lapack.SchurJob, compz lapack.SchurComp, n, ilo, ihi int, h []float64, ldh int, wr, wi []float64, z []float64, ldz int, work []float64, lwork int) (unconverged int) {
var wantt bool var wantt bool
switch job { switch job {
default: default:
@@ -130,9 +130,9 @@ func (impl Implementation) Dhseqr(job lapack.SchurJob, compz lapack.EVComp, n, i
var wantz bool var wantz bool
switch compz { switch compz {
default: default:
panic(badEVComp) panic(badSchurComp)
case lapack.None: case lapack.SchurNone:
case lapack.HessEV, lapack.OriginalEV: case lapack.SchurHess, lapack.SchurOrig:
wantz = true wantz = true
} }
switch { switch {
@@ -197,7 +197,7 @@ func (impl Implementation) Dhseqr(job lapack.SchurJob, compz lapack.EVComp, n, i
} }
// Initialize Z to identity matrix if requested. // Initialize Z to identity matrix if requested.
if compz == lapack.HessEV { if compz == lapack.SchurHess {
impl.Dlaset(blas.All, n, n, 0, 1, z, ldz) impl.Dlaset(blas.All, n, n, 0, 1, z, ldz)
} }

View File

@@ -26,11 +26,11 @@ import (
// Dsteqr will panic otherwise. // Dsteqr will panic otherwise.
// //
// z, on entry, contains the n×n orthogonal matrix used in the reduction to // z, on entry, contains the n×n orthogonal matrix used in the reduction to
// tridiagonal form if compz == lapack.OriginalEV. On exit, if // tridiagonal form if compz == lapack.EVOriginal. On exit, if
// compz == lapack.OriginalEV, z contains the orthonormal eigenvectors of the // compz == lapack.EVOriginal, z contains the orthonormal eigenvectors of the
// original symmetric matrix, and if compz == lapack.TridiagEV, z contains the // original symmetric matrix, and if compz == lapack.EVTridiag, z contains the
// orthonormal eigenvectors of the symmetric tridiagonal matrix. z is not used // orthonormal eigenvectors of the symmetric tridiagonal matrix. z is not used
// if compz == lapack.None. // if compz == lapack.EVCompNone.
// //
// work must have length at least max(1, 2*n-2) if the eigenvectors are computed, // work must have length at least max(1, 2*n-2) if the eigenvectors are computed,
// and Dsteqr will panic otherwise. // and Dsteqr will panic otherwise.
@@ -46,10 +46,10 @@ func (impl Implementation) Dsteqr(compz lapack.EVComp, n int, d, e, z []float64,
if len(e) < n-1 { if len(e) < n-1 {
panic(badE) panic(badE)
} }
if compz != lapack.None && compz != lapack.TridiagEV && compz != lapack.OriginalEV { if compz != lapack.EVCompNone && compz != lapack.EVTridiag && compz != lapack.EVOrig {
panic(badEVComp) panic(badEVComp)
} }
if compz != lapack.None { if compz != lapack.EVCompNone {
if len(work) < max(1, 2*n-2) { if len(work) < max(1, 2*n-2) {
panic(badWork) panic(badWork)
} }
@@ -57,9 +57,9 @@ func (impl Implementation) Dsteqr(compz lapack.EVComp, n int, d, e, z []float64,
} }
var icompz int var icompz int
if compz == lapack.OriginalEV { if compz == lapack.EVOrig {
icompz = 1 icompz = 1
} else if compz == lapack.TridiagEV { } else if compz == lapack.EVTridiag {
icompz = 2 icompz = 2
} }

View File

@@ -18,8 +18,8 @@ import "gonum.org/v1/gonum/lapack"
// as Z^T*T*Z, and will be again in Schur canonical form. // as Z^T*T*Z, and will be again in Schur canonical form.
// //
// If compq is lapack.UpdateSchur, on return the matrix Q of Schur vectors will be // If compq is lapack.UpdateSchur, on return the matrix Q of Schur vectors will be
// updated by postmultiplying it with Z. // updated by post-multiplying it with Z.
// If compq is lapack.None, the matrix Q is not referenced and will not be // If compq is lapack.UpdateSchurNone, the matrix Q is not referenced and will not be
// updated. // updated.
// For other values of compq Dtrexc will panic. // For other values of compq Dtrexc will panic.
// //
@@ -45,13 +45,13 @@ import "gonum.org/v1/gonum/lapack"
// work must have length at least n, otherwise Dtrexc will panic. // work must have length at least n, otherwise Dtrexc will panic.
// //
// Dtrexc is an internal routine. It is exported for testing purposes. // Dtrexc is an internal routine. It is exported for testing purposes.
func (impl Implementation) Dtrexc(compq lapack.EVComp, n int, t []float64, ldt int, q []float64, ldq int, ifst, ilst int, work []float64) (ifstOut, ilstOut int, ok bool) { func (impl Implementation) Dtrexc(compq lapack.UpdateSchurComp, n int, t []float64, ldt int, q []float64, ldq int, ifst, ilst int, work []float64) (ifstOut, ilstOut int, ok bool) {
checkMatrix(n, n, t, ldt) checkMatrix(n, n, t, ldt)
var wantq bool var wantq bool
switch compq { switch compq {
default: default:
panic("lapack: bad value of compq") panic("lapack: bad value of compq")
case lapack.None: case lapack.UpdateSchurNone:
// Nothing to do because wantq is already false. // Nothing to do because wantq is already false.
case lapack.UpdateSchur: case lapack.UpdateSchur:
wantq = true wantq = true

View File

@@ -45,6 +45,7 @@ const (
badNorm = "lapack: bad norm" badNorm = "lapack: bad norm"
badPivot = "lapack: bad pivot" badPivot = "lapack: bad pivot"
badS = "lapack: s has insufficient length" badS = "lapack: s has insufficient length"
badSchurComp = "lapack: bad SchurComp"
badSchurJob = "lapack: bad SchurJob" badSchurJob = "lapack: bad SchurJob"
badShifts = "lapack: bad shifts" badShifts = "lapack: bad shifts"
badSide = "lapack: bad side" badSide = "lapack: bad side"

View File

@@ -127,23 +127,13 @@ const (
GSVDNone GSVDJob = 'N' // Do not compute orthogonal matrix GSVDNone GSVDJob = 'N' // Do not compute orthogonal matrix
) )
// EVComp specifies how eigenvectors are computed. // EVComp specifies how eigenvectors are computed in Dsteqr.
type EVComp byte type EVComp byte
const ( const (
// OriginalEV specifies to compute the eigenvectors of the original EVOrig EVComp = 'V' // Compute eigenvectors of the original symmetric matrix.
// matrix. EVTridiag EVComp = 'I' // Compute eigenvectors of the tridiagonal matrix.
OriginalEV EVComp = 'V' EVCompNone EVComp = 'N' // Do not compute eigenvectors.
// TridiagEV specifies to compute both the eigenvectors of the input
// tridiagonal matrix.
TridiagEV EVComp = 'I'
// HessEV specifies to compute both the eigenvectors of the input upper
// Hessenberg matrix.
HessEV EVComp = 'I'
// UpdateSchur specifies that the matrix of Schur vectors will be
// updated by Dtrexc.
UpdateSchur EVComp = 'V'
) )
// EVJob specifies whether eigenvectors are computed in Dsyev. // EVJob specifies whether eigenvectors are computed in Dsyev.
@@ -188,6 +178,23 @@ const (
EigenvaluesAndSchur SchurJob = 'S' EigenvaluesAndSchur SchurJob = 'S'
) )
// SchurComp specifies whether and how the Schur vectors are computed in Dhseqr.
type SchurComp byte
const (
SchurNone SchurComp = 'N' // Schur vectors are not computed.
SchurHess SchurComp = 'I' // Schur vectors of the upper Hessenberg marix are computed.
SchurOrig SchurComp = 'V' // Schur vectors of the original matrix are computed.
)
// UpdateSchurComp specifies whether the matrix of Schur vectors is updated in Dtrexc.
type UpdateSchurComp byte
const (
UpdateSchur UpdateSchurComp = 'V' // The matrix of Schur vectors is updated.
UpdateSchurNone UpdateSchurComp = 'N' // The matrix of Schur vectors is not updated.
)
// EVSide specifies what eigenvectors are computed in Dtrevc3. // EVSide specifies what eigenvectors are computed in Dtrevc3.
type EVSide byte type EVSide byte

View File

@@ -16,7 +16,7 @@ import (
) )
type Dhseqrer interface { type Dhseqrer interface {
Dhseqr(job lapack.SchurJob, compz lapack.EVComp, n, ilo, ihi int, h []float64, ldh int, wr, wi []float64, Dhseqr(job lapack.SchurJob, compz lapack.SchurComp, n, ilo, ihi int, h []float64, ldh int, wr, wi []float64,
z []float64, ldz int, work []float64, lwork int) int z []float64, ldz int, work []float64, lwork int) int
} }
@@ -57,11 +57,11 @@ func testDhseqr(t *testing.T, impl Dhseqrer, i int, test dhseqrTest, job lapack.
copyGeneral(h, blas64.General{Rows: n, Cols: n, Stride: max(1, n), Data: test.h}) copyGeneral(h, blas64.General{Rows: n, Cols: n, Stride: max(1, n), Data: test.h})
hCopy := cloneGeneral(h) hCopy := cloneGeneral(h)
var compz lapack.EVComp = lapack.None compz := lapack.SchurNone
z := blas64.General{Stride: max(1, n)} z := blas64.General{Stride: max(1, n)}
if wantz { if wantz {
// First, let Dhseqr initialize Z to the identity matrix. // First, let Dhseqr initialize Z to the identity matrix.
compz = lapack.HessEV compz = lapack.SchurHess
z = nanGeneral(n, n, n+extra) z = nanGeneral(n, n, n+extra)
} }
@@ -70,7 +70,7 @@ func testDhseqr(t *testing.T, impl Dhseqrer, i int, test dhseqrTest, job lapack.
work := nanSlice(max(1, n)) work := nanSlice(max(1, n))
if optwork { if optwork {
impl.Dhseqr(job, lapack.HessEV, n, ilo, ihi, nil, h.Stride, nil, nil, nil, z.Stride, work, -1) impl.Dhseqr(job, lapack.SchurHess, n, ilo, ihi, nil, h.Stride, nil, nil, nil, z.Stride, work, -1)
work = nanSlice(int(work[0])) work = nanSlice(int(work[0]))
} }
@@ -196,7 +196,7 @@ func testDhseqr(t *testing.T, impl Dhseqrer, i int, test dhseqrTest, job lapack.
copyGeneral(h, hCopy) copyGeneral(h, hCopy)
// Call Dhseqr again with the identity matrix given explicitly in Q. // Call Dhseqr again with the identity matrix given explicitly in Q.
q := eye(n, n+extra) q := eye(n, n+extra)
impl.Dhseqr(job, lapack.OriginalEV, n, ilo, ihi, h.Data, h.Stride, wr, wi, q.Data, q.Stride, work, len(work)) impl.Dhseqr(job, lapack.SchurOrig, n, ilo, ihi, h.Data, h.Stride, wr, wi, q.Data, q.Stride, work, len(work))
if !equalApproxGeneral(z, q, 0) { if !equalApproxGeneral(z, q, 0) {
t.Errorf("%v: Z and Q are not equal", prefix) t.Errorf("%v: Z and Q are not equal", prefix)
} }

View File

@@ -22,7 +22,7 @@ type Dsteqrer interface {
func DsteqrTest(t *testing.T, impl Dsteqrer) { func DsteqrTest(t *testing.T, impl Dsteqrer) {
rnd := rand.New(rand.NewSource(1)) rnd := rand.New(rand.NewSource(1))
for _, compz := range []lapack.EVComp{lapack.OriginalEV, lapack.TridiagEV} { for _, compz := range []lapack.EVComp{lapack.EVOrig, lapack.EVTridiag} {
for _, test := range []struct { for _, test := range []struct {
n, lda int n, lda int
}{ }{
@@ -59,7 +59,7 @@ func DsteqrTest(t *testing.T, impl Dsteqrer) {
copy(eCopy, e) copy(eCopy, e)
aCopy := make([]float64, len(a)) aCopy := make([]float64, len(a))
copy(aCopy, a) copy(aCopy, a)
if compz == lapack.OriginalEV { if compz == lapack.EVOrig {
uplo := blas.Upper uplo := blas.Upper
tau := make([]float64, n) tau := make([]float64, n)
work := make([]float64, 1) work := make([]float64, 1)
@@ -92,7 +92,7 @@ func DsteqrTest(t *testing.T, impl Dsteqrer) {
copy(dAns, d) copy(dAns, d)
var truth blas64.General var truth blas64.General
if compz == lapack.OriginalEV { if compz == lapack.EVOrig {
truth = blas64.General{ truth = blas64.General{
Rows: n, Rows: n,
Cols: n, Cols: n,
@@ -130,7 +130,7 @@ func DsteqrTest(t *testing.T, impl Dsteqrer) {
} }
if !eigenDecompCorrect(d, truth, V) { if !eigenDecompCorrect(d, truth, V) {
t.Errorf("Eigen reconstruction mismatch. fromFull = %v, n = %v", t.Errorf("Eigen reconstruction mismatch. fromFull = %v, n = %v",
compz == lapack.OriginalEV, n) compz == lapack.EVOrig, n)
} }
// Compare eigenvalues when not computing eigenvectors. // Compare eigenvalues when not computing eigenvectors.

View File

@@ -17,13 +17,13 @@ import (
) )
type Dtrexcer interface { type Dtrexcer interface {
Dtrexc(compq lapack.EVComp, n int, t []float64, ldt int, q []float64, ldq int, ifst, ilst int, work []float64) (ifstOut, ilstOut int, ok bool) Dtrexc(compq lapack.UpdateSchurComp, n int, t []float64, ldt int, q []float64, ldq int, ifst, ilst int, work []float64) (ifstOut, ilstOut int, ok bool)
} }
func DtrexcTest(t *testing.T, impl Dtrexcer) { func DtrexcTest(t *testing.T, impl Dtrexcer) {
rnd := rand.New(rand.NewSource(1)) rnd := rand.New(rand.NewSource(1))
for _, compq := range []lapack.EVComp{lapack.None, lapack.UpdateSchur} { for _, compq := range []lapack.UpdateSchurComp{lapack.UpdateSchurNone, lapack.UpdateSchur} {
for _, n := range []int{1, 2, 3, 4, 5, 6, 10, 18, 31, 53} { for _, n := range []int{1, 2, 3, 4, 5, 6, 10, 18, 31, 53} {
for _, extra := range []int{0, 1, 11} { for _, extra := range []int{0, 1, 11} {
for cas := 0; cas < 100; cas++ { for cas := 0; cas < 100; cas++ {
@@ -36,7 +36,7 @@ func DtrexcTest(t *testing.T, impl Dtrexcer) {
} }
} }
for _, compq := range []lapack.EVComp{lapack.None, lapack.UpdateSchur} { for _, compq := range []lapack.UpdateSchurComp{lapack.UpdateSchurNone, lapack.UpdateSchur} {
for _, extra := range []int{0, 1, 11} { for _, extra := range []int{0, 1, 11} {
tmat := randomSchurCanonical(0, extra, rnd) tmat := randomSchurCanonical(0, extra, rnd)
testDtrexc(t, impl, compq, tmat, 0, 0, extra, rnd) testDtrexc(t, impl, compq, tmat, 0, 0, extra, rnd)
@@ -44,7 +44,7 @@ func DtrexcTest(t *testing.T, impl Dtrexcer) {
} }
} }
func testDtrexc(t *testing.T, impl Dtrexcer, compq lapack.EVComp, tmat blas64.General, ifst, ilst, extra int, rnd *rand.Rand) { func testDtrexc(t *testing.T, impl Dtrexcer, compq lapack.UpdateSchurComp, tmat blas64.General, ifst, ilst, extra int, rnd *rand.Rand) {
const tol = 1e-13 const tol = 1e-13
n := tmat.Rows n := tmat.Rows