lapack/gonum: add Dgghrd and its test

This commit is contained in:
Patricio Whittingslow
2023-09-07 08:24:34 -03:00
committed by GitHub
parent 7bed099d44
commit f0a57a452a
7 changed files with 296 additions and 2 deletions

View File

@@ -20,7 +20,7 @@ func Use(b blas.Float64) {
// Implementation returns the current BLAS float64 implementation.
//
// Implementation allows direct calls to the current the BLAS float64 implementation
// Implementation allows direct calls to the current BLAS float64 implementation
// giving finer control of parameters.
func Implementation() blas.Float64 {
return blas64

125
lapack/gonum/dgghrd.go Normal file
View File

@@ -0,0 +1,125 @@
// Copyright ©2023 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gonum
import (
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
"gonum.org/v1/gonum/lapack"
)
// Dgghrd reduces a pair of real matrices (A,B) to generalized upper
// Hessenberg form using orthogonal transformations, where A is a
// general matrix and B is upper triangular. The form of the
// generalized eigenvalue problem is
//
// A*x = lambda*B*x,
//
// and B is typically made upper triangular by computing its QR
// factorization and moving the orthogonal matrix Q to the left side
// of the equation.
// This subroutine simultaneously reduces A to a Hessenberg matrix H:
//
// Qᵀ*A*Z = H
//
// and transforms B to another upper triangular matrix T:
//
// Qᵀ*B*Z = T
//
// in order to reduce the problem to its standard form
//
// H*y = lambda*T*y
//
// where y = Zᵀ*x.
//
// The orthogonal matrices Q and Z are determined as products of Givens
// rotations. They may either be formed explicitly, or they may be
// postmultiplied into input matrices Q1 and Z1, so that
//
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
//
// If Q1 is the orthogonal matrix from the QR factorization of B in the
// original equation A*x = lambda*B*x, then Dgghrd reduces the original
// problem to generalized Hessenberg form.
//
// Dgghrd is an internal routine. It is exported for testing purposes.
func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) {
switch {
case compq != lapack.OrthoNone && compq != lapack.OrthoEntry && compq != lapack.OrthoUnit:
panic(badOrthoComp)
case compz != lapack.OrthoNone && compz != lapack.OrthoEntry && compz != lapack.OrthoUnit:
panic(badOrthoComp)
case len(a) < (n-1)*lda+n:
panic(shortA)
case len(b) < (n-1)*ldb+n:
panic(shortB)
case n < 0:
panic(nLT0)
case ilo < 0:
panic(badIlo)
case ihi < ilo-1 || ihi >= n:
panic(badIhi)
case lda < max(1, n):
panic(badLdA)
case ldb < max(1, n):
panic(badLdB)
case (compq != lapack.OrthoNone && ldq < n) || ldq < 1:
panic(badLdQ)
case (compz != lapack.OrthoNone && ldz < n) || ldz < 1:
panic(badLdZ)
case compq != lapack.OrthoNone && len(q) < (n-1)*ldq+n:
panic(shortQ)
case compz != lapack.OrthoNone && len(z) < (n-1)*ldz+n:
panic(shortZ)
}
if compq == lapack.OrthoUnit {
impl.Dlaset(blas.All, n, n, 0, 1, q, ldq)
}
if compz == lapack.OrthoUnit {
impl.Dlaset(blas.All, n, n, 0, 1, z, ldz)
}
if n <= 1 {
return // Quick return if possible.
}
// Zero out lower triangle of B.
for i := 1; i < n; i++ {
for j := 0; j < i; j++ {
b[i*ldb+j] = 0
}
}
bi := blas64.Implementation()
// Reduce A and B.
for jcol := ilo; jcol <= ihi-2; jcol++ {
for jrow := ihi; jrow >= jcol+2; jrow-- {
// Step 1: rotate rows JROW-1, JROW to kill A(JROW,JCOL).
var c, s float64
c, s, a[(jrow-1)*lda+jcol] = impl.Dlartg(a[(jrow-1)*lda+jcol], a[jrow*lda+jcol])
a[jrow*lda+jcol] = 0
bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1,
a[jrow*lda+jcol+1:], 1, c, s)
bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1,
b[jrow*ldb+jrow-1:], 1, c, s)
if compq != lapack.OrthoNone {
bi.Drot(n, q[jrow-1:], ldq, q[jrow:], ldq, c, s)
}
// Step 2: rotate columns JROW, JROW-1 to kill B(JROW,JROW-1).
c, s, b[jrow*ldb+jrow] = impl.Dlartg(b[jrow*ldb+jrow], b[jrow*ldb+jrow-1])
b[jrow*ldb+jrow-1] = 0
bi.Drot(ihi+1, a[jrow:], lda, a[jrow-1:], lda, c, s)
bi.Drot(jrow, b[jrow:], ldb, b[jrow-1:], ldb, c, s)
if compz != lapack.OrthoNone {
bi.Drot(n, z[jrow:], ldz, z[jrow-1:], ldz, c, s)
}
}
}
}

View File

@@ -21,6 +21,7 @@ const (
badMatrixType = "lapack: bad MatrixType"
badMaximizeNormXJob = "lapack: bad MaximizeNormXJob"
badNorm = "lapack: bad Norm"
badOrthoComp = "lapack: bad OrthoComp"
badPivot = "lapack: bad Pivot"
badRightEVJob = "lapack: bad RightEVJob"
badSVDJob = "lapack: bad SVDJob"

View File

@@ -148,6 +148,11 @@ func TestDgetrs(t *testing.T) {
testlapack.DgetrsTest(t, impl)
}
func TestDgghrd(t *testing.T) {
t.Parallel()
testlapack.DgghrdTest(t, impl)
}
func TestDggsvd3(t *testing.T) {
t.Parallel()
testlapack.Dggsvd3Test(t, impl)

View File

@@ -226,3 +226,12 @@ const (
LocalLookAhead MaximizeNormXJob = 0 // Solve Z*x=h-f where h is a vector of ±1.
NormalizedNullVector MaximizeNormXJob = 2 // Compute an approximate null-vector e of Z, normalize e and solve Z*x=±e-f.
)
// OrthoComp specifies whether and how the orthogonal matrix is computed in Dgghrd.
type OrthoComp byte
const (
OrthoNone OrthoComp = 'N' // Do not compute orthogonal matrix.
OrthoUnit OrthoComp = 'I' // Argument is initialized to the unit matrix and the orthogonal matrix is returned.
OrthoEntry OrthoComp = 'V' // Argument Q contains orthogonal matrix Q1 on entry and the product Q1*Q is returned.
)

141
lapack/testlapack/dgghrd.go Normal file
View File

@@ -0,0 +1,141 @@
// Copyright ©2023 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testlapack
import (
"math"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
"gonum.org/v1/gonum/lapack"
)
type Dgghrder interface {
Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int)
}
func DgghrdTest(t *testing.T, impl Dgghrder) {
const tol = 1e-13
const ldAdd = 5
rnd := rand.New(rand.NewSource(1))
comps := []lapack.OrthoComp{lapack.OrthoUnit, lapack.OrthoNone, lapack.OrthoEntry}
for _, compq := range comps {
for _, compz := range comps {
for _, n := range []int{2, 0, 1, 4, 15} {
ldMin := max(1, n)
for _, lda := range []int{ldMin, ldMin + ldAdd} {
for _, ldb := range []int{ldMin, ldMin + ldAdd} {
for _, ldq := range []int{ldMin, ldMin + ldAdd} {
for _, ldz := range []int{ldMin, ldMin + ldAdd} {
testDgghrd(t, impl, rnd, tol, compq, compz, n, 0, n-1, lda, ldb, ldq, ldz)
}
}
}
}
}
}
}
}
func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) {
a := randomGeneral(n, n, lda, rnd)
b := blockedUpperTriGeneral(n, n, 0, n, ldb, false, rnd)
var q, q1, z, z1 blas64.General
if compq == lapack.OrthoEntry {
q = randomOrthogonal(n, rnd)
q1 = cloneGeneral(q)
} else {
q = nanGeneral(n, n, ldq)
}
if compz == lapack.OrthoEntry {
z = randomOrthogonal(n, rnd)
z1 = cloneGeneral(z)
} else {
z = nanGeneral(n, n, ldz)
}
hGot := cloneGeneral(a)
tGot := cloneGeneral(b)
for i := 1; i < n; i++ {
for j := 0; j < i; j++ {
// Set all lower tri elems to NaN to catch bad implementations.
tGot.Data[i*tGot.Stride+j] = math.NaN()
}
}
impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, q.Stride, z.Data, z.Stride)
if n == 0 {
return
}
if !isUpperHessenberg(hGot) {
t.Error("H is not upper Hessenberg")
}
if !isNaNFree(tGot) || !isNaNFree(hGot) {
t.Error("T or H is/or not NaN free")
}
if !isUpperTriangular(tGot) {
t.Error("T is not upper triangular")
}
if compq == lapack.OrthoNone {
if !isAllNaN(q.Data) {
t.Errorf("Q is not NaN")
}
return
}
if compz == lapack.OrthoNone {
if !isAllNaN(z.Data) {
t.Errorf("Z is not NaN")
}
return
}
if compq != compz {
return // Do not handle mixed case
}
comp := compq
aux := zeros(n, n, n)
switch comp {
case lapack.OrthoUnit:
// Qᵀ*A*Z = H
hCalc := zeros(n, n, n)
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc)
if !equalApproxGeneral(hGot, hCalc, tol) {
t.Errorf("Qᵀ*A*Z != H")
}
// Qᵀ*B*Z = T
tCalc := zeros(n, n, n)
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc)
if !equalApproxGeneral(hGot, hCalc, tol) {
t.Errorf("Qᵀ*B*Z != T")
}
case lapack.OrthoEntry:
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
lhs := zeros(n, n, n)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux)
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) // lhs = Q1 * A * Z1ᵀ
rhs := zeros(n, n, n)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux)
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
if !equalApproxGeneral(lhs, rhs, tol) {
t.Errorf("Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ")
}
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, b, 0, aux)
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux)
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
if !equalApproxGeneral(lhs, rhs, tol) {
t.Errorf("Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ")
}
}
}

View File

@@ -1201,7 +1201,20 @@ func isUpperTriangular(a blas64.General) bool {
n := a.Rows
for i := 1; i < n; i++ {
for j := 0; j < i; j++ {
if a.Data[i*a.Stride+j] != 0 {
v := a.Data[i*a.Stride+j]
if v != 0 || math.IsNaN(v) {
return false
}
}
}
return true
}
// isNaNFree returns whether a does not contain NaN elements in reachable elements.
func isNaNFree(a blas64.General) bool {
for i := 0; i < a.Rows; i++ {
for j := 0; j < a.Cols; j++ {
if math.IsNaN(a.Data[i*a.Stride+j]) {
return false
}
}