mirror of
https://github.com/gonum/gonum.git
synced 2025-10-22 06:39:26 +08:00
Add Dorgql, dependency, and helper function for test
This commit is contained in:
43
native/dgeql2.go
Normal file
43
native/dgeql2.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import "github.com/gonum/blas"
|
||||||
|
|
||||||
|
// Dgeql2 computes the QL factorization of the m×n matrix A. That is, Dgelq2
|
||||||
|
// computes Q and L such that
|
||||||
|
// A = Q * L
|
||||||
|
// where Q is an m×m orthonormal matrix and L is a lower trapezoidal matrix.
|
||||||
|
//
|
||||||
|
// Q is represented as a product of elementary reflectors,
|
||||||
|
// Q = H[k-1] * ... * H[1] * H[0]
|
||||||
|
// where k = min(m,n) and each H[i] has the form
|
||||||
|
// H[i] = I - tau[i] * v_i * v_i^T
|
||||||
|
// Vector v_i has v[m-k+i+1:m] = 0, v[m-k+i] = 1, and v[:m-k+i+1] is stored on
|
||||||
|
// exit in A[0:m-k+i-1, n-k+i].
|
||||||
|
//
|
||||||
|
// tau must have length at least min(m,n), and Dgeql2 will panic otherwise.
|
||||||
|
//
|
||||||
|
// work is temporary memory storage and must have length at least n.
|
||||||
|
func (impl Implementation) Dgeql2(m, n int, a []float64, lda int, tau, work []float64) {
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if len(tau) < min(m, n) {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
if len(work) < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
k := min(m, n)
|
||||||
|
var aii float64
|
||||||
|
for i := k - 1; i >= 0; i-- {
|
||||||
|
// Generate elementary reflector H[i] to annihilate A[0:m-k+i-1, n-k+i].
|
||||||
|
aii, tau[i] = impl.Dlarfg(m-k+i+1, a[(m-k+i)*lda+n-k+i], a[n-k+i:], lda)
|
||||||
|
|
||||||
|
// Apply H[i] to A[0:m-k+i, 0:n-k+i-1] from the left.
|
||||||
|
a[(m-k+i)*lda+n-k+i] = 1
|
||||||
|
impl.Dlarf(blas.Left, m-k+i+1, n-k+i, a[n-k+i:], lda, tau[i], a, lda, work)
|
||||||
|
a[(m-k+i)*lda+n-k+i] = aii
|
||||||
|
}
|
||||||
|
}
|
63
native/dorg2l.go
Normal file
63
native/dorg2l.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dorg2l generates an m×n matrix Q with orthonormal columns which is defined
|
||||||
|
// as the last n columns of a product of k elementary reflectors of order m.
|
||||||
|
// Q = H[k-1] * ... * H[1] * H[0]
|
||||||
|
// See Dgelqf for more information. It must be that m >= n >= k.
|
||||||
|
//
|
||||||
|
// tau contains the scalar reflectors computed by Dgeqlf. tau must have length
|
||||||
|
// at least k, and Dorg2l will panic otherwise.
|
||||||
|
//
|
||||||
|
// work contains temporary memory, and must have length at least n. Dorg2l will
|
||||||
|
// panic otherwise.
|
||||||
|
func (impl Implementation) Dorg2l(m, n, k int, a []float64, lda int, tau, work []float64) {
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if len(tau) < k {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
if len(work) < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
if m < n {
|
||||||
|
panic(mLTN)
|
||||||
|
}
|
||||||
|
if k > n {
|
||||||
|
panic(kGTN)
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize columns 0:n-k to columns of the unit matrix.
|
||||||
|
for j := 0; j < n-k; j++ {
|
||||||
|
for l := 0; l < m; l++ {
|
||||||
|
a[l*lda+j] = 0
|
||||||
|
}
|
||||||
|
a[(m-n+j)*lda+j] = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
bi := blas64.Implementation()
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
ii := n - k + i
|
||||||
|
|
||||||
|
// Apply H[i] to A[0:m-k+i, 0:n-k+i] from the left.
|
||||||
|
a[(m-n+ii)*lda+ii] = 1
|
||||||
|
impl.Dlarf(blas.Left, m-n+ii+1, ii, a[ii:], lda, tau[i], a, lda, work)
|
||||||
|
bi.Dscal(m-n+ii, -tau[i], a[ii:], lda)
|
||||||
|
a[(m-n+ii)*lda+ii] = 1 - tau[i]
|
||||||
|
|
||||||
|
// Set A[m-k+i:m, n-k+i+1] to zero.
|
||||||
|
for l := m - n + ii + 1; l < m; l++ {
|
||||||
|
a[l*lda+ii] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
105
native/dorgql.go
Normal file
105
native/dorgql.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dorgql generates the m×n matrix Q with orthonormal columns defined as the
|
||||||
|
// last n columns of a product of k elementary reflectors of order m as returned
|
||||||
|
// by Dgelqf. That is,
|
||||||
|
// Q = H[k-1] * ... * H[1] * H[0].
|
||||||
|
// See Dgelqf for more information.
|
||||||
|
//
|
||||||
|
// tau must have length at least k, and Dorgql will panic otherwise.
|
||||||
|
//
|
||||||
|
// Work is temporary storage, and lwork specifies the usable memory length. At minimum,
|
||||||
|
// lwork >= n, and Dorgql will panic otherwise. The amount of blocking is
|
||||||
|
// limited by the usable length.
|
||||||
|
// If lwork == -1, instead of computing Dorgql the optimal work length is stored
|
||||||
|
// into work[0].
|
||||||
|
func (impl Implementation) Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) {
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if len(tau) < k {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
nb := impl.Ilaenv(1, "DORGQL", " ", m, n, k, -1)
|
||||||
|
lworkopt := n * nb
|
||||||
|
work[0] = float64(lworkopt)
|
||||||
|
if lwork == -1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if lwork < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
nbmin := 2
|
||||||
|
var nx, ldwork int
|
||||||
|
iws := n
|
||||||
|
if nb > 1 && nb < k {
|
||||||
|
// Determine when to cross over from blocked to unblocked code.
|
||||||
|
nx = max(0, impl.Ilaenv(3, "DORGQL", " ", m, n, k, -1))
|
||||||
|
if nx < k {
|
||||||
|
// Determine if workspace is large enough for blocked code.
|
||||||
|
ldwork = nb
|
||||||
|
iws = n * nb
|
||||||
|
if lwork < iws {
|
||||||
|
// Not enough workspace to use optimal nb: reduce nb and determine
|
||||||
|
// the minimum value of nb.
|
||||||
|
nb = lwork / n
|
||||||
|
nbmin = max(2, impl.Ilaenv(2, "DORGQL", " ", m, n, k, -1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var kk int
|
||||||
|
if nb >= nbmin && nb < k && nx < k {
|
||||||
|
// Use blocked code after the first block. The last kk columns are handled
|
||||||
|
// by the block method.
|
||||||
|
kk = min(k, ((k-nx+nb-1)/nb)*nb)
|
||||||
|
|
||||||
|
// Set A(m-kk:m, 0:n-kk) to zero.
|
||||||
|
for i := m - kk; i < m; i++ {
|
||||||
|
for j := 0; j < n-kk; j++ {
|
||||||
|
a[i*lda+j] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use unblocked code for the first or only block.
|
||||||
|
impl.Dorg2l(m-kk, n-kk, k-kk, a, lda, tau, work)
|
||||||
|
if kk > 0 {
|
||||||
|
// Use blocked code.
|
||||||
|
for i := k - kk; i < k; i += nb {
|
||||||
|
ib := min(nb, k-i)
|
||||||
|
if n-k+i > 0 {
|
||||||
|
// Form the triangular factor of the block reflector
|
||||||
|
// H = H[i+ib-1] * ... * H[i+1] * H[i].
|
||||||
|
impl.Dlarft(lapack.Backward, lapack.ColumnWise, m-k+i+ib, ib,
|
||||||
|
a[n-k+i:], lda, tau[i:], work, ldwork)
|
||||||
|
|
||||||
|
// Apply H to A[0:m-k+i+ib, 0:n-k+i] from the left.
|
||||||
|
impl.Dlarfb(blas.Left, blas.NoTrans, lapack.Backward, lapack.ColumnWise,
|
||||||
|
m-k+i+ib, n-k+i, ib, a[n-k+i:], lda, work, ldwork,
|
||||||
|
a, lda, work[ib*ldwork:], ldwork)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply H to rows 0:m-k+i+ib of current block.
|
||||||
|
impl.Dorg2l(m-k+i+ib, ib, ib, a[n-k+i:], lda, tau[i:], work)
|
||||||
|
|
||||||
|
// Set rows m-k+i+ib:m of current block to zero.
|
||||||
|
for j := n - k + i; j < n-k+i+ib; j++ {
|
||||||
|
for l := m - k + i + ib; l < m; l++ {
|
||||||
|
a[l*lda+j] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -36,6 +36,10 @@ func TestDgelq2(t *testing.T) {
|
|||||||
testlapack.Dgelq2Test(t, impl)
|
testlapack.Dgelq2Test(t, impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDgeql2(t *testing.T) {
|
||||||
|
testlapack.Dgeql2Test(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDgels(t *testing.T) {
|
func TestDgels(t *testing.T) {
|
||||||
testlapack.DgelsTest(t, impl)
|
testlapack.DgelsTest(t, impl)
|
||||||
}
|
}
|
||||||
@@ -164,6 +168,10 @@ func TestDorgbr(t *testing.T) {
|
|||||||
testlapack.DorgbrTest(t, impl)
|
testlapack.DorgbrTest(t, impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDorg2l(t *testing.T) {
|
||||||
|
testlapack.Dorg2lTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDorgl2(t *testing.T) {
|
func TestDorgl2(t *testing.T) {
|
||||||
testlapack.Dorgl2Test(t, impl)
|
testlapack.Dorgl2Test(t, impl)
|
||||||
}
|
}
|
||||||
@@ -172,6 +180,10 @@ func TestDorglq(t *testing.T) {
|
|||||||
testlapack.DorglqTest(t, impl)
|
testlapack.DorglqTest(t, impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDorgql(t *testing.T) {
|
||||||
|
testlapack.DorgqlTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDorgqr(t *testing.T) {
|
func TestDorgqr(t *testing.T) {
|
||||||
testlapack.DorgqrTest(t, impl)
|
testlapack.DorgqrTest(t, impl)
|
||||||
}
|
}
|
||||||
|
98
testlapack/dgeql2.go
Normal file
98
testlapack/dgeql2.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dgeql2er interface {
|
||||||
|
Dgeql2(m, n int, a []float64, lda int, tau, work []float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dgeql2Test(t *testing.T, impl Dgeql2er) {
|
||||||
|
// TODO(btracey): Add tests for m < n.
|
||||||
|
for _, test := range []struct {
|
||||||
|
m, n, lda int
|
||||||
|
}{
|
||||||
|
{5, 5, 0},
|
||||||
|
{5, 3, 0},
|
||||||
|
{5, 4, 0},
|
||||||
|
} {
|
||||||
|
m := test.m
|
||||||
|
n := test.n
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = n
|
||||||
|
}
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.NormFloat64()
|
||||||
|
}
|
||||||
|
tau := nanSlice(min(m, n))
|
||||||
|
work := nanSlice(n)
|
||||||
|
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
impl.Dgeql2(m, n, a, lda, tau, work)
|
||||||
|
|
||||||
|
k := min(m, n)
|
||||||
|
// Construct Q.
|
||||||
|
q := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: m,
|
||||||
|
Stride: m,
|
||||||
|
Data: make([]float64, m*m),
|
||||||
|
}
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
q.Data[i*q.Stride+i] = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
h := blas64.General{Rows: m, Cols: m, Stride: m, Data: make([]float64, m*m)}
|
||||||
|
for j := 0; j < m; j++ {
|
||||||
|
h.Data[j*h.Stride+j] = 1
|
||||||
|
}
|
||||||
|
v := blas64.Vector{Inc: 1, Data: make([]float64, m)}
|
||||||
|
v.Data[m-k+i] = 1
|
||||||
|
for j := 0; j < m-k+i; j++ {
|
||||||
|
v.Data[j] = a[j*lda+n-k+i]
|
||||||
|
}
|
||||||
|
blas64.Ger(-tau[i], v, v, h)
|
||||||
|
qTmp := blas64.General{Rows: q.Rows, Cols: q.Cols, Stride: q.Stride, Data: make([]float64, len(q.Data))}
|
||||||
|
copy(qTmp.Data, q.Data)
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qTmp, 0, q)
|
||||||
|
}
|
||||||
|
if !isOrthonormal(q) {
|
||||||
|
t.Errorf("Q is not orthonormal")
|
||||||
|
}
|
||||||
|
l := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: n,
|
||||||
|
Stride: n,
|
||||||
|
Data: make([]float64, m*n),
|
||||||
|
}
|
||||||
|
if m >= n {
|
||||||
|
for i := m - n; i < m; i++ {
|
||||||
|
for j := 0; j <= min(i-(m-n), n-1); j++ {
|
||||||
|
l.Data[i*l.Stride+j] = a[i*lda+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic("untested")
|
||||||
|
}
|
||||||
|
ans := blas64.General{Rows: m, Cols: n, Stride: lda, Data: make([]float64, len(a))}
|
||||||
|
copy(ans.Data, a)
|
||||||
|
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, l, 0, ans)
|
||||||
|
if !floats.EqualApprox(ans.Data, aCopy, 1e-10) {
|
||||||
|
t.Errorf("Reconstruction mismatch: m = %v, n = %v", m, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -216,30 +216,6 @@ func dlatrdCheckDecomposition(t *testing.T, uplo blas.Uplo, n, nb int, e, tau, a
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// isOrthonormal checks that a general matrix is orthonormal.
|
|
||||||
// TODO(btracey): Replace other tests with a call to this function.
|
|
||||||
func isOrthonormal(q blas64.General) bool {
|
|
||||||
n := q.Rows
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
for j := i; j < n; j++ {
|
|
||||||
dot := blas64.Dot(n,
|
|
||||||
blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]},
|
|
||||||
blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]},
|
|
||||||
)
|
|
||||||
if i == j {
|
|
||||||
if math.Abs(dot-1) > 1e-10 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if math.Abs(dot) > 1e-10 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// genFromSym constructs a (symmetric) general matrix from the data in the
|
// genFromSym constructs a (symmetric) general matrix from the data in the
|
||||||
// symmetric.
|
// symmetric.
|
||||||
// TODO(btracey): Replace other constructions of this with a call to this function.
|
// TODO(btracey): Replace other constructions of this with a call to this function.
|
||||||
|
74
testlapack/dorg2l.go
Normal file
74
testlapack/dorg2l.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dorg2ler interface {
|
||||||
|
Dorg2l(m, n, k int, a []float64, lda int, tau, work []float64)
|
||||||
|
Dgeql2er
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dorg2lTest(t *testing.T, impl Dorg2ler) {
|
||||||
|
for _, test := range []struct {
|
||||||
|
m, n, k, lda int
|
||||||
|
}{
|
||||||
|
{5, 4, 3, 0},
|
||||||
|
{5, 4, 4, 0},
|
||||||
|
{3, 3, 2, 0},
|
||||||
|
{5, 5, 5, 0},
|
||||||
|
} {
|
||||||
|
m := test.m
|
||||||
|
n := test.n
|
||||||
|
k := test.k
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = n
|
||||||
|
}
|
||||||
|
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.NormFloat64()
|
||||||
|
}
|
||||||
|
tau := nanSlice(max(m, n))
|
||||||
|
work := make([]float64, n)
|
||||||
|
impl.Dgeql2(m, n, a, lda, tau, work)
|
||||||
|
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
impl.Dorg2l(m, n, k, a, lda, tau[n-k:], work)
|
||||||
|
if !hasOrthonormalColumns(m, n, a, lda) {
|
||||||
|
t.Errorf("Q is not orthonormal. m = %v, n = %v, k = %v", m, n, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasOrthornormalColumns checks that the columns of a are orthonormal.
|
||||||
|
func hasOrthonormalColumns(m, n int, a []float64, lda int) bool {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := i; j < n; j++ {
|
||||||
|
dot := blas64.Dot(m,
|
||||||
|
blas64.Vector{Inc: lda, Data: a[i:]},
|
||||||
|
blas64.Vector{Inc: lda, Data: a[j:]},
|
||||||
|
)
|
||||||
|
if i == j {
|
||||||
|
if math.Abs(dot-1) > 1e-10 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if math.Abs(dot) > 1e-10 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
60
testlapack/dorgql.go
Normal file
60
testlapack/dorgql.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dorgqler interface {
|
||||||
|
Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
|
||||||
|
Dorg2ler
|
||||||
|
}
|
||||||
|
|
||||||
|
func DorgqlTest(t *testing.T, impl Dorgqler) {
|
||||||
|
for _, test := range []struct {
|
||||||
|
m, n, k, lda int
|
||||||
|
}{
|
||||||
|
{5, 4, 3, 0},
|
||||||
|
{100, 100, 100, 0},
|
||||||
|
{200, 100, 50, 0},
|
||||||
|
{200, 200, 50, 0},
|
||||||
|
} {
|
||||||
|
m := test.m
|
||||||
|
n := test.n
|
||||||
|
k := test.k
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = n
|
||||||
|
}
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.NormFloat64()
|
||||||
|
}
|
||||||
|
tau := nanSlice(min(m, n))
|
||||||
|
work := nanSlice(max(m, n))
|
||||||
|
|
||||||
|
impl.Dgeql2(m, n, a, lda, tau, work)
|
||||||
|
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
|
||||||
|
impl.Dorg2l(m, n, k, a, lda, tau, work)
|
||||||
|
ans := make([]float64, len(a))
|
||||||
|
copy(ans, a)
|
||||||
|
|
||||||
|
impl.Dorgql(m, n, k, a, lda, tau, work, -1)
|
||||||
|
work = make([]float64, int(work[0]))
|
||||||
|
copy(a, aCopy)
|
||||||
|
impl.Dorgql(m, n, k, a, lda, tau, work, len(work))
|
||||||
|
|
||||||
|
if !floats.EqualApprox(a, ans, 1e-8) {
|
||||||
|
t.Errorf("Answer mismatch. m = %v, n = %v, k = %v", m, n, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -571,3 +571,26 @@ func printRowise(a []float64, m, n, lda int, beyond bool) {
|
|||||||
fmt.Println(a[i*lda : i*lda+end])
|
fmt.Println(a[i*lda : i*lda+end])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isOrthonormal checks that a general matrix is orthonormal.
|
||||||
|
func isOrthonormal(q blas64.General) bool {
|
||||||
|
n := q.Rows
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := i; j < n; j++ {
|
||||||
|
dot := blas64.Dot(n,
|
||||||
|
blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]},
|
||||||
|
blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]},
|
||||||
|
)
|
||||||
|
if i == j {
|
||||||
|
if math.Abs(dot-1) > 1e-10 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if math.Abs(dot) > 1e-10 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user