lapack/gonum: clean up Dgghrd and its test

This commit is contained in:
Vladimir Chalupecky
2023-09-11 12:22:17 +02:00
committed by Vladimír Chalupecký
parent f0a57a452a
commit 7df15c334b
4 changed files with 111 additions and 118 deletions

View File

@@ -10,57 +10,45 @@ import (
"gonum.org/v1/gonum/lapack" "gonum.org/v1/gonum/lapack"
) )
// Dgghrd reduces a pair of real matrices (A,B) to generalized upper // Dgghrd reduces a pair of real matrices (A,B) to generalized upper Hessenberg
// Hessenberg form using orthogonal transformations, where A is a // form using orthogonal transformations, where A is a general matrix and B is
// general matrix and B is upper triangular. The form of the // upper triangular.
// generalized eigenvalue problem is
// //
// A*x = lambda*B*x, // This subroutine simultaneously reduces A to a Hessenberg matrix H
// //
// and B is typically made upper triangular by computing its QR // Qᵀ*A*Z = H,
// factorization and moving the orthogonal matrix Q to the left side
// of the equation.
// This subroutine simultaneously reduces A to a Hessenberg matrix H:
// //
// Qᵀ*A*Z = H // and transforms B to another upper triangular matrix T
// //
// and transforms B to another upper triangular matrix T: // Qᵀ*B*Z = T.
//
// Qᵀ*B*Z = T
//
// in order to reduce the problem to its standard form
//
// H*y = lambda*T*y
//
// where y = Zᵀ*x.
// //
// The orthogonal matrices Q and Z are determined as products of Givens // The orthogonal matrices Q and Z are determined as products of Givens
// rotations. They may either be formed explicitly, or they may be // rotations. They may either be formed explicitly (lapack.OrthoExplicit), or
// postmultiplied into input matrices Q1 and Z1, so that // they may be postmultiplied into input matrices Q1 and Z1
// (lapack.OrthoPostmul), so that
// //
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ // Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ,
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ // Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ.
// //
// If Q1 is the orthogonal matrix from the QR factorization of B in the // ilo and ihi determine the block of A that will be reduced. It must hold that
// original equation A*x = lambda*B*x, then Dgghrd reduces the original //
// problem to generalized Hessenberg form. // - 0 <= ilo <= ihi < n if n > 0,
// - ilo == 0 and ihi == -1 if n == 0,
//
// otherwise Dgghrd will panic.
// //
// Dgghrd is an internal routine. It is exported for testing purposes. // Dgghrd is an internal routine. It is exported for testing purposes.
func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) { func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) {
switch { switch {
case compq != lapack.OrthoNone && compq != lapack.OrthoEntry && compq != lapack.OrthoUnit: case compq != lapack.OrthoNone && compq != lapack.OrthoExplicit && compq != lapack.OrthoPostmul:
panic(badOrthoComp) panic(badOrthoComp)
case compz != lapack.OrthoNone && compz != lapack.OrthoEntry && compz != lapack.OrthoUnit: case compz != lapack.OrthoNone && compz != lapack.OrthoExplicit && compz != lapack.OrthoPostmul:
panic(badOrthoComp) panic(badOrthoComp)
case len(a) < (n-1)*lda+n:
panic(shortA)
case len(b) < (n-1)*ldb+n:
panic(shortB)
case n < 0: case n < 0:
panic(nLT0) panic(nLT0)
case ilo < 0: case ilo < 0 || max(0, n-1) < ilo:
panic(badIlo) panic(badIlo)
case ihi < ilo-1 || ihi >= n: case ihi < min(ilo, n-1) || n <= ihi:
panic(badIhi) panic(badIhi)
case lda < max(1, n): case lda < max(1, n):
panic(badLdA) panic(badLdA)
@@ -70,20 +58,34 @@ func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int
panic(badLdQ) panic(badLdQ)
case (compz != lapack.OrthoNone && ldz < n) || ldz < 1: case (compz != lapack.OrthoNone && ldz < n) || ldz < 1:
panic(badLdZ) panic(badLdZ)
}
// Quick return if possible.
if n == 0 {
return
}
switch {
case len(a) < (n-1)*lda+n:
panic(shortA)
case len(b) < (n-1)*ldb+n:
panic(shortB)
case compq != lapack.OrthoNone && len(q) < (n-1)*ldq+n: case compq != lapack.OrthoNone && len(q) < (n-1)*ldq+n:
panic(shortQ) panic(shortQ)
case compz != lapack.OrthoNone && len(z) < (n-1)*ldz+n: case compz != lapack.OrthoNone && len(z) < (n-1)*ldz+n:
panic(shortZ) panic(shortZ)
} }
if compq == lapack.OrthoUnit { if compq == lapack.OrthoExplicit {
impl.Dlaset(blas.All, n, n, 0, 1, q, ldq) impl.Dlaset(blas.All, n, n, 0, 1, q, ldq)
} }
if compz == lapack.OrthoUnit { if compz == lapack.OrthoExplicit {
impl.Dlaset(blas.All, n, n, 0, 1, z, ldz) impl.Dlaset(blas.All, n, n, 0, 1, z, ldz)
} }
if n <= 1 {
return // Quick return if possible. // Quick return if possible.
if n == 1 {
return
} }
// Zero out lower triangle of B. // Zero out lower triangle of B.
@@ -96,21 +98,19 @@ func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int
// Reduce A and B. // Reduce A and B.
for jcol := ilo; jcol <= ihi-2; jcol++ { for jcol := ilo; jcol <= ihi-2; jcol++ {
for jrow := ihi; jrow >= jcol+2; jrow-- { for jrow := ihi; jrow >= jcol+2; jrow-- {
// Step 1: rotate rows JROW-1, JROW to kill A(JROW,JCOL). // Step 1: rotate rows jrow-1, jrow to kill A[jrow,jcol].
var c, s float64 var c, s float64
c, s, a[(jrow-1)*lda+jcol] = impl.Dlartg(a[(jrow-1)*lda+jcol], a[jrow*lda+jcol]) c, s, a[(jrow-1)*lda+jcol] = impl.Dlartg(a[(jrow-1)*lda+jcol], a[jrow*lda+jcol])
a[jrow*lda+jcol] = 0 a[jrow*lda+jcol] = 0
bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1,
a[jrow*lda+jcol+1:], 1, c, s)
bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1, bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1, a[jrow*lda+jcol+1:], 1, c, s)
b[jrow*ldb+jrow-1:], 1, c, s) bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1, b[jrow*ldb+jrow-1:], 1, c, s)
if compq != lapack.OrthoNone { if compq != lapack.OrthoNone {
bi.Drot(n, q[jrow-1:], ldq, q[jrow:], ldq, c, s) bi.Drot(n, q[jrow-1:], ldq, q[jrow:], ldq, c, s)
} }
// Step 2: rotate columns JROW, JROW-1 to kill B(JROW,JROW-1). // Step 2: rotate columns jrow, jrow-1 to kill B[jrow,jrow-1].
c, s, b[jrow*ldb+jrow] = impl.Dlartg(b[jrow*ldb+jrow], b[jrow*ldb+jrow-1]) c, s, b[jrow*ldb+jrow] = impl.Dlartg(b[jrow*ldb+jrow], b[jrow*ldb+jrow-1])
b[jrow*ldb+jrow-1] = 0 b[jrow*ldb+jrow-1] = 0

View File

@@ -231,7 +231,7 @@ const (
type OrthoComp byte type OrthoComp byte
const ( const (
OrthoNone OrthoComp = 'N' // Do not compute orthogonal matrix. OrthoNone OrthoComp = 'N' // Do not compute the orthogonal matrix.
OrthoUnit OrthoComp = 'I' // Argument is initialized to the unit matrix and the orthogonal matrix is returned. OrthoExplicit OrthoComp = 'I' // The orthogonal matrix is formed explicitly and returned in the argument.
OrthoEntry OrthoComp = 'V' // Argument Q contains orthogonal matrix Q1 on entry and the product Q1*Q is returned. OrthoPostmul OrthoComp = 'V' // The orthogonal matrix is post-multiplied into the matrix stored in the argument on entry.
) )

View File

@@ -5,7 +5,7 @@
package testlapack package testlapack
import ( import (
"math" "fmt"
"testing" "testing"
"golang.org/x/exp/rand" "golang.org/x/exp/rand"
@@ -20,112 +20,115 @@ type Dgghrder interface {
} }
func DgghrdTest(t *testing.T, impl Dgghrder) { func DgghrdTest(t *testing.T, impl Dgghrder) {
const tol = 1e-13
const ldAdd = 5
rnd := rand.New(rand.NewSource(1)) rnd := rand.New(rand.NewSource(1))
comps := []lapack.OrthoComp{lapack.OrthoUnit, lapack.OrthoNone, lapack.OrthoEntry} comps := []lapack.OrthoComp{lapack.OrthoExplicit, lapack.OrthoNone, lapack.OrthoPostmul}
for _, compq := range comps { for _, compq := range comps {
for _, compz := range comps { for _, compz := range comps {
for _, n := range []int{2, 0, 1, 4, 15} { for _, n := range []int{0, 1, 2, 3, 4, 15} {
ldMin := max(1, n) for _, ld := range []int{max(1, n), n + 5} {
for _, lda := range []int{ldMin, ldMin + ldAdd} { testDgghrd(t, impl, rnd, compq, compz, n, 0, n-1, ld, ld, ld, ld)
for _, ldb := range []int{ldMin, ldMin + ldAdd} {
for _, ldq := range []int{ldMin, ldMin + ldAdd} {
for _, ldz := range []int{ldMin, ldMin + ldAdd} {
testDgghrd(t, impl, rnd, tol, compq, compz, n, 0, n-1, lda, ldb, ldq, ldz)
}
}
}
} }
} }
} }
} }
} }
func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) { func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) {
const tol = 1e-13
a := randomGeneral(n, n, lda, rnd) a := randomGeneral(n, n, lda, rnd)
b := blockedUpperTriGeneral(n, n, 0, n, ldb, false, rnd) b := randomGeneral(n, n, ldb, rnd)
var q, q1, z, z1 blas64.General
if compq == lapack.OrthoEntry { var q, q1 blas64.General
switch compq {
case lapack.OrthoExplicit:
// Initialize q to a non-orthogonal matrix, Dgghrd should overwrite it
// with an orthogonal Q.
q = randomGeneral(n, n, ldq, rnd)
case lapack.OrthoPostmul:
// Initialize q to an orthogonal matrix Q1, so that the result Q1*Q is
// again orthogonal.
q = randomOrthogonal(n, rnd) q = randomOrthogonal(n, rnd)
q1 = cloneGeneral(q) q1 = cloneGeneral(q)
} else {
q = nanGeneral(n, n, ldq)
} }
if compz == lapack.OrthoEntry {
var z, z1 blas64.General
switch compz {
case lapack.OrthoExplicit:
z = randomGeneral(n, n, ldz, rnd)
case lapack.OrthoPostmul:
z = randomOrthogonal(n, rnd) z = randomOrthogonal(n, rnd)
z1 = cloneGeneral(z) z1 = cloneGeneral(z)
} else {
z = nanGeneral(n, n, ldz)
} }
hGot := cloneGeneral(a) hGot := cloneGeneral(a)
tGot := cloneGeneral(b) tGot := cloneGeneral(b)
for i := 1; i < n; i++ { impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, max(1, q.Stride), z.Data, max(1, z.Stride))
for j := 0; j < i; j++ {
// Set all lower tri elems to NaN to catch bad implementations.
tGot.Data[i*tGot.Stride+j] = math.NaN()
}
}
impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, q.Stride, z.Data, z.Stride)
if n == 0 { if n == 0 {
return return
} }
name := fmt.Sprintf("Case compq=%v,compz=%v,n=%v,ilo=%v,ihi=%v", compq, compz, n, ilo, ihi)
if !isUpperHessenberg(hGot) { if !isUpperHessenberg(hGot) {
t.Error("H is not upper Hessenberg") t.Errorf("%v: H is not upper Hessenberg", name)
}
if !isNaNFree(tGot) || !isNaNFree(hGot) {
t.Error("T or H is/or not NaN free")
} }
if !isUpperTriangular(tGot) { if !isUpperTriangular(tGot) {
t.Error("T is not upper triangular") t.Errorf("%v: T is not upper triangular", name)
} }
if compq == lapack.OrthoNone { if compq != lapack.OrthoNone {
if !isAllNaN(q.Data) { if resid := residualOrthogonal(q, true); resid > tol {
t.Errorf("Q is not NaN") t.Errorf("%v: Q is not orthogonal, resid=%v", name, resid)
} }
return
} }
if compz == lapack.OrthoNone { if compz != lapack.OrthoNone {
if !isAllNaN(z.Data) { if resid := residualOrthogonal(z, true); resid > tol {
t.Errorf("Z is not NaN") t.Errorf("%v: Z is not orthogonal, resid=%v", name, resid)
} }
return
} }
if compq != compz {
return // Do not handle mixed case
}
comp := compq
aux := zeros(n, n, n)
switch comp { if compq != compz {
case lapack.OrthoUnit: // Verify reduction only when both Q and Z are computed.
return
}
// Zero out the lower triangle of B.
for i := 1; i < n; i++ {
for j := 0; j < i; j++ {
b.Data[i*b.Stride+j] = 0
}
}
aux := zeros(n, n, n)
switch compq {
case lapack.OrthoExplicit:
// Qᵀ*A*Z = H // Qᵀ*A*Z = H
hCalc := zeros(n, n, n) hCalc := zeros(n, n, n)
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux) blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc)
if !equalApproxGeneral(hGot, hCalc, tol) { if !equalApproxGeneral(hGot, hCalc, tol) {
t.Errorf("Qᵀ*A*Z != H") t.Errorf("%v: Qᵀ*A*Z != H", name)
} }
// Qᵀ*B*Z = T // Qᵀ*B*Z = T
tCalc := zeros(n, n, n) tCalc := zeros(n, n, n)
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux) blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc)
if !equalApproxGeneral(hGot, hCalc, tol) { if !equalApproxGeneral(tGot, tCalc, tol) {
t.Errorf("Qᵀ*B*Z != T") t.Errorf("%v: Qᵀ*B*Z != T", name)
} }
case lapack.OrthoEntry: case lapack.OrthoPostmul:
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ // Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
lhs := zeros(n, n, n) lhs := zeros(n, n, n)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux)
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) // lhs = Q1 * A * Z1ᵀ blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs)
rhs := zeros(n, n, n) rhs := zeros(n, n, n)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux)
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs) blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
if !equalApproxGeneral(lhs, rhs, tol) { if !equalApproxGeneral(lhs, rhs, tol) {
t.Errorf("Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ") t.Errorf("%v: Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ", name)
} }
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ // Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
@@ -135,7 +138,7 @@ func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq,
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux) blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux)
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs) blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
if !equalApproxGeneral(lhs, rhs, tol) { if !equalApproxGeneral(lhs, rhs, tol) {
t.Errorf("Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ") t.Errorf("%v: Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ", name)
} }
} }
} }

View File

@@ -1198,23 +1198,13 @@ func isUpperHessenberg(h blas64.General) bool {
// isUpperTriangular returns whether a contains only zeros below the diagonal. // isUpperTriangular returns whether a contains only zeros below the diagonal.
func isUpperTriangular(a blas64.General) bool { func isUpperTriangular(a blas64.General) bool {
if a.Rows != a.Cols {
panic("matrix not square")
}
n := a.Rows n := a.Rows
for i := 1; i < n; i++ { for i := 1; i < n; i++ {
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
v := a.Data[i*a.Stride+j] if a.Data[i*a.Stride+j] != 0 {
if v != 0 || math.IsNaN(v) {
return false
}
}
}
return true
}
// isNaNFree returns whether a does not contain NaN elements in reachable elements.
func isNaNFree(a blas64.General) bool {
for i := 0; i < a.Rows; i++ {
for j := 0; j < a.Cols; j++ {
if math.IsNaN(a.Data[i*a.Stride+j]) {
return false return false
} }
} }