mirror of
https://github.com/gonum/gonum.git
synced 2025-10-21 14:19:35 +08:00
Add Dorgql, dependency, and helper function for test
This commit is contained in:
43
native/dgeql2.go
Normal file
43
native/dgeql2.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package native
|
||||
|
||||
import "github.com/gonum/blas"
|
||||
|
||||
// Dgeql2 computes the QL factorization of the m×n matrix A. That is, Dgelq2
|
||||
// computes Q and L such that
|
||||
// A = Q * L
|
||||
// where Q is an m×m orthonormal matrix and L is a lower trapezoidal matrix.
|
||||
//
|
||||
// Q is represented as a product of elementary reflectors,
|
||||
// Q = H[k-1] * ... * H[1] * H[0]
|
||||
// where k = min(m,n) and each H[i] has the form
|
||||
// H[i] = I - tau[i] * v_i * v_i^T
|
||||
// Vector v_i has v[m-k+i+1:m] = 0, v[m-k+i] = 1, and v[:m-k+i+1] is stored on
|
||||
// exit in A[0:m-k+i-1, n-k+i].
|
||||
//
|
||||
// tau must have length at least min(m,n), and Dgeql2 will panic otherwise.
|
||||
//
|
||||
// work is temporary memory storage and must have length at least n.
|
||||
func (impl Implementation) Dgeql2(m, n int, a []float64, lda int, tau, work []float64) {
|
||||
checkMatrix(m, n, a, lda)
|
||||
if len(tau) < min(m, n) {
|
||||
panic(badTau)
|
||||
}
|
||||
if len(work) < n {
|
||||
panic(badWork)
|
||||
}
|
||||
k := min(m, n)
|
||||
var aii float64
|
||||
for i := k - 1; i >= 0; i-- {
|
||||
// Generate elementary reflector H[i] to annihilate A[0:m-k+i-1, n-k+i].
|
||||
aii, tau[i] = impl.Dlarfg(m-k+i+1, a[(m-k+i)*lda+n-k+i], a[n-k+i:], lda)
|
||||
|
||||
// Apply H[i] to A[0:m-k+i, 0:n-k+i-1] from the left.
|
||||
a[(m-k+i)*lda+n-k+i] = 1
|
||||
impl.Dlarf(blas.Left, m-k+i+1, n-k+i, a[n-k+i:], lda, tau[i], a, lda, work)
|
||||
a[(m-k+i)*lda+n-k+i] = aii
|
||||
}
|
||||
}
|
63
native/dorg2l.go
Normal file
63
native/dorg2l.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package native
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
// Dorg2l generates an m×n matrix Q with orthonormal columns which is defined
|
||||
// as the last n columns of a product of k elementary reflectors of order m.
|
||||
// Q = H[k-1] * ... * H[1] * H[0]
|
||||
// See Dgelqf for more information. It must be that m >= n >= k.
|
||||
//
|
||||
// tau contains the scalar reflectors computed by Dgeqlf. tau must have length
|
||||
// at least k, and Dorg2l will panic otherwise.
|
||||
//
|
||||
// work contains temporary memory, and must have length at least n. Dorg2l will
|
||||
// panic otherwise.
|
||||
func (impl Implementation) Dorg2l(m, n, k int, a []float64, lda int, tau, work []float64) {
|
||||
checkMatrix(m, n, a, lda)
|
||||
if len(tau) < k {
|
||||
panic(badTau)
|
||||
}
|
||||
if len(work) < n {
|
||||
panic(badWork)
|
||||
}
|
||||
if m < n {
|
||||
panic(mLTN)
|
||||
}
|
||||
if k > n {
|
||||
panic(kGTN)
|
||||
}
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize columns 0:n-k to columns of the unit matrix.
|
||||
for j := 0; j < n-k; j++ {
|
||||
for l := 0; l < m; l++ {
|
||||
a[l*lda+j] = 0
|
||||
}
|
||||
a[(m-n+j)*lda+j] = 1
|
||||
}
|
||||
|
||||
bi := blas64.Implementation()
|
||||
for i := 0; i < k; i++ {
|
||||
ii := n - k + i
|
||||
|
||||
// Apply H[i] to A[0:m-k+i, 0:n-k+i] from the left.
|
||||
a[(m-n+ii)*lda+ii] = 1
|
||||
impl.Dlarf(blas.Left, m-n+ii+1, ii, a[ii:], lda, tau[i], a, lda, work)
|
||||
bi.Dscal(m-n+ii, -tau[i], a[ii:], lda)
|
||||
a[(m-n+ii)*lda+ii] = 1 - tau[i]
|
||||
|
||||
// Set A[m-k+i:m, n-k+i+1] to zero.
|
||||
for l := m - n + ii + 1; l < m; l++ {
|
||||
a[l*lda+ii] = 0
|
||||
}
|
||||
}
|
||||
}
|
105
native/dorgql.go
Normal file
105
native/dorgql.go
Normal file
@@ -0,0 +1,105 @@
|
||||
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package native
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
// Dorgql generates the m×n matrix Q with orthonormal columns defined as the
|
||||
// last n columns of a product of k elementary reflectors of order m as returned
|
||||
// by Dgelqf. That is,
|
||||
// Q = H[k-1] * ... * H[1] * H[0].
|
||||
// See Dgelqf for more information.
|
||||
//
|
||||
// tau must have length at least k, and Dorgql will panic otherwise.
|
||||
//
|
||||
// Work is temporary storage, and lwork specifies the usable memory length. At minimum,
|
||||
// lwork >= n, and Dorgql will panic otherwise. The amount of blocking is
|
||||
// limited by the usable length.
|
||||
// If lwork == -1, instead of computing Dorgql the optimal work length is stored
|
||||
// into work[0].
|
||||
func (impl Implementation) Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) {
|
||||
checkMatrix(m, n, a, lda)
|
||||
if len(tau) < k {
|
||||
panic(badTau)
|
||||
}
|
||||
nb := impl.Ilaenv(1, "DORGQL", " ", m, n, k, -1)
|
||||
lworkopt := n * nb
|
||||
work[0] = float64(lworkopt)
|
||||
if lwork == -1 {
|
||||
return
|
||||
}
|
||||
if lwork < n {
|
||||
panic(badWork)
|
||||
}
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
nbmin := 2
|
||||
var nx, ldwork int
|
||||
iws := n
|
||||
if nb > 1 && nb < k {
|
||||
// Determine when to cross over from blocked to unblocked code.
|
||||
nx = max(0, impl.Ilaenv(3, "DORGQL", " ", m, n, k, -1))
|
||||
if nx < k {
|
||||
// Determine if workspace is large enough for blocked code.
|
||||
ldwork = nb
|
||||
iws = n * nb
|
||||
if lwork < iws {
|
||||
// Not enough workspace to use optimal nb: reduce nb and determine
|
||||
// the minimum value of nb.
|
||||
nb = lwork / n
|
||||
nbmin = max(2, impl.Ilaenv(2, "DORGQL", " ", m, n, k, -1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var kk int
|
||||
if nb >= nbmin && nb < k && nx < k {
|
||||
// Use blocked code after the first block. The last kk columns are handled
|
||||
// by the block method.
|
||||
kk = min(k, ((k-nx+nb-1)/nb)*nb)
|
||||
|
||||
// Set A(m-kk:m, 0:n-kk) to zero.
|
||||
for i := m - kk; i < m; i++ {
|
||||
for j := 0; j < n-kk; j++ {
|
||||
a[i*lda+j] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use unblocked code for the first or only block.
|
||||
impl.Dorg2l(m-kk, n-kk, k-kk, a, lda, tau, work)
|
||||
if kk > 0 {
|
||||
// Use blocked code.
|
||||
for i := k - kk; i < k; i += nb {
|
||||
ib := min(nb, k-i)
|
||||
if n-k+i > 0 {
|
||||
// Form the triangular factor of the block reflector
|
||||
// H = H[i+ib-1] * ... * H[i+1] * H[i].
|
||||
impl.Dlarft(lapack.Backward, lapack.ColumnWise, m-k+i+ib, ib,
|
||||
a[n-k+i:], lda, tau[i:], work, ldwork)
|
||||
|
||||
// Apply H to A[0:m-k+i+ib, 0:n-k+i] from the left.
|
||||
impl.Dlarfb(blas.Left, blas.NoTrans, lapack.Backward, lapack.ColumnWise,
|
||||
m-k+i+ib, n-k+i, ib, a[n-k+i:], lda, work, ldwork,
|
||||
a, lda, work[ib*ldwork:], ldwork)
|
||||
}
|
||||
|
||||
// Apply H to rows 0:m-k+i+ib of current block.
|
||||
impl.Dorg2l(m-k+i+ib, ib, ib, a[n-k+i:], lda, tau[i:], work)
|
||||
|
||||
// Set rows m-k+i+ib:m of current block to zero.
|
||||
for j := n - k + i; j < n-k+i+ib; j++ {
|
||||
for l := m - k + i + ib; l < m; l++ {
|
||||
a[l*lda+j] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -36,6 +36,10 @@ func TestDgelq2(t *testing.T) {
|
||||
testlapack.Dgelq2Test(t, impl)
|
||||
}
|
||||
|
||||
func TestDgeql2(t *testing.T) {
|
||||
testlapack.Dgeql2Test(t, impl)
|
||||
}
|
||||
|
||||
func TestDgels(t *testing.T) {
|
||||
testlapack.DgelsTest(t, impl)
|
||||
}
|
||||
@@ -164,6 +168,10 @@ func TestDorgbr(t *testing.T) {
|
||||
testlapack.DorgbrTest(t, impl)
|
||||
}
|
||||
|
||||
func TestDorg2l(t *testing.T) {
|
||||
testlapack.Dorg2lTest(t, impl)
|
||||
}
|
||||
|
||||
func TestDorgl2(t *testing.T) {
|
||||
testlapack.Dorgl2Test(t, impl)
|
||||
}
|
||||
@@ -172,6 +180,10 @@ func TestDorglq(t *testing.T) {
|
||||
testlapack.DorglqTest(t, impl)
|
||||
}
|
||||
|
||||
func TestDorgql(t *testing.T) {
|
||||
testlapack.DorgqlTest(t, impl)
|
||||
}
|
||||
|
||||
func TestDorgqr(t *testing.T) {
|
||||
testlapack.DorgqrTest(t, impl)
|
||||
}
|
||||
|
98
testlapack/dgeql2.go
Normal file
98
testlapack/dgeql2.go
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dgeql2er interface {
|
||||
Dgeql2(m, n int, a []float64, lda int, tau, work []float64)
|
||||
}
|
||||
|
||||
func Dgeql2Test(t *testing.T, impl Dgeql2er) {
|
||||
// TODO(btracey): Add tests for m < n.
|
||||
for _, test := range []struct {
|
||||
m, n, lda int
|
||||
}{
|
||||
{5, 5, 0},
|
||||
{5, 3, 0},
|
||||
{5, 4, 0},
|
||||
} {
|
||||
m := test.m
|
||||
n := test.n
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
a := make([]float64, m*lda)
|
||||
for i := range a {
|
||||
a[i] = rand.NormFloat64()
|
||||
}
|
||||
tau := nanSlice(min(m, n))
|
||||
work := nanSlice(n)
|
||||
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
impl.Dgeql2(m, n, a, lda, tau, work)
|
||||
|
||||
k := min(m, n)
|
||||
// Construct Q.
|
||||
q := blas64.General{
|
||||
Rows: m,
|
||||
Cols: m,
|
||||
Stride: m,
|
||||
Data: make([]float64, m*m),
|
||||
}
|
||||
for i := 0; i < m; i++ {
|
||||
q.Data[i*q.Stride+i] = 1
|
||||
}
|
||||
for i := 0; i < k; i++ {
|
||||
h := blas64.General{Rows: m, Cols: m, Stride: m, Data: make([]float64, m*m)}
|
||||
for j := 0; j < m; j++ {
|
||||
h.Data[j*h.Stride+j] = 1
|
||||
}
|
||||
v := blas64.Vector{Inc: 1, Data: make([]float64, m)}
|
||||
v.Data[m-k+i] = 1
|
||||
for j := 0; j < m-k+i; j++ {
|
||||
v.Data[j] = a[j*lda+n-k+i]
|
||||
}
|
||||
blas64.Ger(-tau[i], v, v, h)
|
||||
qTmp := blas64.General{Rows: q.Rows, Cols: q.Cols, Stride: q.Stride, Data: make([]float64, len(q.Data))}
|
||||
copy(qTmp.Data, q.Data)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qTmp, 0, q)
|
||||
}
|
||||
if !isOrthonormal(q) {
|
||||
t.Errorf("Q is not orthonormal")
|
||||
}
|
||||
l := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, m*n),
|
||||
}
|
||||
if m >= n {
|
||||
for i := m - n; i < m; i++ {
|
||||
for j := 0; j <= min(i-(m-n), n-1); j++ {
|
||||
l.Data[i*l.Stride+j] = a[i*lda+j]
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic("untested")
|
||||
}
|
||||
ans := blas64.General{Rows: m, Cols: n, Stride: lda, Data: make([]float64, len(a))}
|
||||
copy(ans.Data, a)
|
||||
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, l, 0, ans)
|
||||
if !floats.EqualApprox(ans.Data, aCopy, 1e-10) {
|
||||
t.Errorf("Reconstruction mismatch: m = %v, n = %v", m, n)
|
||||
}
|
||||
}
|
||||
}
|
@@ -216,30 +216,6 @@ func dlatrdCheckDecomposition(t *testing.T, uplo blas.Uplo, n, nb int, e, tau, a
|
||||
return true
|
||||
}
|
||||
|
||||
// isOrthonormal checks that a general matrix is orthonormal.
|
||||
// TODO(btracey): Replace other tests with a call to this function.
|
||||
func isOrthonormal(q blas64.General) bool {
|
||||
n := q.Rows
|
||||
for i := 0; i < n; i++ {
|
||||
for j := i; j < n; j++ {
|
||||
dot := blas64.Dot(n,
|
||||
blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]},
|
||||
blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]},
|
||||
)
|
||||
if i == j {
|
||||
if math.Abs(dot-1) > 1e-10 {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if math.Abs(dot) > 1e-10 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// genFromSym constructs a (symmetric) general matrix from the data in the
|
||||
// symmetric.
|
||||
// TODO(btracey): Replace other constructions of this with a call to this function.
|
||||
|
74
testlapack/dorg2l.go
Normal file
74
testlapack/dorg2l.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
type Dorg2ler interface {
|
||||
Dorg2l(m, n, k int, a []float64, lda int, tau, work []float64)
|
||||
Dgeql2er
|
||||
}
|
||||
|
||||
func Dorg2lTest(t *testing.T, impl Dorg2ler) {
|
||||
for _, test := range []struct {
|
||||
m, n, k, lda int
|
||||
}{
|
||||
{5, 4, 3, 0},
|
||||
{5, 4, 4, 0},
|
||||
{3, 3, 2, 0},
|
||||
{5, 5, 5, 0},
|
||||
} {
|
||||
m := test.m
|
||||
n := test.n
|
||||
k := test.k
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
|
||||
a := make([]float64, m*lda)
|
||||
for i := range a {
|
||||
a[i] = rand.NormFloat64()
|
||||
}
|
||||
tau := nanSlice(max(m, n))
|
||||
work := make([]float64, n)
|
||||
impl.Dgeql2(m, n, a, lda, tau, work)
|
||||
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
impl.Dorg2l(m, n, k, a, lda, tau[n-k:], work)
|
||||
if !hasOrthonormalColumns(m, n, a, lda) {
|
||||
t.Errorf("Q is not orthonormal. m = %v, n = %v, k = %v", m, n, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hasOrthornormalColumns checks that the columns of a are orthonormal.
|
||||
func hasOrthonormalColumns(m, n int, a []float64, lda int) bool {
|
||||
for i := 0; i < n; i++ {
|
||||
for j := i; j < n; j++ {
|
||||
dot := blas64.Dot(m,
|
||||
blas64.Vector{Inc: lda, Data: a[i:]},
|
||||
blas64.Vector{Inc: lda, Data: a[j:]},
|
||||
)
|
||||
if i == j {
|
||||
if math.Abs(dot-1) > 1e-10 {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if math.Abs(dot) > 1e-10 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
60
testlapack/dorgql.go
Normal file
60
testlapack/dorgql.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright ©2016 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dorgqler interface {
|
||||
Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
|
||||
Dorg2ler
|
||||
}
|
||||
|
||||
func DorgqlTest(t *testing.T, impl Dorgqler) {
|
||||
for _, test := range []struct {
|
||||
m, n, k, lda int
|
||||
}{
|
||||
{5, 4, 3, 0},
|
||||
{100, 100, 100, 0},
|
||||
{200, 100, 50, 0},
|
||||
{200, 200, 50, 0},
|
||||
} {
|
||||
m := test.m
|
||||
n := test.n
|
||||
k := test.k
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
a := make([]float64, m*lda)
|
||||
for i := range a {
|
||||
a[i] = rand.NormFloat64()
|
||||
}
|
||||
tau := nanSlice(min(m, n))
|
||||
work := nanSlice(max(m, n))
|
||||
|
||||
impl.Dgeql2(m, n, a, lda, tau, work)
|
||||
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
|
||||
impl.Dorg2l(m, n, k, a, lda, tau, work)
|
||||
ans := make([]float64, len(a))
|
||||
copy(ans, a)
|
||||
|
||||
impl.Dorgql(m, n, k, a, lda, tau, work, -1)
|
||||
work = make([]float64, int(work[0]))
|
||||
copy(a, aCopy)
|
||||
impl.Dorgql(m, n, k, a, lda, tau, work, len(work))
|
||||
|
||||
if !floats.EqualApprox(a, ans, 1e-8) {
|
||||
t.Errorf("Answer mismatch. m = %v, n = %v, k = %v", m, n, k)
|
||||
}
|
||||
}
|
||||
}
|
@@ -571,3 +571,26 @@ func printRowise(a []float64, m, n, lda int, beyond bool) {
|
||||
fmt.Println(a[i*lda : i*lda+end])
|
||||
}
|
||||
}
|
||||
|
||||
// isOrthonormal checks that a general matrix is orthonormal.
|
||||
func isOrthonormal(q blas64.General) bool {
|
||||
n := q.Rows
|
||||
for i := 0; i < n; i++ {
|
||||
for j := i; j < n; j++ {
|
||||
dot := blas64.Dot(n,
|
||||
blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]},
|
||||
blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]},
|
||||
)
|
||||
if i == j {
|
||||
if math.Abs(dot-1) > 1e-10 {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if math.Abs(dot) > 1e-10 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
Reference in New Issue
Block a user