diff --git a/lapack/gonum/dgghrd.go b/lapack/gonum/dgghrd.go index da09ea9e..c9d6b4d1 100644 --- a/lapack/gonum/dgghrd.go +++ b/lapack/gonum/dgghrd.go @@ -10,57 +10,45 @@ import ( "gonum.org/v1/gonum/lapack" ) -// Dgghrd reduces a pair of real matrices (A,B) to generalized upper -// Hessenberg form using orthogonal transformations, where A is a -// general matrix and B is upper triangular. The form of the -// generalized eigenvalue problem is +// Dgghrd reduces a pair of real matrices (A,B) to generalized upper Hessenberg +// form using orthogonal transformations, where A is a general matrix and B is +// upper triangular. // -// A*x = lambda*B*x, +// This subroutine simultaneously reduces A to a Hessenberg matrix H // -// and B is typically made upper triangular by computing its QR -// factorization and moving the orthogonal matrix Q to the left side -// of the equation. -// This subroutine simultaneously reduces A to a Hessenberg matrix H: +// Qᵀ*A*Z = H, // -// Qᵀ*A*Z = H +// and transforms B to another upper triangular matrix T // -// and transforms B to another upper triangular matrix T: -// -// Qᵀ*B*Z = T -// -// in order to reduce the problem to its standard form -// -// H*y = lambda*T*y -// -// where y = Zᵀ*x. +// Qᵀ*B*Z = T. // // The orthogonal matrices Q and Z are determined as products of Givens -// rotations. They may either be formed explicitly, or they may be -// postmultiplied into input matrices Q1 and Z1, so that +// rotations. They may either be formed explicitly (lapack.OrthoExplicit), or +// they may be postmultiplied into input matrices Q1 and Z1 +// (lapack.OrthoPostmul), so that // -// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ -// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ +// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ, +// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ. // -// If Q1 is the orthogonal matrix from the QR factorization of B in the -// original equation A*x = lambda*B*x, then Dgghrd reduces the original -// problem to generalized Hessenberg form. +// ilo and ihi determine the block of A that will be reduced. It must hold that +// +// - 0 <= ilo <= ihi < n if n > 0, +// - ilo == 0 and ihi == -1 if n == 0, +// +// otherwise Dgghrd will panic. // // Dgghrd is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) { switch { - case compq != lapack.OrthoNone && compq != lapack.OrthoEntry && compq != lapack.OrthoUnit: + case compq != lapack.OrthoNone && compq != lapack.OrthoExplicit && compq != lapack.OrthoPostmul: panic(badOrthoComp) - case compz != lapack.OrthoNone && compz != lapack.OrthoEntry && compz != lapack.OrthoUnit: + case compz != lapack.OrthoNone && compz != lapack.OrthoExplicit && compz != lapack.OrthoPostmul: panic(badOrthoComp) - case len(a) < (n-1)*lda+n: - panic(shortA) - case len(b) < (n-1)*ldb+n: - panic(shortB) case n < 0: panic(nLT0) - case ilo < 0: + case ilo < 0 || max(0, n-1) < ilo: panic(badIlo) - case ihi < ilo-1 || ihi >= n: + case ihi < min(ilo, n-1) || n <= ihi: panic(badIhi) case lda < max(1, n): panic(badLdA) @@ -70,20 +58,34 @@ func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int panic(badLdQ) case (compz != lapack.OrthoNone && ldz < n) || ldz < 1: panic(badLdZ) + } + + // Quick return if possible. + if n == 0 { + return + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(b) < (n-1)*ldb+n: + panic(shortB) case compq != lapack.OrthoNone && len(q) < (n-1)*ldq+n: panic(shortQ) case compz != lapack.OrthoNone && len(z) < (n-1)*ldz+n: panic(shortZ) } - if compq == lapack.OrthoUnit { + if compq == lapack.OrthoExplicit { impl.Dlaset(blas.All, n, n, 0, 1, q, ldq) } - if compz == lapack.OrthoUnit { + if compz == lapack.OrthoExplicit { impl.Dlaset(blas.All, n, n, 0, 1, z, ldz) } - if n <= 1 { - return // Quick return if possible. + + // Quick return if possible. + if n == 1 { + return } // Zero out lower triangle of B. @@ -96,21 +98,19 @@ func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int // Reduce A and B. for jcol := ilo; jcol <= ihi-2; jcol++ { for jrow := ihi; jrow >= jcol+2; jrow-- { - // Step 1: rotate rows JROW-1, JROW to kill A(JROW,JCOL). + // Step 1: rotate rows jrow-1, jrow to kill A[jrow,jcol]. var c, s float64 c, s, a[(jrow-1)*lda+jcol] = impl.Dlartg(a[(jrow-1)*lda+jcol], a[jrow*lda+jcol]) a[jrow*lda+jcol] = 0 - bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1, - a[jrow*lda+jcol+1:], 1, c, s) - bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1, - b[jrow*ldb+jrow-1:], 1, c, s) + bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1, a[jrow*lda+jcol+1:], 1, c, s) + bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1, b[jrow*ldb+jrow-1:], 1, c, s) if compq != lapack.OrthoNone { bi.Drot(n, q[jrow-1:], ldq, q[jrow:], ldq, c, s) } - // Step 2: rotate columns JROW, JROW-1 to kill B(JROW,JROW-1). + // Step 2: rotate columns jrow, jrow-1 to kill B[jrow,jrow-1]. c, s, b[jrow*ldb+jrow] = impl.Dlartg(b[jrow*ldb+jrow], b[jrow*ldb+jrow-1]) b[jrow*ldb+jrow-1] = 0 diff --git a/lapack/lapack.go b/lapack/lapack.go index 72111d66..18790462 100644 --- a/lapack/lapack.go +++ b/lapack/lapack.go @@ -231,7 +231,7 @@ const ( type OrthoComp byte const ( - OrthoNone OrthoComp = 'N' // Do not compute orthogonal matrix. - OrthoUnit OrthoComp = 'I' // Argument is initialized to the unit matrix and the orthogonal matrix is returned. - OrthoEntry OrthoComp = 'V' // Argument Q contains orthogonal matrix Q1 on entry and the product Q1*Q is returned. + OrthoNone OrthoComp = 'N' // Do not compute the orthogonal matrix. + OrthoExplicit OrthoComp = 'I' // The orthogonal matrix is formed explicitly and returned in the argument. + OrthoPostmul OrthoComp = 'V' // The orthogonal matrix is post-multiplied into the matrix stored in the argument on entry. ) diff --git a/lapack/testlapack/dgghrd.go b/lapack/testlapack/dgghrd.go index 02b0bc86..606ade96 100644 --- a/lapack/testlapack/dgghrd.go +++ b/lapack/testlapack/dgghrd.go @@ -5,7 +5,7 @@ package testlapack import ( - "math" + "fmt" "testing" "golang.org/x/exp/rand" @@ -20,112 +20,115 @@ type Dgghrder interface { } func DgghrdTest(t *testing.T, impl Dgghrder) { - const tol = 1e-13 - const ldAdd = 5 rnd := rand.New(rand.NewSource(1)) - comps := []lapack.OrthoComp{lapack.OrthoUnit, lapack.OrthoNone, lapack.OrthoEntry} + comps := []lapack.OrthoComp{lapack.OrthoExplicit, lapack.OrthoNone, lapack.OrthoPostmul} for _, compq := range comps { for _, compz := range comps { - for _, n := range []int{2, 0, 1, 4, 15} { - ldMin := max(1, n) - for _, lda := range []int{ldMin, ldMin + ldAdd} { - for _, ldb := range []int{ldMin, ldMin + ldAdd} { - for _, ldq := range []int{ldMin, ldMin + ldAdd} { - for _, ldz := range []int{ldMin, ldMin + ldAdd} { - testDgghrd(t, impl, rnd, tol, compq, compz, n, 0, n-1, lda, ldb, ldq, ldz) - } - } - } + for _, n := range []int{0, 1, 2, 3, 4, 15} { + for _, ld := range []int{max(1, n), n + 5} { + testDgghrd(t, impl, rnd, compq, compz, n, 0, n-1, ld, ld, ld, ld) } } } } } -func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) { +func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) { + const tol = 1e-13 + a := randomGeneral(n, n, lda, rnd) - b := blockedUpperTriGeneral(n, n, 0, n, ldb, false, rnd) - var q, q1, z, z1 blas64.General - if compq == lapack.OrthoEntry { + b := randomGeneral(n, n, ldb, rnd) + + var q, q1 blas64.General + switch compq { + case lapack.OrthoExplicit: + // Initialize q to a non-orthogonal matrix, Dgghrd should overwrite it + // with an orthogonal Q. + q = randomGeneral(n, n, ldq, rnd) + case lapack.OrthoPostmul: + // Initialize q to an orthogonal matrix Q1, so that the result Q1*Q is + // again orthogonal. q = randomOrthogonal(n, rnd) q1 = cloneGeneral(q) - } else { - q = nanGeneral(n, n, ldq) } - if compz == lapack.OrthoEntry { + + var z, z1 blas64.General + switch compz { + case lapack.OrthoExplicit: + z = randomGeneral(n, n, ldz, rnd) + case lapack.OrthoPostmul: z = randomOrthogonal(n, rnd) z1 = cloneGeneral(z) - } else { - z = nanGeneral(n, n, ldz) } hGot := cloneGeneral(a) tGot := cloneGeneral(b) - for i := 1; i < n; i++ { - for j := 0; j < i; j++ { - // Set all lower tri elems to NaN to catch bad implementations. - tGot.Data[i*tGot.Stride+j] = math.NaN() - } - } - impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, q.Stride, z.Data, z.Stride) + impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, max(1, q.Stride), z.Data, max(1, z.Stride)) + if n == 0 { return } + + name := fmt.Sprintf("Case compq=%v,compz=%v,n=%v,ilo=%v,ihi=%v", compq, compz, n, ilo, ihi) + if !isUpperHessenberg(hGot) { - t.Error("H is not upper Hessenberg") - } - if !isNaNFree(tGot) || !isNaNFree(hGot) { - t.Error("T or H is/or not NaN free") + t.Errorf("%v: H is not upper Hessenberg", name) } if !isUpperTriangular(tGot) { - t.Error("T is not upper triangular") + t.Errorf("%v: T is not upper triangular", name) } - if compq == lapack.OrthoNone { - if !isAllNaN(q.Data) { - t.Errorf("Q is not NaN") + if compq != lapack.OrthoNone { + if resid := residualOrthogonal(q, true); resid > tol { + t.Errorf("%v: Q is not orthogonal, resid=%v", name, resid) } - return } - if compz == lapack.OrthoNone { - if !isAllNaN(z.Data) { - t.Errorf("Z is not NaN") + if compz != lapack.OrthoNone { + if resid := residualOrthogonal(z, true); resid > tol { + t.Errorf("%v: Z is not orthogonal, resid=%v", name, resid) } - return } - if compq != compz { - return // Do not handle mixed case - } - comp := compq - aux := zeros(n, n, n) - switch comp { - case lapack.OrthoUnit: + if compq != compz { + // Verify reduction only when both Q and Z are computed. + return + } + + // Zero out the lower triangle of B. + for i := 1; i < n; i++ { + for j := 0; j < i; j++ { + b.Data[i*b.Stride+j] = 0 + } + } + + aux := zeros(n, n, n) + switch compq { + case lapack.OrthoExplicit: // Qᵀ*A*Z = H hCalc := zeros(n, n, n) blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc) if !equalApproxGeneral(hGot, hCalc, tol) { - t.Errorf("Qᵀ*A*Z != H") + t.Errorf("%v: Qᵀ*A*Z != H", name) } // Qᵀ*B*Z = T tCalc := zeros(n, n, n) blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc) - if !equalApproxGeneral(hGot, hCalc, tol) { - t.Errorf("Qᵀ*B*Z != T") + if !equalApproxGeneral(tGot, tCalc, tol) { + t.Errorf("%v: Qᵀ*B*Z != T", name) } - case lapack.OrthoEntry: + case lapack.OrthoPostmul: // Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ lhs := zeros(n, n, n) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux) - blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) // lhs = Q1 * A * Z1ᵀ + blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) rhs := zeros(n, n, n) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux) blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs) if !equalApproxGeneral(lhs, rhs, tol) { - t.Errorf("Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ") + t.Errorf("%v: Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ", name) } // Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ @@ -135,7 +138,7 @@ func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq, blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux) blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs) if !equalApproxGeneral(lhs, rhs, tol) { - t.Errorf("Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ") + t.Errorf("%v: Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ", name) } } } diff --git a/lapack/testlapack/general.go b/lapack/testlapack/general.go index 2291dcab..fb16a4d8 100644 --- a/lapack/testlapack/general.go +++ b/lapack/testlapack/general.go @@ -1198,23 +1198,13 @@ func isUpperHessenberg(h blas64.General) bool { // isUpperTriangular returns whether a contains only zeros below the diagonal. func isUpperTriangular(a blas64.General) bool { + if a.Rows != a.Cols { + panic("matrix not square") + } n := a.Rows for i := 1; i < n; i++ { for j := 0; j < i; j++ { - v := a.Data[i*a.Stride+j] - if v != 0 || math.IsNaN(v) { - return false - } - } - } - return true -} - -// isNaNFree returns whether a does not contain NaN elements in reachable elements. -func isNaNFree(a blas64.General) bool { - for i := 0; i < a.Rows; i++ { - for j := 0; j < a.Cols; j++ { - if math.IsNaN(a.Data[i*a.Stride+j]) { + if a.Data[i*a.Stride+j] != 0 { return false } }