mirror of
https://github.com/gonum/gonum.git
synced 2025-10-24 07:34:11 +08:00
Add Dorgtr and test
This commit is contained in:
89
native/dorgtr.go
Normal file
89
native/dorgtr.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import "github.com/gonum/blas"
|
||||||
|
|
||||||
|
// Dorgtr generates a real orthogonal matrix Q which is defined as the product
|
||||||
|
// of n-1 elementary reflectors of order n as returned by Dsytrd.
|
||||||
|
//
|
||||||
|
// The construction of Q depends on the value of uplo:
|
||||||
|
// Q = H[n-1] * ... * H[1] * H[0] if uplo == blas.Upper
|
||||||
|
// Q = H[0] * ... * H[n-1] if uplo == blas.Lower
|
||||||
|
// where H[i] is constructed from the elementary reflectors as computed by Dsytrd.
|
||||||
|
// See the documentation for Dsytrd for more information.
|
||||||
|
//
|
||||||
|
// tau must have length at least n-1, and Dorgtr will panic otherwise.
|
||||||
|
//
|
||||||
|
// work is temporary storage, and lwork specifies the usable memory length. At
|
||||||
|
// minimum, lwork >= n-1, and Dorgtr will panic otherwise. The amount of blocking
|
||||||
|
// is limited by the usable length.
|
||||||
|
// If lwork == -1, instead of computing Dorgtr the optimal work length is stored
|
||||||
|
// into work[0].
|
||||||
|
func (impl Implementation) Dorgtr(uplo blas.Uplo, n int, a []float64, lda int, tau, work []float64, lwork int) {
|
||||||
|
checkMatrix(n, n, a, lda)
|
||||||
|
if len(tau) < n-1 {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
upper := uplo == blas.Upper
|
||||||
|
var nb int
|
||||||
|
if upper {
|
||||||
|
nb = impl.Ilaenv(1, "DORGQL", " ", n-1, n-1, n-1, -1)
|
||||||
|
} else {
|
||||||
|
nb = impl.Ilaenv(1, "DORGQR", " ", n-1, n-1, n-1, -1)
|
||||||
|
}
|
||||||
|
lworkopt := max(1, n-1) * nb
|
||||||
|
work[0] = float64(lworkopt)
|
||||||
|
if lwork == -1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(work) < lwork {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
if lwork < n-1 {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if upper {
|
||||||
|
// Q was determined by a call to Dsytrd with uplo == blas.Upper.
|
||||||
|
// Shift the vectors which define the elementary reflectors one column
|
||||||
|
// to the left, and set the last row and column of Q to those of the unit
|
||||||
|
// matrix.
|
||||||
|
for j := 0; j < n-1; j++ {
|
||||||
|
for i := 0; i < j; i++ {
|
||||||
|
a[i*lda+j] = a[i*lda+j+1]
|
||||||
|
}
|
||||||
|
a[(n-1)*lda+j] = 0
|
||||||
|
}
|
||||||
|
for i := 0; i < n-1; i++ {
|
||||||
|
a[i*lda+n-1] = 0
|
||||||
|
}
|
||||||
|
a[(n-1)*lda+n-1] = 1
|
||||||
|
|
||||||
|
// Generate Q[0:n-1, 0:n-1].
|
||||||
|
impl.Dorgql(n-1, n-1, n-1, a, lda, tau, work, lwork)
|
||||||
|
} else {
|
||||||
|
// Q was determined by a call to Dsytrd with uplo == blas.Upper.
|
||||||
|
// Shift the vectors which define the elementary reflectors one column
|
||||||
|
// to the right, and set the first row and column of Q to those of the unit
|
||||||
|
// matrix.
|
||||||
|
for j := n - 1; j > 0; j-- {
|
||||||
|
a[j] = 0
|
||||||
|
for i := j + 1; i < n; i++ {
|
||||||
|
a[i*lda+j] = a[i*lda+j-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
a[0] = 1
|
||||||
|
for i := 1; i < n; i++ {
|
||||||
|
a[i*lda] = 0
|
||||||
|
}
|
||||||
|
if n > 1 {
|
||||||
|
// Generate Q[1:n, 1:n].
|
||||||
|
impl.Dorgqr(n-1, n-1, n-1, a[lda+1:], lda, tau, work, lwork)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -135,7 +135,7 @@ func (impl Implementation) Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d
|
|||||||
// Update the unreduced submatrix A[i+ib:n, i+ib:n], using an update
|
// Update the unreduced submatrix A[i+ib:n, i+ib:n], using an update
|
||||||
// of the form A = A + V*W^T - W*V^T.
|
// of the form A = A + V*W^T - W*V^T.
|
||||||
bi.Dsyr2k(uplo, blas.NoTrans, n-i-nb, nb, -1, a[(i+nb)*lda+i:], lda,
|
bi.Dsyr2k(uplo, blas.NoTrans, n-i-nb, nb, -1, a[(i+nb)*lda+i:], lda,
|
||||||
work[nb+i:], ldwork, 1, a[(i+nb)*lda+i+nb:], lda)
|
work[nb*ldwork:], ldwork, 1, a[(i+nb)*lda+i+nb:], lda)
|
||||||
|
|
||||||
// Copy subdiagonal elements back into A, and diagonal elements into D.
|
// Copy subdiagonal elements back into A, and diagonal elements into D.
|
||||||
for j := i; j < i+nb; j++ {
|
for j := i; j < i+nb; j++ {
|
||||||
|
|||||||
@@ -188,6 +188,10 @@ func TestDorgqr(t *testing.T) {
|
|||||||
testlapack.DorgqrTest(t, impl)
|
testlapack.DorgqrTest(t, impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDorgtr(t *testing.T) {
|
||||||
|
testlapack.DorgtrTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDormbr(t *testing.T) {
|
func TestDormbr(t *testing.T) {
|
||||||
testlapack.DormbrTest(t, impl)
|
testlapack.DormbrTest(t, impl)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func DorgqrTest(t *testing.T, impl Dorgqrer) {
|
|||||||
{10, 10, 10, 0},
|
{10, 10, 10, 0},
|
||||||
{10, 10, 10, 20},
|
{10, 10, 10, 20},
|
||||||
{30, 10, 10, 0},
|
{30, 10, 10, 0},
|
||||||
{30, 20, 10, 0},
|
{30, 20, 10, 20},
|
||||||
|
|
||||||
{100, 100, 100, 0},
|
{100, 100, 100, 0},
|
||||||
{100, 100, 50, 0},
|
{100, 100, 50, 0},
|
||||||
|
|||||||
117
testlapack/dorgtr.go
Normal file
117
testlapack/dorgtr.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dorgtrer interface {
|
||||||
|
Dorgtr(uplo blas.Uplo, n int, a []float64, lda int, tau, work []float64, lwork int)
|
||||||
|
Dsytrder
|
||||||
|
}
|
||||||
|
|
||||||
|
func DorgtrTest(t *testing.T, impl Dorgtrer) {
|
||||||
|
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
|
||||||
|
for _, test := range []struct {
|
||||||
|
n, lda int
|
||||||
|
}{
|
||||||
|
{6, 0},
|
||||||
|
{33, 0},
|
||||||
|
{100, 0},
|
||||||
|
|
||||||
|
{6, 10},
|
||||||
|
{33, 50},
|
||||||
|
{100, 120},
|
||||||
|
} {
|
||||||
|
n := test.n
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = n
|
||||||
|
}
|
||||||
|
a := make([]float64, n*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.NormFloat64()
|
||||||
|
}
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
|
||||||
|
d := make([]float64, n)
|
||||||
|
e := make([]float64, n-1)
|
||||||
|
tau := make([]float64, n-1)
|
||||||
|
work := make([]float64, 1)
|
||||||
|
impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, -1)
|
||||||
|
work = make([]float64, int(work[0]))
|
||||||
|
impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work))
|
||||||
|
|
||||||
|
impl.Dorgtr(uplo, n, a, lda, tau, work, -1)
|
||||||
|
work = make([]float64, int(work[0]))
|
||||||
|
for i := range work {
|
||||||
|
work[i] = math.NaN()
|
||||||
|
}
|
||||||
|
impl.Dorgtr(uplo, n, a, lda, tau, work, len(work))
|
||||||
|
|
||||||
|
q := blas64.General{
|
||||||
|
Rows: n,
|
||||||
|
Cols: n,
|
||||||
|
Stride: lda,
|
||||||
|
Data: a,
|
||||||
|
}
|
||||||
|
tri := blas64.General{
|
||||||
|
Rows: n,
|
||||||
|
Cols: n,
|
||||||
|
Stride: n,
|
||||||
|
Data: make([]float64, n*n),
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
tri.Data[i*tri.Stride+i] = d[i]
|
||||||
|
if i != n-1 {
|
||||||
|
tri.Data[i*tri.Stride+i+1] = e[i]
|
||||||
|
tri.Data[(i+1)*tri.Stride+i] = e[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
aMat := blas64.General{
|
||||||
|
Rows: n,
|
||||||
|
Cols: n,
|
||||||
|
Stride: n,
|
||||||
|
Data: make([]float64, n*n),
|
||||||
|
}
|
||||||
|
if uplo == blas.Upper {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := i; j < n; j++ {
|
||||||
|
v := aCopy[i*lda+j]
|
||||||
|
aMat.Data[i*aMat.Stride+j] = v
|
||||||
|
aMat.Data[j*aMat.Stride+i] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := 0; j <= i; j++ {
|
||||||
|
v := aCopy[i*lda+j]
|
||||||
|
aMat.Data[i*aMat.Stride+j] = v
|
||||||
|
aMat.Data[j*aMat.Stride+i] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp := blas64.General{Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n)}
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, q, 0, tmp)
|
||||||
|
|
||||||
|
ans := blas64.General{Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n)}
|
||||||
|
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, tmp, 0, ans)
|
||||||
|
|
||||||
|
if !floats.EqualApprox(ans.Data, tri.Data, 1e-8) {
|
||||||
|
t.Errorf("Recombination mismatch. n = %v, isUpper = %v", n, uplo == blas.Upper)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user