diff --git a/blas/blas64/blas64.go b/blas/blas64/blas64.go index c336dc89..64ac985c 100644 --- a/blas/blas64/blas64.go +++ b/blas/blas64/blas64.go @@ -20,7 +20,7 @@ func Use(b blas.Float64) { // Implementation returns the current BLAS float64 implementation. // -// Implementation allows direct calls to the current the BLAS float64 implementation +// Implementation allows direct calls to the current BLAS float64 implementation // giving finer control of parameters. func Implementation() blas.Float64 { return blas64 diff --git a/lapack/gonum/dgghrd.go b/lapack/gonum/dgghrd.go new file mode 100644 index 00000000..da09ea9e --- /dev/null +++ b/lapack/gonum/dgghrd.go @@ -0,0 +1,125 @@ +// Copyright ©2023 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" + "gonum.org/v1/gonum/lapack" +) + +// Dgghrd reduces a pair of real matrices (A,B) to generalized upper +// Hessenberg form using orthogonal transformations, where A is a +// general matrix and B is upper triangular. The form of the +// generalized eigenvalue problem is +// +// A*x = lambda*B*x, +// +// and B is typically made upper triangular by computing its QR +// factorization and moving the orthogonal matrix Q to the left side +// of the equation. +// This subroutine simultaneously reduces A to a Hessenberg matrix H: +// +// Qᵀ*A*Z = H +// +// and transforms B to another upper triangular matrix T: +// +// Qᵀ*B*Z = T +// +// in order to reduce the problem to its standard form +// +// H*y = lambda*T*y +// +// where y = Zᵀ*x. +// +// The orthogonal matrices Q and Z are determined as products of Givens +// rotations. They may either be formed explicitly, or they may be +// postmultiplied into input matrices Q1 and Z1, so that +// +// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ +// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ +// +// If Q1 is the orthogonal matrix from the QR factorization of B in the +// original equation A*x = lambda*B*x, then Dgghrd reduces the original +// problem to generalized Hessenberg form. +// +// Dgghrd is an internal routine. It is exported for testing purposes. +func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) { + switch { + case compq != lapack.OrthoNone && compq != lapack.OrthoEntry && compq != lapack.OrthoUnit: + panic(badOrthoComp) + case compz != lapack.OrthoNone && compz != lapack.OrthoEntry && compz != lapack.OrthoUnit: + panic(badOrthoComp) + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(b) < (n-1)*ldb+n: + panic(shortB) + case n < 0: + panic(nLT0) + case ilo < 0: + panic(badIlo) + case ihi < ilo-1 || ihi >= n: + panic(badIhi) + case lda < max(1, n): + panic(badLdA) + case ldb < max(1, n): + panic(badLdB) + case (compq != lapack.OrthoNone && ldq < n) || ldq < 1: + panic(badLdQ) + case (compz != lapack.OrthoNone && ldz < n) || ldz < 1: + panic(badLdZ) + case compq != lapack.OrthoNone && len(q) < (n-1)*ldq+n: + panic(shortQ) + case compz != lapack.OrthoNone && len(z) < (n-1)*ldz+n: + panic(shortZ) + } + + if compq == lapack.OrthoUnit { + impl.Dlaset(blas.All, n, n, 0, 1, q, ldq) + } + if compz == lapack.OrthoUnit { + impl.Dlaset(blas.All, n, n, 0, 1, z, ldz) + } + if n <= 1 { + return // Quick return if possible. + } + + // Zero out lower triangle of B. + for i := 1; i < n; i++ { + for j := 0; j < i; j++ { + b[i*ldb+j] = 0 + } + } + bi := blas64.Implementation() + // Reduce A and B. + for jcol := ilo; jcol <= ihi-2; jcol++ { + for jrow := ihi; jrow >= jcol+2; jrow-- { + // Step 1: rotate rows JROW-1, JROW to kill A(JROW,JCOL). + var c, s float64 + c, s, a[(jrow-1)*lda+jcol] = impl.Dlartg(a[(jrow-1)*lda+jcol], a[jrow*lda+jcol]) + a[jrow*lda+jcol] = 0 + bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1, + a[jrow*lda+jcol+1:], 1, c, s) + + bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1, + b[jrow*ldb+jrow-1:], 1, c, s) + + if compq != lapack.OrthoNone { + bi.Drot(n, q[jrow-1:], ldq, q[jrow:], ldq, c, s) + } + + // Step 2: rotate columns JROW, JROW-1 to kill B(JROW,JROW-1). + c, s, b[jrow*ldb+jrow] = impl.Dlartg(b[jrow*ldb+jrow], b[jrow*ldb+jrow-1]) + b[jrow*ldb+jrow-1] = 0 + + bi.Drot(ihi+1, a[jrow:], lda, a[jrow-1:], lda, c, s) + bi.Drot(jrow, b[jrow:], ldb, b[jrow-1:], ldb, c, s) + + if compz != lapack.OrthoNone { + bi.Drot(n, z[jrow:], ldz, z[jrow-1:], ldz, c, s) + } + } + } +} diff --git a/lapack/gonum/errors.go b/lapack/gonum/errors.go index 47652841..711cc2d5 100644 --- a/lapack/gonum/errors.go +++ b/lapack/gonum/errors.go @@ -21,6 +21,7 @@ const ( badMatrixType = "lapack: bad MatrixType" badMaximizeNormXJob = "lapack: bad MaximizeNormXJob" badNorm = "lapack: bad Norm" + badOrthoComp = "lapack: bad OrthoComp" badPivot = "lapack: bad Pivot" badRightEVJob = "lapack: bad RightEVJob" badSVDJob = "lapack: bad SVDJob" diff --git a/lapack/gonum/lapack_test.go b/lapack/gonum/lapack_test.go index 4a36c44d..55d50616 100644 --- a/lapack/gonum/lapack_test.go +++ b/lapack/gonum/lapack_test.go @@ -148,6 +148,11 @@ func TestDgetrs(t *testing.T) { testlapack.DgetrsTest(t, impl) } +func TestDgghrd(t *testing.T) { + t.Parallel() + testlapack.DgghrdTest(t, impl) +} + func TestDggsvd3(t *testing.T) { t.Parallel() testlapack.Dggsvd3Test(t, impl) diff --git a/lapack/lapack.go b/lapack/lapack.go index 5f60438f..72111d66 100644 --- a/lapack/lapack.go +++ b/lapack/lapack.go @@ -226,3 +226,12 @@ const ( LocalLookAhead MaximizeNormXJob = 0 // Solve Z*x=h-f where h is a vector of ±1. NormalizedNullVector MaximizeNormXJob = 2 // Compute an approximate null-vector e of Z, normalize e and solve Z*x=±e-f. ) + +// OrthoComp specifies whether and how the orthogonal matrix is computed in Dgghrd. +type OrthoComp byte + +const ( + OrthoNone OrthoComp = 'N' // Do not compute orthogonal matrix. + OrthoUnit OrthoComp = 'I' // Argument is initialized to the unit matrix and the orthogonal matrix is returned. + OrthoEntry OrthoComp = 'V' // Argument Q contains orthogonal matrix Q1 on entry and the product Q1*Q is returned. +) diff --git a/lapack/testlapack/dgghrd.go b/lapack/testlapack/dgghrd.go new file mode 100644 index 00000000..02b0bc86 --- /dev/null +++ b/lapack/testlapack/dgghrd.go @@ -0,0 +1,141 @@ +// Copyright ©2023 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testlapack + +import ( + "math" + "testing" + + "golang.org/x/exp/rand" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" + "gonum.org/v1/gonum/lapack" +) + +type Dgghrder interface { + Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) +} + +func DgghrdTest(t *testing.T, impl Dgghrder) { + const tol = 1e-13 + const ldAdd = 5 + rnd := rand.New(rand.NewSource(1)) + comps := []lapack.OrthoComp{lapack.OrthoUnit, lapack.OrthoNone, lapack.OrthoEntry} + for _, compq := range comps { + for _, compz := range comps { + for _, n := range []int{2, 0, 1, 4, 15} { + ldMin := max(1, n) + for _, lda := range []int{ldMin, ldMin + ldAdd} { + for _, ldb := range []int{ldMin, ldMin + ldAdd} { + for _, ldq := range []int{ldMin, ldMin + ldAdd} { + for _, ldz := range []int{ldMin, ldMin + ldAdd} { + testDgghrd(t, impl, rnd, tol, compq, compz, n, 0, n-1, lda, ldb, ldq, ldz) + } + } + } + } + } + } + } +} + +func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) { + a := randomGeneral(n, n, lda, rnd) + b := blockedUpperTriGeneral(n, n, 0, n, ldb, false, rnd) + var q, q1, z, z1 blas64.General + if compq == lapack.OrthoEntry { + q = randomOrthogonal(n, rnd) + q1 = cloneGeneral(q) + } else { + q = nanGeneral(n, n, ldq) + } + if compz == lapack.OrthoEntry { + z = randomOrthogonal(n, rnd) + z1 = cloneGeneral(z) + } else { + z = nanGeneral(n, n, ldz) + } + + hGot := cloneGeneral(a) + tGot := cloneGeneral(b) + for i := 1; i < n; i++ { + for j := 0; j < i; j++ { + // Set all lower tri elems to NaN to catch bad implementations. + tGot.Data[i*tGot.Stride+j] = math.NaN() + } + } + impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, q.Stride, z.Data, z.Stride) + if n == 0 { + return + } + if !isUpperHessenberg(hGot) { + t.Error("H is not upper Hessenberg") + } + if !isNaNFree(tGot) || !isNaNFree(hGot) { + t.Error("T or H is/or not NaN free") + } + if !isUpperTriangular(tGot) { + t.Error("T is not upper triangular") + } + if compq == lapack.OrthoNone { + if !isAllNaN(q.Data) { + t.Errorf("Q is not NaN") + } + return + } + if compz == lapack.OrthoNone { + if !isAllNaN(z.Data) { + t.Errorf("Z is not NaN") + } + return + } + if compq != compz { + return // Do not handle mixed case + } + comp := compq + aux := zeros(n, n, n) + + switch comp { + case lapack.OrthoUnit: + // Qᵀ*A*Z = H + hCalc := zeros(n, n, n) + blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux) + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc) + if !equalApproxGeneral(hGot, hCalc, tol) { + t.Errorf("Qᵀ*A*Z != H") + } + + // Qᵀ*B*Z = T + tCalc := zeros(n, n, n) + blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux) + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc) + if !equalApproxGeneral(hGot, hCalc, tol) { + t.Errorf("Qᵀ*B*Z != T") + } + case lapack.OrthoEntry: + // Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ + lhs := zeros(n, n, n) + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux) + blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) // lhs = Q1 * A * Z1ᵀ + + rhs := zeros(n, n, n) + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux) + blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs) + if !equalApproxGeneral(lhs, rhs, tol) { + t.Errorf("Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ") + } + + // Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, b, 0, aux) + blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) + + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux) + blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs) + if !equalApproxGeneral(lhs, rhs, tol) { + t.Errorf("Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ") + } + } +} diff --git a/lapack/testlapack/general.go b/lapack/testlapack/general.go index 9dd14143..2291dcab 100644 --- a/lapack/testlapack/general.go +++ b/lapack/testlapack/general.go @@ -1201,7 +1201,20 @@ func isUpperTriangular(a blas64.General) bool { n := a.Rows for i := 1; i < n; i++ { for j := 0; j < i; j++ { - if a.Data[i*a.Stride+j] != 0 { + v := a.Data[i*a.Stride+j] + if v != 0 || math.IsNaN(v) { + return false + } + } + } + return true +} + +// isNaNFree returns whether a does not contain NaN elements in reachable elements. +func isNaNFree(a blas64.General) bool { + for i := 0; i < a.Rows; i++ { + for j := 0; j < a.Cols; j++ { + if math.IsNaN(a.Data[i*a.Stride+j]) { return false } }