diff --git a/native/dlatrd.go b/native/dlatrd.go new file mode 100644 index 00000000..596c2ce9 --- /dev/null +++ b/native/dlatrd.go @@ -0,0 +1,143 @@ +// Copyright ©2016 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package native + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Dlatrd reduces nb rows and columns of a real n×n symmetric matrix A to symmetric +// tridiagonal form. It computes the orthonormal similarity transformation +// Q^T * A * Q +// and returns the matrices V and W to apply to the unreduced part of A. If +// uplo == blas.Upper, the upper triangle is supplied and the last nb rows are +// reduced. If uplo == blas.Lower, the lower triangle is supplied and the first +// nb rows are reduced. +// +// a contains the symmetric matrix on entry with active triangular half specified +// by uplo. On exit, the nb columns have been reduced to tridiagonal form. The +// diagonal contains the diagonal of the reduced matrix, the off-diagonal is +// set to 1, and the remaining elements contain the data to construct Q. +// +// If uplo == blas.Upper, with n = 5 and nb = 2 on exit a is +// [a a a v4 v5] +// [ a a v4 v5] +// [ a 1 v5] +// [ d 1] +// [ d] +// +// If uplo == blas.Lower, with n = 5 and nb = 2, on exit a is +// [d ] +// [1 d ] +// [v1 1 a ] +// [v1 v2 a a ] +// [v1 v2 a a a] +// +// e contains the superdiagonal elements of the reduced matrix. If uplo == blas.Upper, +// e[n-nb:n-1] contains the last nb columns of the reduced matrix, while if +// uplo == blas.Lower, e[:nb] contains the first nb columns of the reduced matrix. +// e must have length at least n-1, and Dlatrd will panic otherwise. +// +// tau contains the scalar factors of the elementary reflectors needed to construct Q. +// The reflectors are stored in tau[n-nb:n-1] if uplo == blas.Upper, and in +// tau[:nb] if uplo == blas.Lower. tau must have length n-1, and Dlatrd will panic +// otherwise. +// +// w is an n×nb matrix. On exit it contains the data to update the unreduced part +// of A. +// +// The matrix Q is represented as a product of elementary reflectors. Each reflector +// H has the form +// I - tau * v * v^T +// If uplo == blas.Upper, +// Q = H[n] * H[n-1] * ... * H[n-nb+1] +// where v[:i-1] is stored in A[:i-1,i], v[i-1] = 1, and v[i:n] = 0. +// +// If uplo == blas.Lower, +// Q = H[1] * H[2] * ... H[nb] +// where v[1:i+1] = 0, v[i+1] = 1, and v[i+2:n] is stored in A[i+2:n,i]. +// +// The vectors v form the n×nb matrix V which is used with W to apply a +// symmetric rank-2 update to the unreduced part of A +// A = A - V * W^T - W * V^T +func (impl Implementation) Dlatrd(uplo blas.Uplo, n, nb int, a []float64, lda int, e, tau, w []float64, ldw int) { + checkMatrix(n, n, a, lda) + checkMatrix(n, nb, w, ldw) + if len(e) < n-1 { + panic(badE) + } + if len(tau) < n-1 { + panic(badTau) + } + if n <= 0 { + return + } + bi := blas64.Implementation() + if uplo == blas.Upper { + for i := n - 1; i >= n-nb; i-- { + iw := i - n + nb + if i < n-1 { + // Update A(0:i, i). + bi.Dgemv(blas.NoTrans, i+1, n-i-1, -1, a[i+1:], lda, + w[i*ldw+iw+1:], 1, 1, a[i:], lda) + bi.Dgemv(blas.NoTrans, i+1, n-i-1, -1, w[iw+1:], ldw, + a[i*lda+i+1:], 1, 1, a[i:], lda) + } + if i > 0 { + // Generate elementary reflector H(i) to annihilate A(0:i-2,i). + e[i-1], tau[i-1] = impl.Dlarfg(i, a[(i-1)*lda+i], a[i:], lda) + a[(i-1)*lda+i] = 1 + + // Compute W(0:i-1, i). + bi.Dsymv(blas.Upper, i, 1, a, lda, a[i:], lda, 0, w[iw:], ldw) + if i < n-1 { + bi.Dgemv(blas.Trans, i, n-i-1, 1, w[iw+1:], ldw, + a[i:], lda, 0, w[(i+1)*ldw+iw:], ldw) + bi.Dgemv(blas.NoTrans, i, n-i-1, -1, a[i+1:], lda, + w[(i+1)*ldw+iw:], ldw, 1, w[iw:], ldw) + bi.Dgemv(blas.Trans, i, n-i-1, 1, a[i+1:], lda, + a[i:], lda, 0, w[(i+1)*ldw+iw:], ldw) + bi.Dgemv(blas.NoTrans, i, n-i-1, -1, w[iw+1:], ldw, + w[(i+1)*ldw+iw:], ldw, 1, w[iw:], ldw) + } + bi.Dscal(i, tau[i-1], w[iw:], ldw) + alpha := -0.5 * tau[i-1] * bi.Ddot(i, w[iw:], ldw, a[i:], lda) + bi.Daxpy(i, alpha, a[i:], lda, w[iw:], ldw) + } + } + } else { + // Reduce first nb columns of lower triangle. + for i := 0; i < nb; i++ { + // Update A(i:n, i) + bi.Dgemv(blas.NoTrans, n-i, i, -1, a[i*lda:], lda, + w[i*ldw:], 1, 1, a[i*lda+i:], lda) + bi.Dgemv(blas.NoTrans, n-i, i, -1, w[i*ldw:], ldw, + a[i*lda:], 1, 1, a[i*lda+i:], lda) + if i < n-1 { + // Generate elementary reflector H(i) to annihilate A(i+2:n,i). + e[i], tau[i] = impl.Dlarfg(n-i-1, a[(i+1)*lda+i], a[min(i+2, n-1)*lda+i:], lda) + a[(i+1)*lda+i] = 1 + + // Compute W(i+1:n,i). + bi.Dsymv(blas.Lower, n-i-1, 1, a[(i+1)*lda+i+1:], lda, + a[(i+1)*lda+i:], lda, 0, w[(i+1)*ldw+i:], ldw) + bi.Dgemv(blas.Trans, n-i-1, i, 1, w[(i+1)*ldw:], ldw, + a[(i+1)*lda+i:], lda, 0, w[i:], ldw) + bi.Dgemv(blas.NoTrans, n-i-1, i, -1, a[(i+1)*lda:], lda, + w[i:], ldw, 1, w[(i+1)*ldw+i:], ldw) + bi.Dgemv(blas.Trans, n-i-1, i, 1, a[(i+1)*lda:], lda, + a[(i+1)*lda+i:], lda, 0, w[i:], ldw) + bi.Dgemv(blas.NoTrans, n-i-1, i, -1, w[(i+1)*ldw:], ldw, + w[i:], ldw, 1, w[(i+1)*ldw+i:], ldw) + bi.Dscal(n-i-1, tau[i], w[(i+1)*ldw+i:], ldw) + alpha := -0.5 * tau[i] * bi.Ddot(n-i-1, w[(i+1)*ldw+i:], ldw, + a[(i+1)*lda+i:], lda) + bi.Daxpy(n-i-1, alpha, a[(i+1)*lda+i:], lda, + w[(i+1)*ldw+i:], ldw) + } + } + } +} diff --git a/native/lapack_test.go b/native/lapack_test.go index de3b1881..782f9487 100644 --- a/native/lapack_test.go +++ b/native/lapack_test.go @@ -152,6 +152,10 @@ func TestDlasv2(t *testing.T) { testlapack.Dlasv2Test(t, impl) } +func TestDlatrd(t *testing.T) { + testlapack.DlatrdTest(t, impl) +} + func TestDorg2r(t *testing.T) { testlapack.Dorg2rTest(t, impl) } diff --git a/testlapack/dlatrd.go b/testlapack/dlatrd.go new file mode 100644 index 00000000..7b7528a9 --- /dev/null +++ b/testlapack/dlatrd.go @@ -0,0 +1,268 @@ +// Copyright ©2016 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testlapack + +import ( + "fmt" + "math" + "math/rand" + "testing" + + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +type Dlatrder interface { + Dlatrd(uplo blas.Uplo, n, nb int, a []float64, lda int, e, tau, w []float64, ldw int) +} + +func DlatrdTest(t *testing.T, impl Dlatrder) { + for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { + for _, test := range []struct { + n, nb, lda, ldw int + }{ + {5, 2, 0, 0}, + {5, 5, 0, 0}, + + {5, 3, 10, 11}, + {5, 5, 10, 11}, + } { + n := test.n + nb := test.nb + lda := test.lda + if lda == 0 { + lda = n + } + ldw := test.ldw + if ldw == 0 { + ldw = nb + } + + a := make([]float64, n*lda) + for i := range a { + a[i] = rand.NormFloat64() + } + + e := make([]float64, n-1) + for i := range e { + e[i] = math.NaN() + } + tau := make([]float64, n-1) + for i := range tau { + tau[i] = math.NaN() + } + w := make([]float64, n*ldw) + for i := range w { + w[i] = math.NaN() + } + + aCopy := make([]float64, len(a)) + copy(aCopy, a) + + impl.Dlatrd(uplo, n, nb, a, lda, e, tau, w, ldw) + + // Construct Q. + ldq := n + q := blas64.General{ + Rows: n, + Cols: n, + Stride: ldq, + Data: make([]float64, n*ldq), + } + for i := 0; i < n; i++ { + q.Data[i*ldq+i] = 1 + } + if uplo == blas.Upper { + for i := n - 1; i >= n-nb; i-- { + if i == 0 { + continue + } + h := blas64.General{ + Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n), + } + for j := 0; j < n; j++ { + h.Data[j*n+j] = 1 + } + v := blas64.Vector{ + Inc: 1, + Data: make([]float64, n), + } + for j := 0; j < i-1; j++ { + v.Data[j] = a[j*lda+i] + } + v.Data[i-1] = 1 + + blas64.Ger(-tau[i-1], v, v, h) + + qTmp := blas64.General{ + Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n), + } + copy(qTmp.Data, q.Data) + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q) + } + } else { + for i := 0; i < nb; i++ { + if i == n-1 { + continue + } + h := blas64.General{ + Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n), + } + for j := 0; j < n; j++ { + h.Data[j*n+j] = 1 + } + v := blas64.Vector{ + Inc: 1, + Data: make([]float64, n), + } + v.Data[i+1] = 1 + for j := i + 2; j < n; j++ { + v.Data[j] = a[j*lda+i] + } + blas64.Ger(-tau[i], v, v, h) + + qTmp := blas64.General{ + Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n), + } + copy(qTmp.Data, q.Data) + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q) + } + } + errStr := fmt.Sprintf("isUpper = %v, n = %v, nb = %v", uplo == blas.Upper, n, nb) + if !isOrthonormal(q) { + t.Errorf("Q not orthonormal. %s", errStr) + } + aGen := genFromSym(blas64.Symmetric{N: n, Stride: lda, Uplo: uplo, Data: aCopy}) + if !dlatrdCheckDecomposition(t, uplo, n, nb, e, tau, a, lda, aGen, q) { + t.Errorf("Decomposition mismatch. %s", errStr) + } + } + } +} + +// dlatrdCheckDecomposition checks that the first nb rows have been successfully +// reduced. +func dlatrdCheckDecomposition(t *testing.T, uplo blas.Uplo, n, nb int, e, tau, a []float64, lda int, aGen, q blas64.General) bool { + // Compute Q^T * A * Q. + tmp := blas64.General{ + Rows: n, + Cols: n, + Stride: n, + Data: make([]float64, n*n), + } + + ans := blas64.General{ + Rows: n, + Cols: n, + Stride: n, + Data: make([]float64, n*n), + } + + blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aGen, 0, tmp) + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, ans) + + // Compare with T. + if uplo == blas.Upper { + for i := n - 1; i >= n-nb; i-- { + for j := 0; j < n; j++ { + v := ans.Data[i*ans.Stride+j] + switch { + case i == j: + if math.Abs(v-a[i*lda+j]) > 1e-10 { + return false + } + case i == j-1: + if math.Abs(a[i*lda+j]-1) > 1e-10 { + return false + } + if math.Abs(v-e[i]) > 1e-10 { + return false + } + case i == j+1: + default: + if math.Abs(v) > 1e-10 { + return false + } + } + } + } + } else { + for i := 0; i < nb; i++ { + for j := 0; j < n; j++ { + v := ans.Data[i*ans.Stride+j] + switch { + case i == j: + if math.Abs(v-a[i*lda+j]) > 1e-10 { + return false + } + case i == j-1: + case i == j+1: + if math.Abs(a[i*lda+j]-1) > 1e-10 { + return false + } + if math.Abs(v-e[i-1]) > 1e-10 { + return false + } + default: + if math.Abs(v) > 1e-10 { + return false + } + } + } + } + } + return true +} + +// isOrthonormal checks that a general matrix is orthonormal. +// TODO(btracey): Replace other tests with a call to this function. +func isOrthonormal(q blas64.General) bool { + n := q.Rows + for i := 0; i < n; i++ { + for j := i; j < n; j++ { + dot := blas64.Dot(n, + blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]}, + blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]}, + ) + if i == j { + if math.Abs(dot-1) > 1e-10 { + return false + } + } else { + if math.Abs(dot) > 1e-10 { + return false + } + } + } + } + return true +} + +// genFromSym constructs a (symmetric) general matrix from the data in the +// symmetric. +// TODO(btracey): Replace other constructions of this with a call to this function. +func genFromSym(a blas64.Symmetric) blas64.General { + n := a.N + lda := a.Stride + uplo := a.Uplo + b := blas64.General{ + Rows: n, + Cols: n, + Stride: n, + Data: make([]float64, n*n), + } + + for i := 0; i < n; i++ { + for j := i; j < n; j++ { + v := a.Data[i*lda+j] + if uplo == blas.Lower { + v = a.Data[j*lda+i] + } + b.Data[i*n+j] = v + b.Data[j*n+i] = v + } + } + return b +}