mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00
lapack/gonum: add Dgghrd and its test
This commit is contained in:

committed by
GitHub

parent
7bed099d44
commit
f0a57a452a
@@ -20,7 +20,7 @@ func Use(b blas.Float64) {
|
||||
|
||||
// Implementation returns the current BLAS float64 implementation.
|
||||
//
|
||||
// Implementation allows direct calls to the current the BLAS float64 implementation
|
||||
// Implementation allows direct calls to the current BLAS float64 implementation
|
||||
// giving finer control of parameters.
|
||||
func Implementation() blas.Float64 {
|
||||
return blas64
|
||||
|
125
lapack/gonum/dgghrd.go
Normal file
125
lapack/gonum/dgghrd.go
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright ©2023 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gonum
|
||||
|
||||
import (
|
||||
"gonum.org/v1/gonum/blas"
|
||||
"gonum.org/v1/gonum/blas/blas64"
|
||||
"gonum.org/v1/gonum/lapack"
|
||||
)
|
||||
|
||||
// Dgghrd reduces a pair of real matrices (A,B) to generalized upper
|
||||
// Hessenberg form using orthogonal transformations, where A is a
|
||||
// general matrix and B is upper triangular. The form of the
|
||||
// generalized eigenvalue problem is
|
||||
//
|
||||
// A*x = lambda*B*x,
|
||||
//
|
||||
// and B is typically made upper triangular by computing its QR
|
||||
// factorization and moving the orthogonal matrix Q to the left side
|
||||
// of the equation.
|
||||
// This subroutine simultaneously reduces A to a Hessenberg matrix H:
|
||||
//
|
||||
// Qᵀ*A*Z = H
|
||||
//
|
||||
// and transforms B to another upper triangular matrix T:
|
||||
//
|
||||
// Qᵀ*B*Z = T
|
||||
//
|
||||
// in order to reduce the problem to its standard form
|
||||
//
|
||||
// H*y = lambda*T*y
|
||||
//
|
||||
// where y = Zᵀ*x.
|
||||
//
|
||||
// The orthogonal matrices Q and Z are determined as products of Givens
|
||||
// rotations. They may either be formed explicitly, or they may be
|
||||
// postmultiplied into input matrices Q1 and Z1, so that
|
||||
//
|
||||
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
|
||||
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
|
||||
//
|
||||
// If Q1 is the orthogonal matrix from the QR factorization of B in the
|
||||
// original equation A*x = lambda*B*x, then Dgghrd reduces the original
|
||||
// problem to generalized Hessenberg form.
|
||||
//
|
||||
// Dgghrd is an internal routine. It is exported for testing purposes.
|
||||
func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) {
|
||||
switch {
|
||||
case compq != lapack.OrthoNone && compq != lapack.OrthoEntry && compq != lapack.OrthoUnit:
|
||||
panic(badOrthoComp)
|
||||
case compz != lapack.OrthoNone && compz != lapack.OrthoEntry && compz != lapack.OrthoUnit:
|
||||
panic(badOrthoComp)
|
||||
case len(a) < (n-1)*lda+n:
|
||||
panic(shortA)
|
||||
case len(b) < (n-1)*ldb+n:
|
||||
panic(shortB)
|
||||
case n < 0:
|
||||
panic(nLT0)
|
||||
case ilo < 0:
|
||||
panic(badIlo)
|
||||
case ihi < ilo-1 || ihi >= n:
|
||||
panic(badIhi)
|
||||
case lda < max(1, n):
|
||||
panic(badLdA)
|
||||
case ldb < max(1, n):
|
||||
panic(badLdB)
|
||||
case (compq != lapack.OrthoNone && ldq < n) || ldq < 1:
|
||||
panic(badLdQ)
|
||||
case (compz != lapack.OrthoNone && ldz < n) || ldz < 1:
|
||||
panic(badLdZ)
|
||||
case compq != lapack.OrthoNone && len(q) < (n-1)*ldq+n:
|
||||
panic(shortQ)
|
||||
case compz != lapack.OrthoNone && len(z) < (n-1)*ldz+n:
|
||||
panic(shortZ)
|
||||
}
|
||||
|
||||
if compq == lapack.OrthoUnit {
|
||||
impl.Dlaset(blas.All, n, n, 0, 1, q, ldq)
|
||||
}
|
||||
if compz == lapack.OrthoUnit {
|
||||
impl.Dlaset(blas.All, n, n, 0, 1, z, ldz)
|
||||
}
|
||||
if n <= 1 {
|
||||
return // Quick return if possible.
|
||||
}
|
||||
|
||||
// Zero out lower triangle of B.
|
||||
for i := 1; i < n; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
b[i*ldb+j] = 0
|
||||
}
|
||||
}
|
||||
bi := blas64.Implementation()
|
||||
// Reduce A and B.
|
||||
for jcol := ilo; jcol <= ihi-2; jcol++ {
|
||||
for jrow := ihi; jrow >= jcol+2; jrow-- {
|
||||
// Step 1: rotate rows JROW-1, JROW to kill A(JROW,JCOL).
|
||||
var c, s float64
|
||||
c, s, a[(jrow-1)*lda+jcol] = impl.Dlartg(a[(jrow-1)*lda+jcol], a[jrow*lda+jcol])
|
||||
a[jrow*lda+jcol] = 0
|
||||
bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1,
|
||||
a[jrow*lda+jcol+1:], 1, c, s)
|
||||
|
||||
bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1,
|
||||
b[jrow*ldb+jrow-1:], 1, c, s)
|
||||
|
||||
if compq != lapack.OrthoNone {
|
||||
bi.Drot(n, q[jrow-1:], ldq, q[jrow:], ldq, c, s)
|
||||
}
|
||||
|
||||
// Step 2: rotate columns JROW, JROW-1 to kill B(JROW,JROW-1).
|
||||
c, s, b[jrow*ldb+jrow] = impl.Dlartg(b[jrow*ldb+jrow], b[jrow*ldb+jrow-1])
|
||||
b[jrow*ldb+jrow-1] = 0
|
||||
|
||||
bi.Drot(ihi+1, a[jrow:], lda, a[jrow-1:], lda, c, s)
|
||||
bi.Drot(jrow, b[jrow:], ldb, b[jrow-1:], ldb, c, s)
|
||||
|
||||
if compz != lapack.OrthoNone {
|
||||
bi.Drot(n, z[jrow:], ldz, z[jrow-1:], ldz, c, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -21,6 +21,7 @@ const (
|
||||
badMatrixType = "lapack: bad MatrixType"
|
||||
badMaximizeNormXJob = "lapack: bad MaximizeNormXJob"
|
||||
badNorm = "lapack: bad Norm"
|
||||
badOrthoComp = "lapack: bad OrthoComp"
|
||||
badPivot = "lapack: bad Pivot"
|
||||
badRightEVJob = "lapack: bad RightEVJob"
|
||||
badSVDJob = "lapack: bad SVDJob"
|
||||
|
@@ -148,6 +148,11 @@ func TestDgetrs(t *testing.T) {
|
||||
testlapack.DgetrsTest(t, impl)
|
||||
}
|
||||
|
||||
func TestDgghrd(t *testing.T) {
|
||||
t.Parallel()
|
||||
testlapack.DgghrdTest(t, impl)
|
||||
}
|
||||
|
||||
func TestDggsvd3(t *testing.T) {
|
||||
t.Parallel()
|
||||
testlapack.Dggsvd3Test(t, impl)
|
||||
|
@@ -226,3 +226,12 @@ const (
|
||||
LocalLookAhead MaximizeNormXJob = 0 // Solve Z*x=h-f where h is a vector of ±1.
|
||||
NormalizedNullVector MaximizeNormXJob = 2 // Compute an approximate null-vector e of Z, normalize e and solve Z*x=±e-f.
|
||||
)
|
||||
|
||||
// OrthoComp specifies whether and how the orthogonal matrix is computed in Dgghrd.
|
||||
type OrthoComp byte
|
||||
|
||||
const (
|
||||
OrthoNone OrthoComp = 'N' // Do not compute orthogonal matrix.
|
||||
OrthoUnit OrthoComp = 'I' // Argument is initialized to the unit matrix and the orthogonal matrix is returned.
|
||||
OrthoEntry OrthoComp = 'V' // Argument Q contains orthogonal matrix Q1 on entry and the product Q1*Q is returned.
|
||||
)
|
||||
|
141
lapack/testlapack/dgghrd.go
Normal file
141
lapack/testlapack/dgghrd.go
Normal file
@@ -0,0 +1,141 @@
|
||||
// Copyright ©2023 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
|
||||
"gonum.org/v1/gonum/blas"
|
||||
"gonum.org/v1/gonum/blas/blas64"
|
||||
"gonum.org/v1/gonum/lapack"
|
||||
)
|
||||
|
||||
type Dgghrder interface {
|
||||
Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int)
|
||||
}
|
||||
|
||||
func DgghrdTest(t *testing.T, impl Dgghrder) {
|
||||
const tol = 1e-13
|
||||
const ldAdd = 5
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
comps := []lapack.OrthoComp{lapack.OrthoUnit, lapack.OrthoNone, lapack.OrthoEntry}
|
||||
for _, compq := range comps {
|
||||
for _, compz := range comps {
|
||||
for _, n := range []int{2, 0, 1, 4, 15} {
|
||||
ldMin := max(1, n)
|
||||
for _, lda := range []int{ldMin, ldMin + ldAdd} {
|
||||
for _, ldb := range []int{ldMin, ldMin + ldAdd} {
|
||||
for _, ldq := range []int{ldMin, ldMin + ldAdd} {
|
||||
for _, ldz := range []int{ldMin, ldMin + ldAdd} {
|
||||
testDgghrd(t, impl, rnd, tol, compq, compz, n, 0, n-1, lda, ldb, ldq, ldz)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) {
|
||||
a := randomGeneral(n, n, lda, rnd)
|
||||
b := blockedUpperTriGeneral(n, n, 0, n, ldb, false, rnd)
|
||||
var q, q1, z, z1 blas64.General
|
||||
if compq == lapack.OrthoEntry {
|
||||
q = randomOrthogonal(n, rnd)
|
||||
q1 = cloneGeneral(q)
|
||||
} else {
|
||||
q = nanGeneral(n, n, ldq)
|
||||
}
|
||||
if compz == lapack.OrthoEntry {
|
||||
z = randomOrthogonal(n, rnd)
|
||||
z1 = cloneGeneral(z)
|
||||
} else {
|
||||
z = nanGeneral(n, n, ldz)
|
||||
}
|
||||
|
||||
hGot := cloneGeneral(a)
|
||||
tGot := cloneGeneral(b)
|
||||
for i := 1; i < n; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
// Set all lower tri elems to NaN to catch bad implementations.
|
||||
tGot.Data[i*tGot.Stride+j] = math.NaN()
|
||||
}
|
||||
}
|
||||
impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, q.Stride, z.Data, z.Stride)
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
if !isUpperHessenberg(hGot) {
|
||||
t.Error("H is not upper Hessenberg")
|
||||
}
|
||||
if !isNaNFree(tGot) || !isNaNFree(hGot) {
|
||||
t.Error("T or H is/or not NaN free")
|
||||
}
|
||||
if !isUpperTriangular(tGot) {
|
||||
t.Error("T is not upper triangular")
|
||||
}
|
||||
if compq == lapack.OrthoNone {
|
||||
if !isAllNaN(q.Data) {
|
||||
t.Errorf("Q is not NaN")
|
||||
}
|
||||
return
|
||||
}
|
||||
if compz == lapack.OrthoNone {
|
||||
if !isAllNaN(z.Data) {
|
||||
t.Errorf("Z is not NaN")
|
||||
}
|
||||
return
|
||||
}
|
||||
if compq != compz {
|
||||
return // Do not handle mixed case
|
||||
}
|
||||
comp := compq
|
||||
aux := zeros(n, n, n)
|
||||
|
||||
switch comp {
|
||||
case lapack.OrthoUnit:
|
||||
// Qᵀ*A*Z = H
|
||||
hCalc := zeros(n, n, n)
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc)
|
||||
if !equalApproxGeneral(hGot, hCalc, tol) {
|
||||
t.Errorf("Qᵀ*A*Z != H")
|
||||
}
|
||||
|
||||
// Qᵀ*B*Z = T
|
||||
tCalc := zeros(n, n, n)
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc)
|
||||
if !equalApproxGeneral(hGot, hCalc, tol) {
|
||||
t.Errorf("Qᵀ*B*Z != T")
|
||||
}
|
||||
case lapack.OrthoEntry:
|
||||
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
|
||||
lhs := zeros(n, n, n)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux)
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) // lhs = Q1 * A * Z1ᵀ
|
||||
|
||||
rhs := zeros(n, n, n)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux)
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
|
||||
if !equalApproxGeneral(lhs, rhs, tol) {
|
||||
t.Errorf("Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ")
|
||||
}
|
||||
|
||||
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, b, 0, aux)
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs)
|
||||
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux)
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
|
||||
if !equalApproxGeneral(lhs, rhs, tol) {
|
||||
t.Errorf("Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ")
|
||||
}
|
||||
}
|
||||
}
|
@@ -1201,7 +1201,20 @@ func isUpperTriangular(a blas64.General) bool {
|
||||
n := a.Rows
|
||||
for i := 1; i < n; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
if a.Data[i*a.Stride+j] != 0 {
|
||||
v := a.Data[i*a.Stride+j]
|
||||
if v != 0 || math.IsNaN(v) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isNaNFree returns whether a does not contain NaN elements in reachable elements.
|
||||
func isNaNFree(a blas64.General) bool {
|
||||
for i := 0; i < a.Rows; i++ {
|
||||
for j := 0; j < a.Cols; j++ {
|
||||
if math.IsNaN(a.Data[i*a.Stride+j]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user