lapack/gonum: clean up Dgghrd and its test

This commit is contained in:
Vladimir Chalupecky
2023-09-11 12:22:17 +02:00
committed by Vladimír Chalupecký
parent f0a57a452a
commit 7df15c334b
4 changed files with 111 additions and 118 deletions

View File

@@ -10,57 +10,45 @@ import (
"gonum.org/v1/gonum/lapack"
)
// Dgghrd reduces a pair of real matrices (A,B) to generalized upper
// Hessenberg form using orthogonal transformations, where A is a
// general matrix and B is upper triangular. The form of the
// generalized eigenvalue problem is
// Dgghrd reduces a pair of real matrices (A,B) to generalized upper Hessenberg
// form using orthogonal transformations, where A is a general matrix and B is
// upper triangular.
//
// A*x = lambda*B*x,
// This subroutine simultaneously reduces A to a Hessenberg matrix H
//
// and B is typically made upper triangular by computing its QR
// factorization and moving the orthogonal matrix Q to the left side
// of the equation.
// This subroutine simultaneously reduces A to a Hessenberg matrix H:
// Qᵀ*A*Z = H,
//
// Qᵀ*A*Z = H
// and transforms B to another upper triangular matrix T
//
// and transforms B to another upper triangular matrix T:
//
// Qᵀ*B*Z = T
//
// in order to reduce the problem to its standard form
//
// H*y = lambda*T*y
//
// where y = Zᵀ*x.
// Qᵀ*B*Z = T.
//
// The orthogonal matrices Q and Z are determined as products of Givens
// rotations. They may either be formed explicitly, or they may be
// postmultiplied into input matrices Q1 and Z1, so that
// rotations. They may either be formed explicitly (lapack.OrthoExplicit), or
// they may be postmultiplied into input matrices Q1 and Z1
// (lapack.OrthoPostmul), so that
//
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ,
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ.
//
// If Q1 is the orthogonal matrix from the QR factorization of B in the
// original equation A*x = lambda*B*x, then Dgghrd reduces the original
// problem to generalized Hessenberg form.
// ilo and ihi determine the block of A that will be reduced. It must hold that
//
// - 0 <= ilo <= ihi < n if n > 0,
// - ilo == 0 and ihi == -1 if n == 0,
//
// otherwise Dgghrd will panic.
//
// Dgghrd is an internal routine. It is exported for testing purposes.
func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) {
switch {
case compq != lapack.OrthoNone && compq != lapack.OrthoEntry && compq != lapack.OrthoUnit:
case compq != lapack.OrthoNone && compq != lapack.OrthoExplicit && compq != lapack.OrthoPostmul:
panic(badOrthoComp)
case compz != lapack.OrthoNone && compz != lapack.OrthoEntry && compz != lapack.OrthoUnit:
case compz != lapack.OrthoNone && compz != lapack.OrthoExplicit && compz != lapack.OrthoPostmul:
panic(badOrthoComp)
case len(a) < (n-1)*lda+n:
panic(shortA)
case len(b) < (n-1)*ldb+n:
panic(shortB)
case n < 0:
panic(nLT0)
case ilo < 0:
case ilo < 0 || max(0, n-1) < ilo:
panic(badIlo)
case ihi < ilo-1 || ihi >= n:
case ihi < min(ilo, n-1) || n <= ihi:
panic(badIhi)
case lda < max(1, n):
panic(badLdA)
@@ -70,20 +58,34 @@ func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int
panic(badLdQ)
case (compz != lapack.OrthoNone && ldz < n) || ldz < 1:
panic(badLdZ)
}
// Quick return if possible.
if n == 0 {
return
}
switch {
case len(a) < (n-1)*lda+n:
panic(shortA)
case len(b) < (n-1)*ldb+n:
panic(shortB)
case compq != lapack.OrthoNone && len(q) < (n-1)*ldq+n:
panic(shortQ)
case compz != lapack.OrthoNone && len(z) < (n-1)*ldz+n:
panic(shortZ)
}
if compq == lapack.OrthoUnit {
if compq == lapack.OrthoExplicit {
impl.Dlaset(blas.All, n, n, 0, 1, q, ldq)
}
if compz == lapack.OrthoUnit {
if compz == lapack.OrthoExplicit {
impl.Dlaset(blas.All, n, n, 0, 1, z, ldz)
}
if n <= 1 {
return // Quick return if possible.
// Quick return if possible.
if n == 1 {
return
}
// Zero out lower triangle of B.
@@ -96,21 +98,19 @@ func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int
// Reduce A and B.
for jcol := ilo; jcol <= ihi-2; jcol++ {
for jrow := ihi; jrow >= jcol+2; jrow-- {
// Step 1: rotate rows JROW-1, JROW to kill A(JROW,JCOL).
// Step 1: rotate rows jrow-1, jrow to kill A[jrow,jcol].
var c, s float64
c, s, a[(jrow-1)*lda+jcol] = impl.Dlartg(a[(jrow-1)*lda+jcol], a[jrow*lda+jcol])
a[jrow*lda+jcol] = 0
bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1,
a[jrow*lda+jcol+1:], 1, c, s)
bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1,
b[jrow*ldb+jrow-1:], 1, c, s)
bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1, a[jrow*lda+jcol+1:], 1, c, s)
bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1, b[jrow*ldb+jrow-1:], 1, c, s)
if compq != lapack.OrthoNone {
bi.Drot(n, q[jrow-1:], ldq, q[jrow:], ldq, c, s)
}
// Step 2: rotate columns JROW, JROW-1 to kill B(JROW,JROW-1).
// Step 2: rotate columns jrow, jrow-1 to kill B[jrow,jrow-1].
c, s, b[jrow*ldb+jrow] = impl.Dlartg(b[jrow*ldb+jrow], b[jrow*ldb+jrow-1])
b[jrow*ldb+jrow-1] = 0

View File

@@ -231,7 +231,7 @@ const (
type OrthoComp byte
const (
OrthoNone OrthoComp = 'N' // Do not compute orthogonal matrix.
OrthoUnit OrthoComp = 'I' // Argument is initialized to the unit matrix and the orthogonal matrix is returned.
OrthoEntry OrthoComp = 'V' // Argument Q contains orthogonal matrix Q1 on entry and the product Q1*Q is returned.
OrthoNone OrthoComp = 'N' // Do not compute the orthogonal matrix.
OrthoExplicit OrthoComp = 'I' // The orthogonal matrix is formed explicitly and returned in the argument.
OrthoPostmul OrthoComp = 'V' // The orthogonal matrix is post-multiplied into the matrix stored in the argument on entry.
)

View File

@@ -5,7 +5,7 @@
package testlapack
import (
"math"
"fmt"
"testing"
"golang.org/x/exp/rand"
@@ -20,112 +20,115 @@ type Dgghrder interface {
}
func DgghrdTest(t *testing.T, impl Dgghrder) {
const tol = 1e-13
const ldAdd = 5
rnd := rand.New(rand.NewSource(1))
comps := []lapack.OrthoComp{lapack.OrthoUnit, lapack.OrthoNone, lapack.OrthoEntry}
comps := []lapack.OrthoComp{lapack.OrthoExplicit, lapack.OrthoNone, lapack.OrthoPostmul}
for _, compq := range comps {
for _, compz := range comps {
for _, n := range []int{2, 0, 1, 4, 15} {
ldMin := max(1, n)
for _, lda := range []int{ldMin, ldMin + ldAdd} {
for _, ldb := range []int{ldMin, ldMin + ldAdd} {
for _, ldq := range []int{ldMin, ldMin + ldAdd} {
for _, ldz := range []int{ldMin, ldMin + ldAdd} {
testDgghrd(t, impl, rnd, tol, compq, compz, n, 0, n-1, lda, ldb, ldq, ldz)
}
}
}
for _, n := range []int{0, 1, 2, 3, 4, 15} {
for _, ld := range []int{max(1, n), n + 5} {
testDgghrd(t, impl, rnd, compq, compz, n, 0, n-1, ld, ld, ld, ld)
}
}
}
}
}
func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) {
func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) {
const tol = 1e-13
a := randomGeneral(n, n, lda, rnd)
b := blockedUpperTriGeneral(n, n, 0, n, ldb, false, rnd)
var q, q1, z, z1 blas64.General
if compq == lapack.OrthoEntry {
b := randomGeneral(n, n, ldb, rnd)
var q, q1 blas64.General
switch compq {
case lapack.OrthoExplicit:
// Initialize q to a non-orthogonal matrix, Dgghrd should overwrite it
// with an orthogonal Q.
q = randomGeneral(n, n, ldq, rnd)
case lapack.OrthoPostmul:
// Initialize q to an orthogonal matrix Q1, so that the result Q1*Q is
// again orthogonal.
q = randomOrthogonal(n, rnd)
q1 = cloneGeneral(q)
} else {
q = nanGeneral(n, n, ldq)
}
if compz == lapack.OrthoEntry {
var z, z1 blas64.General
switch compz {
case lapack.OrthoExplicit:
z = randomGeneral(n, n, ldz, rnd)
case lapack.OrthoPostmul:
z = randomOrthogonal(n, rnd)
z1 = cloneGeneral(z)
} else {
z = nanGeneral(n, n, ldz)
}
hGot := cloneGeneral(a)
tGot := cloneGeneral(b)
for i := 1; i < n; i++ {
for j := 0; j < i; j++ {
// Set all lower tri elems to NaN to catch bad implementations.
tGot.Data[i*tGot.Stride+j] = math.NaN()
}
}
impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, q.Stride, z.Data, z.Stride)
impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, max(1, q.Stride), z.Data, max(1, z.Stride))
if n == 0 {
return
}
name := fmt.Sprintf("Case compq=%v,compz=%v,n=%v,ilo=%v,ihi=%v", compq, compz, n, ilo, ihi)
if !isUpperHessenberg(hGot) {
t.Error("H is not upper Hessenberg")
}
if !isNaNFree(tGot) || !isNaNFree(hGot) {
t.Error("T or H is/or not NaN free")
t.Errorf("%v: H is not upper Hessenberg", name)
}
if !isUpperTriangular(tGot) {
t.Error("T is not upper triangular")
t.Errorf("%v: T is not upper triangular", name)
}
if compq == lapack.OrthoNone {
if !isAllNaN(q.Data) {
t.Errorf("Q is not NaN")
if compq != lapack.OrthoNone {
if resid := residualOrthogonal(q, true); resid > tol {
t.Errorf("%v: Q is not orthogonal, resid=%v", name, resid)
}
return
}
if compz == lapack.OrthoNone {
if !isAllNaN(z.Data) {
t.Errorf("Z is not NaN")
if compz != lapack.OrthoNone {
if resid := residualOrthogonal(z, true); resid > tol {
t.Errorf("%v: Z is not orthogonal, resid=%v", name, resid)
}
return
}
if compq != compz {
return // Do not handle mixed case
}
comp := compq
aux := zeros(n, n, n)
switch comp {
case lapack.OrthoUnit:
if compq != compz {
// Verify reduction only when both Q and Z are computed.
return
}
// Zero out the lower triangle of B.
for i := 1; i < n; i++ {
for j := 0; j < i; j++ {
b.Data[i*b.Stride+j] = 0
}
}
aux := zeros(n, n, n)
switch compq {
case lapack.OrthoExplicit:
// Qᵀ*A*Z = H
hCalc := zeros(n, n, n)
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc)
if !equalApproxGeneral(hGot, hCalc, tol) {
t.Errorf("Qᵀ*A*Z != H")
t.Errorf("%v: Qᵀ*A*Z != H", name)
}
// Qᵀ*B*Z = T
tCalc := zeros(n, n, n)
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc)
if !equalApproxGeneral(hGot, hCalc, tol) {
t.Errorf("Qᵀ*B*Z != T")
if !equalApproxGeneral(tGot, tCalc, tol) {
t.Errorf("%v: Qᵀ*B*Z != T", name)
}
case lapack.OrthoEntry:
case lapack.OrthoPostmul:
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
lhs := zeros(n, n, n)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux)
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) // lhs = Q1 * A * Z1ᵀ
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs)
rhs := zeros(n, n, n)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux)
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
if !equalApproxGeneral(lhs, rhs, tol) {
t.Errorf("Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ")
t.Errorf("%v: Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ", name)
}
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
@@ -135,7 +138,7 @@ func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq,
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux)
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
if !equalApproxGeneral(lhs, rhs, tol) {
t.Errorf("Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ")
t.Errorf("%v: Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ", name)
}
}
}

View File

@@ -1198,23 +1198,13 @@ func isUpperHessenberg(h blas64.General) bool {
// isUpperTriangular returns whether a contains only zeros below the diagonal.
func isUpperTriangular(a blas64.General) bool {
if a.Rows != a.Cols {
panic("matrix not square")
}
n := a.Rows
for i := 1; i < n; i++ {
for j := 0; j < i; j++ {
v := a.Data[i*a.Stride+j]
if v != 0 || math.IsNaN(v) {
return false
}
}
}
return true
}
// isNaNFree returns whether a does not contain NaN elements in reachable elements.
func isNaNFree(a blas64.General) bool {
for i := 0; i < a.Rows; i++ {
for j := 0; j < a.Cols; j++ {
if math.IsNaN(a.Data[i*a.Stride+j]) {
if a.Data[i*a.Stride+j] != 0 {
return false
}
}