mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 07:37:03 +08:00
lapack/gonum: clean up Dgghrd and its test
This commit is contained in:

committed by
Vladimír Chalupecký

parent
f0a57a452a
commit
7df15c334b
@@ -10,57 +10,45 @@ import (
|
||||
"gonum.org/v1/gonum/lapack"
|
||||
)
|
||||
|
||||
// Dgghrd reduces a pair of real matrices (A,B) to generalized upper
|
||||
// Hessenberg form using orthogonal transformations, where A is a
|
||||
// general matrix and B is upper triangular. The form of the
|
||||
// generalized eigenvalue problem is
|
||||
// Dgghrd reduces a pair of real matrices (A,B) to generalized upper Hessenberg
|
||||
// form using orthogonal transformations, where A is a general matrix and B is
|
||||
// upper triangular.
|
||||
//
|
||||
// A*x = lambda*B*x,
|
||||
// This subroutine simultaneously reduces A to a Hessenberg matrix H
|
||||
//
|
||||
// and B is typically made upper triangular by computing its QR
|
||||
// factorization and moving the orthogonal matrix Q to the left side
|
||||
// of the equation.
|
||||
// This subroutine simultaneously reduces A to a Hessenberg matrix H:
|
||||
// Qᵀ*A*Z = H,
|
||||
//
|
||||
// Qᵀ*A*Z = H
|
||||
// and transforms B to another upper triangular matrix T
|
||||
//
|
||||
// and transforms B to another upper triangular matrix T:
|
||||
//
|
||||
// Qᵀ*B*Z = T
|
||||
//
|
||||
// in order to reduce the problem to its standard form
|
||||
//
|
||||
// H*y = lambda*T*y
|
||||
//
|
||||
// where y = Zᵀ*x.
|
||||
// Qᵀ*B*Z = T.
|
||||
//
|
||||
// The orthogonal matrices Q and Z are determined as products of Givens
|
||||
// rotations. They may either be formed explicitly, or they may be
|
||||
// postmultiplied into input matrices Q1 and Z1, so that
|
||||
// rotations. They may either be formed explicitly (lapack.OrthoExplicit), or
|
||||
// they may be postmultiplied into input matrices Q1 and Z1
|
||||
// (lapack.OrthoPostmul), so that
|
||||
//
|
||||
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
|
||||
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
|
||||
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ,
|
||||
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ.
|
||||
//
|
||||
// If Q1 is the orthogonal matrix from the QR factorization of B in the
|
||||
// original equation A*x = lambda*B*x, then Dgghrd reduces the original
|
||||
// problem to generalized Hessenberg form.
|
||||
// ilo and ihi determine the block of A that will be reduced. It must hold that
|
||||
//
|
||||
// - 0 <= ilo <= ihi < n if n > 0,
|
||||
// - ilo == 0 and ihi == -1 if n == 0,
|
||||
//
|
||||
// otherwise Dgghrd will panic.
|
||||
//
|
||||
// Dgghrd is an internal routine. It is exported for testing purposes.
|
||||
func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) {
|
||||
switch {
|
||||
case compq != lapack.OrthoNone && compq != lapack.OrthoEntry && compq != lapack.OrthoUnit:
|
||||
case compq != lapack.OrthoNone && compq != lapack.OrthoExplicit && compq != lapack.OrthoPostmul:
|
||||
panic(badOrthoComp)
|
||||
case compz != lapack.OrthoNone && compz != lapack.OrthoEntry && compz != lapack.OrthoUnit:
|
||||
case compz != lapack.OrthoNone && compz != lapack.OrthoExplicit && compz != lapack.OrthoPostmul:
|
||||
panic(badOrthoComp)
|
||||
case len(a) < (n-1)*lda+n:
|
||||
panic(shortA)
|
||||
case len(b) < (n-1)*ldb+n:
|
||||
panic(shortB)
|
||||
case n < 0:
|
||||
panic(nLT0)
|
||||
case ilo < 0:
|
||||
case ilo < 0 || max(0, n-1) < ilo:
|
||||
panic(badIlo)
|
||||
case ihi < ilo-1 || ihi >= n:
|
||||
case ihi < min(ilo, n-1) || n <= ihi:
|
||||
panic(badIhi)
|
||||
case lda < max(1, n):
|
||||
panic(badLdA)
|
||||
@@ -70,20 +58,34 @@ func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int
|
||||
panic(badLdQ)
|
||||
case (compz != lapack.OrthoNone && ldz < n) || ldz < 1:
|
||||
panic(badLdZ)
|
||||
}
|
||||
|
||||
// Quick return if possible.
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(a) < (n-1)*lda+n:
|
||||
panic(shortA)
|
||||
case len(b) < (n-1)*ldb+n:
|
||||
panic(shortB)
|
||||
case compq != lapack.OrthoNone && len(q) < (n-1)*ldq+n:
|
||||
panic(shortQ)
|
||||
case compz != lapack.OrthoNone && len(z) < (n-1)*ldz+n:
|
||||
panic(shortZ)
|
||||
}
|
||||
|
||||
if compq == lapack.OrthoUnit {
|
||||
if compq == lapack.OrthoExplicit {
|
||||
impl.Dlaset(blas.All, n, n, 0, 1, q, ldq)
|
||||
}
|
||||
if compz == lapack.OrthoUnit {
|
||||
if compz == lapack.OrthoExplicit {
|
||||
impl.Dlaset(blas.All, n, n, 0, 1, z, ldz)
|
||||
}
|
||||
if n <= 1 {
|
||||
return // Quick return if possible.
|
||||
|
||||
// Quick return if possible.
|
||||
if n == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// Zero out lower triangle of B.
|
||||
@@ -96,21 +98,19 @@ func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int
|
||||
// Reduce A and B.
|
||||
for jcol := ilo; jcol <= ihi-2; jcol++ {
|
||||
for jrow := ihi; jrow >= jcol+2; jrow-- {
|
||||
// Step 1: rotate rows JROW-1, JROW to kill A(JROW,JCOL).
|
||||
// Step 1: rotate rows jrow-1, jrow to kill A[jrow,jcol].
|
||||
var c, s float64
|
||||
c, s, a[(jrow-1)*lda+jcol] = impl.Dlartg(a[(jrow-1)*lda+jcol], a[jrow*lda+jcol])
|
||||
a[jrow*lda+jcol] = 0
|
||||
bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1,
|
||||
a[jrow*lda+jcol+1:], 1, c, s)
|
||||
|
||||
bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1,
|
||||
b[jrow*ldb+jrow-1:], 1, c, s)
|
||||
bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1, a[jrow*lda+jcol+1:], 1, c, s)
|
||||
bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1, b[jrow*ldb+jrow-1:], 1, c, s)
|
||||
|
||||
if compq != lapack.OrthoNone {
|
||||
bi.Drot(n, q[jrow-1:], ldq, q[jrow:], ldq, c, s)
|
||||
}
|
||||
|
||||
// Step 2: rotate columns JROW, JROW-1 to kill B(JROW,JROW-1).
|
||||
// Step 2: rotate columns jrow, jrow-1 to kill B[jrow,jrow-1].
|
||||
c, s, b[jrow*ldb+jrow] = impl.Dlartg(b[jrow*ldb+jrow], b[jrow*ldb+jrow-1])
|
||||
b[jrow*ldb+jrow-1] = 0
|
||||
|
||||
|
@@ -231,7 +231,7 @@ const (
|
||||
type OrthoComp byte
|
||||
|
||||
const (
|
||||
OrthoNone OrthoComp = 'N' // Do not compute orthogonal matrix.
|
||||
OrthoUnit OrthoComp = 'I' // Argument is initialized to the unit matrix and the orthogonal matrix is returned.
|
||||
OrthoEntry OrthoComp = 'V' // Argument Q contains orthogonal matrix Q1 on entry and the product Q1*Q is returned.
|
||||
OrthoNone OrthoComp = 'N' // Do not compute the orthogonal matrix.
|
||||
OrthoExplicit OrthoComp = 'I' // The orthogonal matrix is formed explicitly and returned in the argument.
|
||||
OrthoPostmul OrthoComp = 'V' // The orthogonal matrix is post-multiplied into the matrix stored in the argument on entry.
|
||||
)
|
||||
|
@@ -5,7 +5,7 @@
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
@@ -20,112 +20,115 @@ type Dgghrder interface {
|
||||
}
|
||||
|
||||
func DgghrdTest(t *testing.T, impl Dgghrder) {
|
||||
const tol = 1e-13
|
||||
const ldAdd = 5
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
comps := []lapack.OrthoComp{lapack.OrthoUnit, lapack.OrthoNone, lapack.OrthoEntry}
|
||||
comps := []lapack.OrthoComp{lapack.OrthoExplicit, lapack.OrthoNone, lapack.OrthoPostmul}
|
||||
for _, compq := range comps {
|
||||
for _, compz := range comps {
|
||||
for _, n := range []int{2, 0, 1, 4, 15} {
|
||||
ldMin := max(1, n)
|
||||
for _, lda := range []int{ldMin, ldMin + ldAdd} {
|
||||
for _, ldb := range []int{ldMin, ldMin + ldAdd} {
|
||||
for _, ldq := range []int{ldMin, ldMin + ldAdd} {
|
||||
for _, ldz := range []int{ldMin, ldMin + ldAdd} {
|
||||
testDgghrd(t, impl, rnd, tol, compq, compz, n, 0, n-1, lda, ldb, ldq, ldz)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, n := range []int{0, 1, 2, 3, 4, 15} {
|
||||
for _, ld := range []int{max(1, n), n + 5} {
|
||||
testDgghrd(t, impl, rnd, compq, compz, n, 0, n-1, ld, ld, ld, ld)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) {
|
||||
func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) {
|
||||
const tol = 1e-13
|
||||
|
||||
a := randomGeneral(n, n, lda, rnd)
|
||||
b := blockedUpperTriGeneral(n, n, 0, n, ldb, false, rnd)
|
||||
var q, q1, z, z1 blas64.General
|
||||
if compq == lapack.OrthoEntry {
|
||||
b := randomGeneral(n, n, ldb, rnd)
|
||||
|
||||
var q, q1 blas64.General
|
||||
switch compq {
|
||||
case lapack.OrthoExplicit:
|
||||
// Initialize q to a non-orthogonal matrix, Dgghrd should overwrite it
|
||||
// with an orthogonal Q.
|
||||
q = randomGeneral(n, n, ldq, rnd)
|
||||
case lapack.OrthoPostmul:
|
||||
// Initialize q to an orthogonal matrix Q1, so that the result Q1*Q is
|
||||
// again orthogonal.
|
||||
q = randomOrthogonal(n, rnd)
|
||||
q1 = cloneGeneral(q)
|
||||
} else {
|
||||
q = nanGeneral(n, n, ldq)
|
||||
}
|
||||
if compz == lapack.OrthoEntry {
|
||||
|
||||
var z, z1 blas64.General
|
||||
switch compz {
|
||||
case lapack.OrthoExplicit:
|
||||
z = randomGeneral(n, n, ldz, rnd)
|
||||
case lapack.OrthoPostmul:
|
||||
z = randomOrthogonal(n, rnd)
|
||||
z1 = cloneGeneral(z)
|
||||
} else {
|
||||
z = nanGeneral(n, n, ldz)
|
||||
}
|
||||
|
||||
hGot := cloneGeneral(a)
|
||||
tGot := cloneGeneral(b)
|
||||
for i := 1; i < n; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
// Set all lower tri elems to NaN to catch bad implementations.
|
||||
tGot.Data[i*tGot.Stride+j] = math.NaN()
|
||||
}
|
||||
}
|
||||
impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, q.Stride, z.Data, z.Stride)
|
||||
impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, max(1, q.Stride), z.Data, max(1, z.Stride))
|
||||
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("Case compq=%v,compz=%v,n=%v,ilo=%v,ihi=%v", compq, compz, n, ilo, ihi)
|
||||
|
||||
if !isUpperHessenberg(hGot) {
|
||||
t.Error("H is not upper Hessenberg")
|
||||
}
|
||||
if !isNaNFree(tGot) || !isNaNFree(hGot) {
|
||||
t.Error("T or H is/or not NaN free")
|
||||
t.Errorf("%v: H is not upper Hessenberg", name)
|
||||
}
|
||||
if !isUpperTriangular(tGot) {
|
||||
t.Error("T is not upper triangular")
|
||||
t.Errorf("%v: T is not upper triangular", name)
|
||||
}
|
||||
if compq == lapack.OrthoNone {
|
||||
if !isAllNaN(q.Data) {
|
||||
t.Errorf("Q is not NaN")
|
||||
if compq != lapack.OrthoNone {
|
||||
if resid := residualOrthogonal(q, true); resid > tol {
|
||||
t.Errorf("%v: Q is not orthogonal, resid=%v", name, resid)
|
||||
}
|
||||
return
|
||||
}
|
||||
if compz == lapack.OrthoNone {
|
||||
if !isAllNaN(z.Data) {
|
||||
t.Errorf("Z is not NaN")
|
||||
if compz != lapack.OrthoNone {
|
||||
if resid := residualOrthogonal(z, true); resid > tol {
|
||||
t.Errorf("%v: Z is not orthogonal, resid=%v", name, resid)
|
||||
}
|
||||
return
|
||||
}
|
||||
if compq != compz {
|
||||
return // Do not handle mixed case
|
||||
}
|
||||
comp := compq
|
||||
aux := zeros(n, n, n)
|
||||
|
||||
switch comp {
|
||||
case lapack.OrthoUnit:
|
||||
if compq != compz {
|
||||
// Verify reduction only when both Q and Z are computed.
|
||||
return
|
||||
}
|
||||
|
||||
// Zero out the lower triangle of B.
|
||||
for i := 1; i < n; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
b.Data[i*b.Stride+j] = 0
|
||||
}
|
||||
}
|
||||
|
||||
aux := zeros(n, n, n)
|
||||
switch compq {
|
||||
case lapack.OrthoExplicit:
|
||||
// Qᵀ*A*Z = H
|
||||
hCalc := zeros(n, n, n)
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc)
|
||||
if !equalApproxGeneral(hGot, hCalc, tol) {
|
||||
t.Errorf("Qᵀ*A*Z != H")
|
||||
t.Errorf("%v: Qᵀ*A*Z != H", name)
|
||||
}
|
||||
|
||||
// Qᵀ*B*Z = T
|
||||
tCalc := zeros(n, n, n)
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc)
|
||||
if !equalApproxGeneral(hGot, hCalc, tol) {
|
||||
t.Errorf("Qᵀ*B*Z != T")
|
||||
if !equalApproxGeneral(tGot, tCalc, tol) {
|
||||
t.Errorf("%v: Qᵀ*B*Z != T", name)
|
||||
}
|
||||
case lapack.OrthoEntry:
|
||||
case lapack.OrthoPostmul:
|
||||
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
|
||||
lhs := zeros(n, n, n)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux)
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) // lhs = Q1 * A * Z1ᵀ
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs)
|
||||
|
||||
rhs := zeros(n, n, n)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux)
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
|
||||
if !equalApproxGeneral(lhs, rhs, tol) {
|
||||
t.Errorf("Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ")
|
||||
t.Errorf("%v: Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ", name)
|
||||
}
|
||||
|
||||
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
|
||||
@@ -135,7 +138,7 @@ func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq,
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux)
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
|
||||
if !equalApproxGeneral(lhs, rhs, tol) {
|
||||
t.Errorf("Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ")
|
||||
t.Errorf("%v: Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1198,23 +1198,13 @@ func isUpperHessenberg(h blas64.General) bool {
|
||||
|
||||
// isUpperTriangular returns whether a contains only zeros below the diagonal.
|
||||
func isUpperTriangular(a blas64.General) bool {
|
||||
if a.Rows != a.Cols {
|
||||
panic("matrix not square")
|
||||
}
|
||||
n := a.Rows
|
||||
for i := 1; i < n; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
v := a.Data[i*a.Stride+j]
|
||||
if v != 0 || math.IsNaN(v) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isNaNFree returns whether a does not contain NaN elements in reachable elements.
|
||||
func isNaNFree(a blas64.General) bool {
|
||||
for i := 0; i < a.Rows; i++ {
|
||||
for j := 0; j < a.Cols; j++ {
|
||||
if math.IsNaN(a.Data[i*a.Stride+j]) {
|
||||
if a.Data[i*a.Stride+j] != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user