mirror of
https://github.com/gonum/gonum.git
synced 2025-10-22 22:59:24 +08:00
Working implementation of blocked QR
Improved function documentation Fixed dlarfb and dlarft and added full tests Added dgelq2 Working Dgels Fix many comments and tests Many PR comment responses Responded to more PR comments Many PR comments
This commit is contained in:
112
testlapack/dgelq2.go
Normal file
112
testlapack/dgelq2.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dgelq2er interface {
|
||||
Dgelq2(m, n int, a []float64, lda int, tau, work []float64)
|
||||
}
|
||||
|
||||
func Dgelq2Test(t *testing.T, impl Dgelq2er) {
|
||||
for c, test := range []struct {
|
||||
m, n, lda int
|
||||
}{
|
||||
{1, 1, 0},
|
||||
{2, 2, 0},
|
||||
{3, 2, 0},
|
||||
{2, 3, 0},
|
||||
{1, 12, 0},
|
||||
{2, 6, 0},
|
||||
{3, 4, 0},
|
||||
{4, 3, 0},
|
||||
{6, 2, 0},
|
||||
{1, 12, 0},
|
||||
{1, 1, 20},
|
||||
{2, 2, 20},
|
||||
{3, 2, 20},
|
||||
{2, 3, 20},
|
||||
{1, 12, 20},
|
||||
{2, 6, 20},
|
||||
{3, 4, 20},
|
||||
{4, 3, 20},
|
||||
{6, 2, 20},
|
||||
{1, 12, 20},
|
||||
} {
|
||||
n := test.n
|
||||
m := test.m
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = test.n
|
||||
}
|
||||
k := min(m, n)
|
||||
tau := make([]float64, k)
|
||||
for i := range tau {
|
||||
tau[i] = rand.Float64()
|
||||
}
|
||||
work := make([]float64, m)
|
||||
for i := range work {
|
||||
work[i] = rand.Float64()
|
||||
}
|
||||
a := make([]float64, m*lda)
|
||||
for i := 0; i < m*lda; i++ {
|
||||
a[i] = rand.Float64()
|
||||
}
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
impl.Dgelq2(m, n, a, lda, tau, work)
|
||||
|
||||
Q := constructQ("LQ", m, n, a, lda, tau)
|
||||
|
||||
// Check that Q is orthonormal
|
||||
for i := 0; i < Q.Rows; i++ {
|
||||
nrm := blas64.Nrm2(Q.Cols, blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]})
|
||||
if math.Abs(nrm-1) > 1e-14 {
|
||||
t.Errorf("Q not normal. Norm is %v", nrm)
|
||||
}
|
||||
for j := 0; j < i; j++ {
|
||||
dot := blas64.Dot(Q.Rows,
|
||||
blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]},
|
||||
blas64.Vector{Inc: 1, Data: Q.Data[j*Q.Stride:]},
|
||||
)
|
||||
if math.Abs(dot) > 1e-14 {
|
||||
t.Errorf("Q not orthogonal. Dot is %v", dot)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
L := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, m*n),
|
||||
}
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j <= min(i, n-1); j++ {
|
||||
L.Data[i*L.Stride+j] = a[i*lda+j]
|
||||
}
|
||||
}
|
||||
|
||||
ans := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: lda,
|
||||
Data: make([]float64, m*lda),
|
||||
}
|
||||
copy(ans.Data, aCopy)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, Q, 0, ans)
|
||||
if !floats.EqualApprox(aCopy, ans.Data, 1e-14) {
|
||||
t.Errorf("Case %v, LQ mismatch. Want %v, got %v.", c, aCopy, ans.Data)
|
||||
}
|
||||
}
|
||||
}
|
94
testlapack/dgelqf.go
Normal file
94
testlapack/dgelqf.go
Normal file
@@ -0,0 +1,94 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dgelqfer interface {
|
||||
Dgelq2er
|
||||
Dgelqf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
|
||||
}
|
||||
|
||||
func DgelqfTest(t *testing.T, impl Dgelqfer) {
|
||||
for c, test := range []struct {
|
||||
m, n, lda int
|
||||
}{
|
||||
{10, 5, 0},
|
||||
{5, 10, 0},
|
||||
{10, 10, 0},
|
||||
{300, 5, 0},
|
||||
{3, 500, 0},
|
||||
{200, 200, 0},
|
||||
{300, 200, 0},
|
||||
{204, 300, 0},
|
||||
{1, 3000, 0},
|
||||
{3000, 1, 0},
|
||||
{10, 5, 30},
|
||||
{5, 10, 30},
|
||||
{10, 10, 30},
|
||||
{300, 5, 500},
|
||||
{3, 500, 600},
|
||||
{200, 200, 300},
|
||||
{300, 200, 300},
|
||||
{204, 300, 400},
|
||||
{1, 3000, 4000},
|
||||
{3000, 1, 4000},
|
||||
} {
|
||||
m := test.m
|
||||
n := test.n
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
a := make([]float64, m*lda)
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
a[i*lda+j] = rand.Float64()
|
||||
}
|
||||
}
|
||||
tau := make([]float64, n)
|
||||
for i := 0; i < n; i++ {
|
||||
tau[i] = rand.Float64()
|
||||
}
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
ans := make([]float64, len(a))
|
||||
copy(ans, a)
|
||||
work := make([]float64, m)
|
||||
for i := range work {
|
||||
work[i] = rand.Float64()
|
||||
}
|
||||
// Compute unblocked QR.
|
||||
impl.Dgelq2(m, n, ans, lda, tau, work)
|
||||
// Compute blocked QR with small work.
|
||||
impl.Dgelqf(m, n, a, lda, tau, work, len(work))
|
||||
if !floats.EqualApprox(ans, a, 1e-14) {
|
||||
t.Errorf("Case %v, mismatch small work.", c)
|
||||
}
|
||||
// Try the full length of work.
|
||||
impl.Dgelqf(m, n, a, lda, tau, work, -1)
|
||||
lwork := int(work[0])
|
||||
work = make([]float64, lwork)
|
||||
copy(a, aCopy)
|
||||
impl.Dgelqf(m, n, a, lda, tau, work, lwork)
|
||||
if !floats.EqualApprox(ans, a, 1e-12) {
|
||||
t.Errorf("Case %v, mismatch large work.", c)
|
||||
}
|
||||
|
||||
// Try a slightly smaller version of work to test blocking code.
|
||||
work = work[1:]
|
||||
lwork--
|
||||
copy(a, aCopy)
|
||||
impl.Dgelqf(m, n, a, lda, tau, work, lwork)
|
||||
if !floats.EqualApprox(ans, a, 1e-12) {
|
||||
t.Errorf("Case %v, mismatch large work.", c)
|
||||
}
|
||||
}
|
||||
}
|
181
testlapack/dgels.go
Normal file
181
testlapack/dgels.go
Normal file
@@ -0,0 +1,181 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dgelser interface {
|
||||
Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool
|
||||
}
|
||||
|
||||
func DgelsTest(t *testing.T, impl Dgelser) {
|
||||
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
|
||||
for _, test := range []struct {
|
||||
m, n, nrhs, lda, ldb int
|
||||
}{
|
||||
{3, 4, 5, 0, 0},
|
||||
{3, 5, 4, 0, 0},
|
||||
{4, 3, 5, 0, 0},
|
||||
{4, 5, 3, 0, 0},
|
||||
{5, 3, 4, 0, 0},
|
||||
{5, 4, 3, 0, 0},
|
||||
{3, 4, 5, 10, 20},
|
||||
{3, 5, 4, 10, 20},
|
||||
{4, 3, 5, 10, 20},
|
||||
{4, 5, 3, 10, 20},
|
||||
{5, 3, 4, 10, 20},
|
||||
{5, 4, 3, 10, 20},
|
||||
{3, 4, 5, 20, 10},
|
||||
{3, 5, 4, 20, 10},
|
||||
{4, 3, 5, 20, 10},
|
||||
{4, 5, 3, 20, 10},
|
||||
{5, 3, 4, 20, 10},
|
||||
{5, 4, 3, 20, 10},
|
||||
{200, 300, 400, 0, 0},
|
||||
{200, 400, 300, 0, 0},
|
||||
{300, 200, 400, 0, 0},
|
||||
{300, 400, 200, 0, 0},
|
||||
{400, 200, 300, 0, 0},
|
||||
{400, 300, 200, 0, 0},
|
||||
{200, 300, 400, 500, 600},
|
||||
{200, 400, 300, 500, 600},
|
||||
{300, 200, 400, 500, 600},
|
||||
{300, 400, 200, 500, 600},
|
||||
{400, 200, 300, 500, 600},
|
||||
{400, 300, 200, 500, 600},
|
||||
{200, 300, 400, 600, 500},
|
||||
{200, 400, 300, 600, 500},
|
||||
{300, 200, 400, 600, 500},
|
||||
{300, 400, 200, 600, 500},
|
||||
{400, 200, 300, 600, 500},
|
||||
{400, 300, 200, 600, 500},
|
||||
} {
|
||||
m := test.m
|
||||
n := test.n
|
||||
nrhs := test.nrhs
|
||||
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
a := make([]float64, m*lda)
|
||||
for i := range a {
|
||||
a[i] = rand.Float64()
|
||||
}
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
|
||||
// Size of b is the same trans or no trans, because the number of rows
|
||||
// has to be the max of (m,n).
|
||||
mb := max(m, n)
|
||||
nb := nrhs
|
||||
ldb := test.ldb
|
||||
if ldb == 0 {
|
||||
ldb = nb
|
||||
}
|
||||
b := make([]float64, mb*ldb)
|
||||
for i := range b {
|
||||
b[i] = rand.Float64()
|
||||
}
|
||||
bCopy := make([]float64, len(b))
|
||||
copy(bCopy, b)
|
||||
|
||||
// Find optimal work length.
|
||||
work := make([]float64, 1)
|
||||
impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, -1)
|
||||
|
||||
// Perform linear solve
|
||||
work = make([]float64, int(work[0]))
|
||||
lwork := len(work)
|
||||
for i := range work {
|
||||
work[i] = rand.Float64()
|
||||
}
|
||||
impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, lwork)
|
||||
|
||||
// Check that the answer is correct by comparing to the normal equations.
|
||||
aMat := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: lda,
|
||||
Data: make([]float64, len(aCopy)),
|
||||
}
|
||||
copy(aMat.Data, aCopy)
|
||||
szAta := n
|
||||
if trans == blas.Trans {
|
||||
szAta = m
|
||||
}
|
||||
aTA := blas64.General{
|
||||
Rows: szAta,
|
||||
Cols: szAta,
|
||||
Stride: szAta,
|
||||
Data: make([]float64, szAta*szAta),
|
||||
}
|
||||
|
||||
// Compute A^T * A if notrans and A * A^T otherwise.
|
||||
if trans == blas.NoTrans {
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, aMat, 0, aTA)
|
||||
} else {
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aMat, aMat, 0, aTA)
|
||||
}
|
||||
|
||||
// Multiply by X.
|
||||
X := blas64.General{
|
||||
Rows: szAta,
|
||||
Cols: nrhs,
|
||||
Stride: ldb,
|
||||
Data: b,
|
||||
}
|
||||
ans := blas64.General{
|
||||
Rows: aTA.Rows,
|
||||
Cols: X.Cols,
|
||||
Stride: X.Cols,
|
||||
Data: make([]float64, aTA.Rows*X.Cols),
|
||||
}
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aTA, X, 0, ans)
|
||||
|
||||
B := blas64.General{
|
||||
Rows: szAta,
|
||||
Cols: nrhs,
|
||||
Stride: ldb,
|
||||
Data: make([]float64, len(bCopy)),
|
||||
}
|
||||
|
||||
copy(B.Data, bCopy)
|
||||
var ans2 blas64.General
|
||||
if trans == blas.NoTrans {
|
||||
ans2 = blas64.General{
|
||||
Rows: aMat.Cols,
|
||||
Cols: B.Cols,
|
||||
Stride: B.Cols,
|
||||
Data: make([]float64, aMat.Cols*B.Cols),
|
||||
}
|
||||
} else {
|
||||
ans2 = blas64.General{
|
||||
Rows: aMat.Rows,
|
||||
Cols: B.Cols,
|
||||
Stride: B.Cols,
|
||||
Data: make([]float64, aMat.Rows*B.Cols),
|
||||
}
|
||||
}
|
||||
|
||||
// Compute A^T B if Trans or A * B otherwise
|
||||
if trans == blas.NoTrans {
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, B, 0, ans2)
|
||||
} else {
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, B, 0, ans2)
|
||||
}
|
||||
if !floats.EqualApprox(ans.Data, ans2.Data, 1e-12) {
|
||||
t.Errorf("Normal equations not satisfied")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
110
testlapack/dgeqr2.go
Normal file
110
testlapack/dgeqr2.go
Normal file
@@ -0,0 +1,110 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dgeqr2er interface {
|
||||
Dgeqr2(m, n int, a []float64, lda int, tau []float64, work []float64)
|
||||
}
|
||||
|
||||
func Dgeqr2Test(t *testing.T, impl Dgeqr2er) {
|
||||
for c, test := range []struct {
|
||||
m, n, lda int
|
||||
}{
|
||||
{1, 1, 0},
|
||||
{2, 2, 0},
|
||||
{3, 2, 0},
|
||||
{2, 3, 0},
|
||||
{1, 12, 0},
|
||||
{2, 6, 0},
|
||||
{3, 4, 0},
|
||||
{4, 3, 0},
|
||||
{6, 2, 0},
|
||||
{12, 1, 0},
|
||||
{1, 1, 20},
|
||||
{2, 2, 20},
|
||||
{3, 2, 20},
|
||||
{2, 3, 20},
|
||||
{1, 12, 20},
|
||||
{2, 6, 20},
|
||||
{3, 4, 20},
|
||||
{4, 3, 20},
|
||||
{6, 2, 20},
|
||||
{12, 1, 20},
|
||||
} {
|
||||
n := test.n
|
||||
m := test.m
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = test.n
|
||||
}
|
||||
a := make([]float64, m*lda)
|
||||
for i := range a {
|
||||
a[i] = rand.Float64()
|
||||
}
|
||||
aCopy := make([]float64, len(a))
|
||||
k := min(m, n)
|
||||
tau := make([]float64, k)
|
||||
for i := range tau {
|
||||
tau[i] = rand.Float64()
|
||||
}
|
||||
work := make([]float64, n)
|
||||
for i := range work {
|
||||
work[i] = rand.Float64()
|
||||
}
|
||||
copy(aCopy, a)
|
||||
impl.Dgeqr2(m, n, a, lda, tau, work)
|
||||
|
||||
// Test that the QR factorization has completed successfully. Compute
|
||||
// Q based on the vectors.
|
||||
q := constructQ("QR", m, n, a, lda, tau)
|
||||
|
||||
// Check that q is orthonormal
|
||||
for i := 0; i < m; i++ {
|
||||
nrm := blas64.Nrm2(m, blas64.Vector{1, q.Data[i*m:]})
|
||||
if math.Abs(nrm-1) > 1e-14 {
|
||||
t.Errorf("Case %v, q not normal", c)
|
||||
}
|
||||
for j := 0; j < i; j++ {
|
||||
dot := blas64.Dot(m, blas64.Vector{1, q.Data[i*m:]}, blas64.Vector{1, q.Data[j*m:]})
|
||||
if math.Abs(dot) > 1e-14 {
|
||||
t.Errorf("Case %v, q not orthogonal", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check that A = Q * R
|
||||
r := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, m*n),
|
||||
}
|
||||
for i := 0; i < m; i++ {
|
||||
for j := i; j < n; j++ {
|
||||
r.Data[i*n+j] = a[i*lda+j]
|
||||
}
|
||||
}
|
||||
atmp := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: lda,
|
||||
Data: make([]float64, m*lda),
|
||||
}
|
||||
copy(atmp.Data, a)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, r, 0, atmp)
|
||||
if !floats.EqualApprox(atmp.Data, aCopy, 1e-14) {
|
||||
t.Errorf("Q*R != a")
|
||||
}
|
||||
}
|
||||
}
|
91
testlapack/dgeqrf.go
Normal file
91
testlapack/dgeqrf.go
Normal file
@@ -0,0 +1,91 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dgeqrfer interface {
|
||||
Dgeqr2er
|
||||
Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
|
||||
}
|
||||
|
||||
func DgeqrfTest(t *testing.T, impl Dgeqrfer) {
|
||||
for c, test := range []struct {
|
||||
m, n, lda int
|
||||
}{
|
||||
{10, 5, 0},
|
||||
{5, 10, 0},
|
||||
{10, 10, 0},
|
||||
{300, 5, 0},
|
||||
{3, 500, 0},
|
||||
{200, 200, 0},
|
||||
{300, 200, 0},
|
||||
{204, 300, 0},
|
||||
{1, 3000, 0},
|
||||
{3000, 1, 0},
|
||||
{10, 5, 20},
|
||||
{5, 10, 20},
|
||||
{10, 10, 20},
|
||||
{300, 5, 400},
|
||||
{3, 500, 600},
|
||||
{200, 200, 300},
|
||||
{300, 200, 300},
|
||||
{204, 300, 400},
|
||||
{1, 3000, 4000},
|
||||
{3000, 1, 4000},
|
||||
} {
|
||||
m := test.m
|
||||
n := test.n
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = test.n
|
||||
}
|
||||
a := make([]float64, m*lda)
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
a[i*lda+j] = rand.Float64()
|
||||
}
|
||||
}
|
||||
tau := make([]float64, n)
|
||||
for i := 0; i < n; i++ {
|
||||
tau[i] = rand.Float64()
|
||||
}
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
ans := make([]float64, len(a))
|
||||
copy(ans, a)
|
||||
work := make([]float64, n)
|
||||
// Compute unblocked QR.
|
||||
impl.Dgeqr2(m, n, ans, lda, tau, work)
|
||||
// Compute blocked QR with small work.
|
||||
impl.Dgeqrf(m, n, a, lda, tau, work, len(work))
|
||||
if !floats.EqualApprox(ans, a, 1e-14) {
|
||||
t.Errorf("Case %v, mismatch small work.", c)
|
||||
}
|
||||
// Try the full length of work.
|
||||
impl.Dgeqrf(m, n, a, lda, tau, work, -1)
|
||||
lwork := int(work[0])
|
||||
work = make([]float64, lwork)
|
||||
copy(a, aCopy)
|
||||
impl.Dgeqrf(m, n, a, lda, tau, work, lwork)
|
||||
if !floats.EqualApprox(ans, a, 1e-12) {
|
||||
t.Errorf("Case %v, mismatch large work.", c)
|
||||
}
|
||||
|
||||
// Try a slightly smaller version of work to test blocking.
|
||||
work = work[1:]
|
||||
lwork--
|
||||
copy(a, aCopy)
|
||||
impl.Dgeqrf(m, n, a, lda, tau, work, lwork)
|
||||
if !floats.EqualApprox(ans, a, 1e-12) {
|
||||
t.Errorf("Case %v, mismatch large work.", c)
|
||||
}
|
||||
}
|
||||
}
|
92
testlapack/dlange.go
Normal file
92
testlapack/dlange.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
type Dlanger interface {
|
||||
Dlange(norm lapack.MatrixNorm, m, n int, a []float64, lda int, work []float64) float64
|
||||
}
|
||||
|
||||
func DlangeTest(t *testing.T, impl Dlanger) {
|
||||
for _, test := range []struct {
|
||||
m, n, lda int
|
||||
}{
|
||||
{4, 3, 0},
|
||||
{3, 4, 0},
|
||||
{4, 3, 100},
|
||||
{3, 4, 100},
|
||||
} {
|
||||
m := test.m
|
||||
n := test.n
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
a := make([]float64, m*lda)
|
||||
for i := range a {
|
||||
a[i] = (rand.Float64() - 0.5)
|
||||
}
|
||||
work := make([]float64, n)
|
||||
for i := range work {
|
||||
work[i] = rand.Float64()
|
||||
}
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
|
||||
// Test MaxAbs norm.
|
||||
norm := impl.Dlange(lapack.MaxAbs, m, n, a, lda, work)
|
||||
var ans float64
|
||||
for i := 0; i < m; i++ {
|
||||
idx := blas64.Iamax(n, blas64.Vector{1, aCopy[i*lda:]})
|
||||
ans = math.Max(ans, math.Abs(a[i*lda+idx]))
|
||||
}
|
||||
// Should be strictly equal because there is no floating point summation error.
|
||||
if ans != norm {
|
||||
t.Errorf("MaxAbs mismatch. Want %v, got %v.", ans, norm)
|
||||
}
|
||||
|
||||
// Test MaxColumnSum norm.
|
||||
norm = impl.Dlange(lapack.MaxColumnSum, m, n, a, lda, work)
|
||||
ans = 0
|
||||
for i := 0; i < n; i++ {
|
||||
sum := blas64.Asum(m, blas64.Vector{lda, aCopy[i:]})
|
||||
ans = math.Max(ans, sum)
|
||||
}
|
||||
if math.Abs(norm-ans) > 1e-14 {
|
||||
t.Errorf("MaxColumnSum mismatch. Want %v, got %v.", ans, norm)
|
||||
}
|
||||
|
||||
// Test MaxRowSum norm.
|
||||
norm = impl.Dlange(lapack.MaxRowSum, m, n, a, lda, work)
|
||||
ans = 0
|
||||
for i := 0; i < m; i++ {
|
||||
sum := blas64.Asum(n, blas64.Vector{1, aCopy[i*lda:]})
|
||||
ans = math.Max(ans, sum)
|
||||
}
|
||||
if math.Abs(norm-ans) > 1e-14 {
|
||||
t.Errorf("MaxRowSum mismatch. Want %v, got %v.", ans, norm)
|
||||
}
|
||||
|
||||
// Test Frobenius norm
|
||||
norm = impl.Dlange(lapack.NormFrob, m, n, a, lda, work)
|
||||
ans = 0
|
||||
for i := 0; i < m; i++ {
|
||||
sum := blas64.Nrm2(n, blas64.Vector{1, aCopy[i*lda:]})
|
||||
ans += sum * sum
|
||||
}
|
||||
ans = math.Sqrt(ans)
|
||||
if math.Abs(norm-ans) > 1e-14 {
|
||||
t.Errorf("NormFrob mismatch. Want %v, got %v.", ans, norm)
|
||||
}
|
||||
}
|
||||
}
|
@@ -63,6 +63,19 @@ func DlarfTest(t *testing.T, impl Dlarfer) {
|
||||
|
||||
tau: 2,
|
||||
},
|
||||
{
|
||||
m: 2,
|
||||
n: 3,
|
||||
ldc: 3,
|
||||
|
||||
incv: 4,
|
||||
lastv: 0,
|
||||
|
||||
lastr: 0,
|
||||
lastc: 1,
|
||||
|
||||
tau: 2,
|
||||
},
|
||||
{
|
||||
m: 10,
|
||||
n: 10,
|
||||
@@ -93,7 +106,7 @@ func DlarfTest(t *testing.T, impl Dlarfer) {
|
||||
sz := max(test.m, test.n) // so v works for both right and left side.
|
||||
v := make([]float64, test.incv*sz+1)
|
||||
// Fill with nonzero entries up until lastv.
|
||||
for i := 0; i < test.lastv; i++ {
|
||||
for i := 0; i <= test.lastv; i++ {
|
||||
v[i*test.incv] = rand.Float64()
|
||||
}
|
||||
// Construct h explicitly to compare.
|
||||
@@ -132,7 +145,7 @@ func DlarfTest(t *testing.T, impl Dlarfer) {
|
||||
work := make([]float64, sz)
|
||||
impl.Dlarf(blas.Right, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
|
||||
if !floats.EqualApprox(c, cMat.Data, 1e-14) {
|
||||
t.Errorf("Dlarf mismatch case %v. Want %v, got %v", i, cMat.Data, c)
|
||||
t.Errorf("Dlarf mismatch right, case %v. Want %v, got %v", i, cMat.Data, c)
|
||||
}
|
||||
|
||||
// Test on the left side.
|
||||
@@ -153,7 +166,7 @@ func DlarfTest(t *testing.T, impl Dlarfer) {
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, cMat2, 0, cMat)
|
||||
impl.Dlarf(blas.Left, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
|
||||
if !floats.EqualApprox(c, cMat.Data, 1e-14) {
|
||||
t.Errorf("Dlarf mismatch case %v. Want %v, got %v", i, cMat.Data, c)
|
||||
t.Errorf("Dlarf mismatch left, case %v. Want %v, got %v", i, cMat.Data, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
160
testlapack/dlarfb.go
Normal file
160
testlapack/dlarfb.go
Normal file
@@ -0,0 +1,160 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
type Dlarfber interface {
|
||||
Dlarfter
|
||||
Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct,
|
||||
store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int,
|
||||
c []float64, ldc int, work []float64, ldwork int)
|
||||
}
|
||||
|
||||
func DlarfbTest(t *testing.T, impl Dlarfber) {
|
||||
for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} {
|
||||
for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
|
||||
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||
for _, trans := range []blas.Transpose{blas.Trans, blas.NoTrans} {
|
||||
for cas, test := range []struct {
|
||||
ma, na, cdim, lda, ldt, ldc int
|
||||
}{
|
||||
{6, 6, 6, 0, 0, 0},
|
||||
{6, 8, 10, 0, 0, 0},
|
||||
{6, 10, 8, 0, 0, 0},
|
||||
{8, 6, 10, 0, 0, 0},
|
||||
{8, 10, 6, 0, 0, 0},
|
||||
{10, 6, 8, 0, 0, 0},
|
||||
{10, 8, 6, 0, 0, 0},
|
||||
{6, 6, 6, 12, 15, 30},
|
||||
{6, 8, 10, 12, 15, 30},
|
||||
{6, 10, 8, 12, 15, 30},
|
||||
{8, 6, 10, 12, 15, 30},
|
||||
{8, 10, 6, 12, 15, 30},
|
||||
{10, 6, 8, 12, 15, 30},
|
||||
{10, 8, 6, 12, 15, 30},
|
||||
{6, 6, 6, 15, 12, 30},
|
||||
{6, 8, 10, 15, 12, 30},
|
||||
{6, 10, 8, 15, 12, 30},
|
||||
{8, 6, 10, 15, 12, 30},
|
||||
{8, 10, 6, 15, 12, 30},
|
||||
{10, 6, 8, 15, 12, 30},
|
||||
{10, 8, 6, 15, 12, 30},
|
||||
} {
|
||||
// Generate a matrix for QR
|
||||
ma := test.ma
|
||||
na := test.na
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = na
|
||||
}
|
||||
a := make([]float64, ma*lda)
|
||||
for i := 0; i < ma; i++ {
|
||||
for j := 0; j < lda; j++ {
|
||||
a[i*lda+j] = rand.Float64()
|
||||
}
|
||||
}
|
||||
k := min(ma, na)
|
||||
|
||||
// H is always ma x ma
|
||||
var m, n, rowsWork int
|
||||
switch {
|
||||
default:
|
||||
panic("not implemented")
|
||||
case side == blas.Left:
|
||||
m = test.ma
|
||||
n = test.cdim
|
||||
rowsWork = n
|
||||
case side == blas.Right:
|
||||
m = test.cdim
|
||||
n = test.ma
|
||||
rowsWork = m
|
||||
}
|
||||
|
||||
// Use dgeqr2 to find the v vectors
|
||||
tau := make([]float64, na)
|
||||
work := make([]float64, na)
|
||||
impl.Dgeqr2(ma, k, a, lda, tau, work)
|
||||
|
||||
// Correct the v vectors based on the direct and store
|
||||
vMatTmp := extractVMat(ma, na, a, lda, lapack.Forward, lapack.ColumnWise)
|
||||
vMat := constructVMat(vMatTmp, store, direct)
|
||||
v := vMat.Data
|
||||
ldv := vMat.Stride
|
||||
|
||||
// Use dlarft to find the t vector
|
||||
ldt := test.ldt
|
||||
if ldt == 0 {
|
||||
ldt = k
|
||||
}
|
||||
tm := make([]float64, k*ldt)
|
||||
|
||||
impl.Dlarft(direct, store, ma, k, v, ldv, tau, tm, ldt)
|
||||
|
||||
// Generate c matrix
|
||||
ldc := test.ldc
|
||||
if ldc == 0 {
|
||||
ldc = n
|
||||
}
|
||||
c := make([]float64, m*ldc)
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < ldc; j++ {
|
||||
c[i*ldc+j] = rand.Float64()
|
||||
}
|
||||
}
|
||||
cCopy := make([]float64, len(c))
|
||||
copy(cCopy, c)
|
||||
|
||||
ldwork := k
|
||||
work = make([]float64, rowsWork*k)
|
||||
|
||||
// Call Dlarfb with this information
|
||||
impl.Dlarfb(side, trans, direct, store, m, n, k, v, ldv, tm, ldt, c, ldc, work, ldwork)
|
||||
|
||||
h := constructH(tau, vMat, store, direct)
|
||||
|
||||
cMat := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: ldc,
|
||||
Data: make([]float64, m*ldc),
|
||||
}
|
||||
copy(cMat.Data, cCopy)
|
||||
ans := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: ldc,
|
||||
Data: make([]float64, m*ldc),
|
||||
}
|
||||
copy(ans.Data, cMat.Data)
|
||||
switch {
|
||||
default:
|
||||
panic("not implemented")
|
||||
case side == blas.Left && trans == blas.NoTrans:
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, cMat, 0, ans)
|
||||
case side == blas.Left && trans == blas.Trans:
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, h, cMat, 0, ans)
|
||||
case side == blas.Right && trans == blas.NoTrans:
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat, h, 0, ans)
|
||||
case side == blas.Right && trans == blas.Trans:
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMat, h, 0, ans)
|
||||
}
|
||||
if !floats.EqualApprox(ans.Data, c, 1e-14) {
|
||||
t.Errorf("Cas %v mismatch. Want %v, got %v.", cas, ans.Data, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
133
testlapack/dlarfg.go
Normal file
133
testlapack/dlarfg.go
Normal file
@@ -0,0 +1,133 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
type Dlarfger interface {
|
||||
Dlarfg(n int, alpha float64, x []float64, incX int) (beta, tau float64)
|
||||
}
|
||||
|
||||
func DlarfgTest(t *testing.T, impl Dlarfger) {
|
||||
for i, test := range []struct {
|
||||
alpha float64
|
||||
n int
|
||||
x []float64
|
||||
}{
|
||||
{
|
||||
alpha: 4,
|
||||
n: 3,
|
||||
},
|
||||
{
|
||||
alpha: -2,
|
||||
n: 3,
|
||||
},
|
||||
{
|
||||
alpha: 0,
|
||||
n: 3,
|
||||
},
|
||||
{
|
||||
alpha: 1,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
alpha: 1,
|
||||
n: 2,
|
||||
x: []float64{4, 5, 6},
|
||||
},
|
||||
} {
|
||||
n := test.n
|
||||
incX := 1
|
||||
var x []float64
|
||||
if test.x == nil {
|
||||
x = make([]float64, n-1)
|
||||
for i := range x {
|
||||
x[i] = rand.Float64()
|
||||
}
|
||||
} else {
|
||||
x = make([]float64, n-1)
|
||||
copy(x, test.x)
|
||||
}
|
||||
xcopy := make([]float64, n-1)
|
||||
copy(xcopy, x)
|
||||
alpha := test.alpha
|
||||
beta, tau := impl.Dlarfg(n, alpha, x, incX)
|
||||
|
||||
// Verify the returns and the values in v. Construct h and perform
|
||||
// the explicit multiplication.
|
||||
h := make([]float64, n*n)
|
||||
for i := 0; i < n; i++ {
|
||||
h[i*n+i] = 1
|
||||
}
|
||||
hmat := blas64.General{
|
||||
Rows: n,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: h,
|
||||
}
|
||||
v := make([]float64, n)
|
||||
copy(v[1:], x)
|
||||
v[0] = 1
|
||||
vVec := blas64.Vector{
|
||||
Inc: 1,
|
||||
Data: v,
|
||||
}
|
||||
blas64.Ger(-tau, vVec, vVec, hmat)
|
||||
eye := blas64.General{
|
||||
Rows: n,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, n*n),
|
||||
}
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, hmat, hmat, 0, eye)
|
||||
iseye := true
|
||||
for i := 0; i < n; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
if i == j {
|
||||
if math.Abs(eye.Data[i*n+j]-1) > 1e-14 {
|
||||
iseye = false
|
||||
}
|
||||
} else {
|
||||
if math.Abs(eye.Data[i*n+j]) > 1e-14 {
|
||||
iseye = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !iseye {
|
||||
t.Errorf("H^T * H is not I %V", eye)
|
||||
}
|
||||
|
||||
xVec := blas64.Vector{
|
||||
Inc: 1,
|
||||
Data: make([]float64, n),
|
||||
}
|
||||
xVec.Data[0] = test.alpha
|
||||
copy(xVec.Data[1:], xcopy)
|
||||
|
||||
ans := make([]float64, n)
|
||||
ansVec := blas64.Vector{
|
||||
Inc: 1,
|
||||
Data: ans,
|
||||
}
|
||||
blas64.Gemv(blas.NoTrans, 1, hmat, xVec, 0, ansVec)
|
||||
if math.Abs(ans[0]-beta) > 1e-14 {
|
||||
t.Errorf("Case %v, beta mismatch. Want %v, got %v", i, ans[0], beta)
|
||||
}
|
||||
for i := 1; i < n; i++ {
|
||||
if math.Abs(ans[i]) > 1e-14 {
|
||||
t.Errorf("Case %v, nonzero answer %v", i, ans)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
167
testlapack/dlarft.go
Normal file
167
testlapack/dlarft.go
Normal file
@@ -0,0 +1,167 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
type Dlarfter interface {
|
||||
Dgeqr2er
|
||||
Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, v []float64, ldv int, tau []float64, t []float64, ldt int)
|
||||
}
|
||||
|
||||
func DlarftTest(t *testing.T, impl Dlarfter) {
|
||||
for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} {
|
||||
for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
|
||||
for _, test := range []struct {
|
||||
m, n, ldv, ldt int
|
||||
}{
|
||||
{6, 6, 0, 0},
|
||||
{8, 6, 0, 0},
|
||||
{6, 8, 0, 0},
|
||||
{6, 6, 10, 15},
|
||||
{8, 6, 10, 15},
|
||||
{6, 8, 10, 15},
|
||||
{6, 6, 15, 10},
|
||||
{8, 6, 15, 10},
|
||||
{6, 8, 15, 10},
|
||||
} {
|
||||
// Generate a matrix
|
||||
m := test.m
|
||||
n := test.n
|
||||
lda := n
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
|
||||
a := make([]float64, m*lda)
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < lda; j++ {
|
||||
a[i*lda+j] = rand.Float64()
|
||||
}
|
||||
}
|
||||
// Use dgeqr2 to find the v vectors
|
||||
tau := make([]float64, n)
|
||||
work := make([]float64, n)
|
||||
impl.Dgeqr2(m, n, a, lda, tau, work)
|
||||
|
||||
// Construct H using these answers
|
||||
vMatTmp := extractVMat(m, n, a, lda, lapack.Forward, lapack.ColumnWise)
|
||||
vMat := constructVMat(vMatTmp, store, direct)
|
||||
v := vMat.Data
|
||||
ldv := vMat.Stride
|
||||
|
||||
h := constructH(tau, vMat, store, direct)
|
||||
|
||||
k := min(m, n)
|
||||
ldt := test.ldt
|
||||
if ldt == 0 {
|
||||
ldt = k
|
||||
}
|
||||
// Find T from the actual function
|
||||
tm := make([]float64, k*ldt)
|
||||
for i := range tm {
|
||||
tm[i] = 100 + rand.Float64()
|
||||
}
|
||||
// The v data has been put into a.
|
||||
impl.Dlarft(direct, store, m, k, v, ldv, tau, tm, ldt)
|
||||
|
||||
tData := make([]float64, len(tm))
|
||||
copy(tData, tm)
|
||||
if direct == lapack.Forward {
|
||||
// Zero out the lower traingular portion.
|
||||
for i := 0; i < k; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
tData[i*ldt+j] = 0
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Zero out the upper traingular portion.
|
||||
for i := 0; i < k; i++ {
|
||||
for j := i + 1; j < k; j++ {
|
||||
tData[i*ldt+j] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
T := blas64.General{
|
||||
Rows: k,
|
||||
Cols: k,
|
||||
Stride: ldt,
|
||||
Data: tData,
|
||||
}
|
||||
|
||||
vMatT := blas64.General{
|
||||
Rows: vMat.Cols,
|
||||
Cols: vMat.Rows,
|
||||
Stride: vMat.Rows,
|
||||
Data: make([]float64, vMat.Cols*vMat.Rows),
|
||||
}
|
||||
for i := 0; i < vMat.Rows; i++ {
|
||||
for j := 0; j < vMat.Cols; j++ {
|
||||
vMatT.Data[j*vMatT.Stride+i] = vMat.Data[i*vMat.Stride+j]
|
||||
}
|
||||
}
|
||||
var comp blas64.General
|
||||
if store == lapack.ColumnWise {
|
||||
// H = I - V * T * V^T
|
||||
tmp := blas64.General{
|
||||
Rows: T.Rows,
|
||||
Cols: vMatT.Cols,
|
||||
Stride: vMatT.Cols,
|
||||
Data: make([]float64, T.Rows*vMatT.Cols),
|
||||
}
|
||||
// T * V^T
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMatT, 0, tmp)
|
||||
comp = blas64.General{
|
||||
Rows: vMat.Rows,
|
||||
Cols: tmp.Cols,
|
||||
Stride: tmp.Cols,
|
||||
Data: make([]float64, vMat.Rows*tmp.Cols),
|
||||
}
|
||||
// V * (T * V^T)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMat, tmp, 0, comp)
|
||||
} else {
|
||||
// H = I - V^T * T * V
|
||||
tmp := blas64.General{
|
||||
Rows: T.Rows,
|
||||
Cols: vMat.Cols,
|
||||
Stride: vMat.Cols,
|
||||
Data: make([]float64, T.Rows*vMat.Cols),
|
||||
}
|
||||
// T * V
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMat, 0, tmp)
|
||||
comp = blas64.General{
|
||||
Rows: vMatT.Rows,
|
||||
Cols: tmp.Cols,
|
||||
Stride: tmp.Cols,
|
||||
Data: make([]float64, vMatT.Rows*tmp.Cols),
|
||||
}
|
||||
// V^T * (T * V)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMatT, tmp, 0, comp)
|
||||
}
|
||||
// I - V^T * T * V
|
||||
for i := 0; i < comp.Rows; i++ {
|
||||
for j := 0; j < comp.Cols; j++ {
|
||||
comp.Data[i*m+j] *= -1
|
||||
if i == j {
|
||||
comp.Data[i*m+j] += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
if !floats.EqualApprox(comp.Data, h.Data, 1e-14) {
|
||||
t.Errorf("T does not construct proper H. Store = %v, Direct = %v.\nWant %v\ngot %v.", string(store), string(direct), h.Data, comp.Data)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
138
testlapack/dorm2r.go
Normal file
138
testlapack/dorm2r.go
Normal file
@@ -0,0 +1,138 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dorm2rer interface {
|
||||
Dgeqrfer
|
||||
Dorm2r(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64)
|
||||
}
|
||||
|
||||
func Dorm2rTest(t *testing.T, impl Dorm2rer) {
|
||||
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
|
||||
for _, test := range []struct {
|
||||
common, adim, cdim, lda, ldc int
|
||||
}{
|
||||
{3, 4, 5, 0, 0},
|
||||
{3, 5, 4, 0, 0},
|
||||
{4, 3, 5, 0, 0},
|
||||
{4, 5, 3, 0, 0},
|
||||
{5, 3, 4, 0, 0},
|
||||
{5, 4, 3, 0, 0},
|
||||
{3, 4, 5, 6, 20},
|
||||
{3, 5, 4, 6, 20},
|
||||
{4, 3, 5, 6, 20},
|
||||
{4, 5, 3, 6, 20},
|
||||
{5, 3, 4, 6, 20},
|
||||
{5, 4, 3, 6, 20},
|
||||
{3, 4, 5, 20, 6},
|
||||
{3, 5, 4, 20, 6},
|
||||
{4, 3, 5, 20, 6},
|
||||
{4, 5, 3, 20, 6},
|
||||
{5, 3, 4, 20, 6},
|
||||
{5, 4, 3, 20, 6},
|
||||
} {
|
||||
var ma, na, mc, nc int
|
||||
if side == blas.Left {
|
||||
ma = test.common
|
||||
na = test.adim
|
||||
mc = test.common
|
||||
nc = test.cdim
|
||||
} else {
|
||||
ma = test.common
|
||||
na = test.adim
|
||||
mc = test.cdim
|
||||
nc = test.common
|
||||
}
|
||||
|
||||
// Generate a random matrix
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = na
|
||||
}
|
||||
a := make([]float64, ma*lda)
|
||||
for i := range a {
|
||||
a[i] = rand.Float64()
|
||||
}
|
||||
ldc := test.ldc
|
||||
if ldc == 0 {
|
||||
ldc = nc
|
||||
}
|
||||
// Compute random C matrix
|
||||
c := make([]float64, mc*ldc)
|
||||
for i := range c {
|
||||
c[i] = rand.Float64()
|
||||
}
|
||||
|
||||
// Compute QR
|
||||
k := min(ma, na)
|
||||
tau := make([]float64, k)
|
||||
work := make([]float64, 1)
|
||||
impl.Dgeqrf(ma, na, a, lda, tau, work, -1)
|
||||
work = make([]float64, int(work[0]))
|
||||
impl.Dgeqrf(ma, na, a, lda, tau, work, len(work))
|
||||
|
||||
// Build Q from result
|
||||
q := constructQ("QR", ma, na, a, lda, tau)
|
||||
|
||||
cMat := blas64.General{
|
||||
Rows: mc,
|
||||
Cols: nc,
|
||||
Stride: ldc,
|
||||
Data: make([]float64, len(c)),
|
||||
}
|
||||
copy(cMat.Data, c)
|
||||
cMatCopy := blas64.General{
|
||||
Rows: cMat.Rows,
|
||||
Cols: cMat.Cols,
|
||||
Stride: cMat.Stride,
|
||||
Data: make([]float64, len(cMat.Data)),
|
||||
}
|
||||
copy(cMatCopy.Data, cMat.Data)
|
||||
switch {
|
||||
default:
|
||||
panic("bad test")
|
||||
case side == blas.Left && trans == blas.NoTrans:
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
|
||||
case side == blas.Left && trans == blas.Trans:
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
|
||||
case side == blas.Right && trans == blas.NoTrans:
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMatCopy, q, 0, cMat)
|
||||
case side == blas.Right && trans == blas.Trans:
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMatCopy, q, 0, cMat)
|
||||
}
|
||||
// Do Dorm2r ard compare
|
||||
if side == blas.Left {
|
||||
work = make([]float64, nc)
|
||||
} else {
|
||||
work = make([]float64, mc)
|
||||
}
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
tauCopy := make([]float64, len(tau))
|
||||
copy(tauCopy, tau)
|
||||
impl.Dorm2r(side, trans, mc, nc, k, a, lda, tau, c, ldc, work)
|
||||
if !floats.Equal(a, aCopy) {
|
||||
t.Errorf("a changed in call")
|
||||
}
|
||||
if !floats.Equal(tau, tauCopy) {
|
||||
t.Errorf("tau changed in call")
|
||||
}
|
||||
if !floats.EqualApprox(cMat.Data, c, 1e-14) {
|
||||
t.Errorf("Multiplication mismatch.\n Want %v \n got %v.", cMat.Data, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
137
testlapack/dorml2.go
Normal file
137
testlapack/dorml2.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dorml2er interface {
|
||||
Dgelqfer
|
||||
Dorml2(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64)
|
||||
}
|
||||
|
||||
func Dorml2Test(t *testing.T, impl Dorml2er) {
|
||||
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
|
||||
for _, test := range []struct {
|
||||
common, adim, cdim, lda, ldc int
|
||||
}{
|
||||
{3, 4, 5, 0, 0},
|
||||
{3, 5, 4, 0, 0},
|
||||
{4, 3, 5, 0, 0},
|
||||
{4, 5, 3, 0, 0},
|
||||
{5, 3, 4, 0, 0},
|
||||
{5, 4, 3, 0, 0},
|
||||
{3, 4, 5, 6, 20},
|
||||
{3, 5, 4, 6, 20},
|
||||
{4, 3, 5, 6, 20},
|
||||
{4, 5, 3, 6, 20},
|
||||
{5, 3, 4, 6, 20},
|
||||
{5, 4, 3, 6, 20},
|
||||
{3, 4, 5, 20, 6},
|
||||
{3, 5, 4, 20, 6},
|
||||
{4, 3, 5, 20, 6},
|
||||
{4, 5, 3, 20, 6},
|
||||
{5, 3, 4, 20, 6},
|
||||
{5, 4, 3, 20, 6},
|
||||
} {
|
||||
var ma, na, mc, nc int
|
||||
if side == blas.Left {
|
||||
ma = test.adim
|
||||
na = test.common
|
||||
mc = test.common
|
||||
nc = test.cdim
|
||||
} else {
|
||||
ma = test.adim
|
||||
na = test.common
|
||||
mc = test.cdim
|
||||
nc = test.common
|
||||
}
|
||||
// Generate a random matrix
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = na
|
||||
}
|
||||
a := make([]float64, ma*lda)
|
||||
for i := range a {
|
||||
a[i] = rand.Float64()
|
||||
}
|
||||
ldc := test.ldc
|
||||
if ldc == 0 {
|
||||
ldc = nc
|
||||
}
|
||||
// Compute random C matrix
|
||||
c := make([]float64, mc*ldc)
|
||||
for i := range c {
|
||||
c[i] = rand.Float64()
|
||||
}
|
||||
|
||||
// Compute LQ
|
||||
k := min(ma, na)
|
||||
tau := make([]float64, k)
|
||||
work := make([]float64, 1)
|
||||
impl.Dgelqf(ma, na, a, lda, tau, work, -1)
|
||||
work = make([]float64, int(work[0]))
|
||||
impl.Dgelqf(ma, na, a, lda, tau, work, len(work))
|
||||
|
||||
// Build Q from result
|
||||
q := constructQ("LQ", ma, na, a, lda, tau)
|
||||
|
||||
cMat := blas64.General{
|
||||
Rows: mc,
|
||||
Cols: nc,
|
||||
Stride: ldc,
|
||||
Data: make([]float64, len(c)),
|
||||
}
|
||||
copy(cMat.Data, c)
|
||||
cMatCopy := blas64.General{
|
||||
Rows: cMat.Rows,
|
||||
Cols: cMat.Cols,
|
||||
Stride: cMat.Stride,
|
||||
Data: make([]float64, len(cMat.Data)),
|
||||
}
|
||||
copy(cMatCopy.Data, cMat.Data)
|
||||
switch {
|
||||
default:
|
||||
panic("bad test")
|
||||
case side == blas.Left && trans == blas.NoTrans:
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
|
||||
case side == blas.Left && trans == blas.Trans:
|
||||
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
|
||||
case side == blas.Right && trans == blas.NoTrans:
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMatCopy, q, 0, cMat)
|
||||
case side == blas.Right && trans == blas.Trans:
|
||||
blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMatCopy, q, 0, cMat)
|
||||
}
|
||||
// Do Dorm2r ard compare
|
||||
if side == blas.Left {
|
||||
work = make([]float64, nc)
|
||||
} else {
|
||||
work = make([]float64, mc)
|
||||
}
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
tauCopy := make([]float64, len(tau))
|
||||
copy(tauCopy, tau)
|
||||
impl.Dorml2(side, trans, mc, nc, k, a, lda, tau, c, ldc, work)
|
||||
if !floats.Equal(a, aCopy) {
|
||||
t.Errorf("a changed in call")
|
||||
}
|
||||
if !floats.Equal(tau, tauCopy) {
|
||||
t.Errorf("tau changed in call")
|
||||
}
|
||||
if !floats.EqualApprox(cMat.Data, c, 1e-14) {
|
||||
t.Errorf("Multiplication mismatch.\n Want %v \n got %v.", cMat.Data, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
149
testlapack/dormlq.go
Normal file
149
testlapack/dormlq.go
Normal file
@@ -0,0 +1,149 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dormlqer interface {
|
||||
Dorml2er
|
||||
Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
|
||||
}
|
||||
|
||||
func DormlqTest(t *testing.T, impl Dormlqer) {
|
||||
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
|
||||
for _, test := range []struct {
|
||||
common, adim, cdim, lda, ldc int
|
||||
}{
|
||||
{6, 7, 8, 0, 0},
|
||||
{6, 8, 7, 0, 0},
|
||||
{7, 6, 8, 0, 0},
|
||||
{7, 8, 6, 0, 0},
|
||||
{8, 6, 7, 0, 0},
|
||||
{8, 7, 6, 0, 0},
|
||||
{100, 200, 300, 0, 0},
|
||||
{100, 300, 200, 0, 0},
|
||||
{200, 100, 300, 0, 0},
|
||||
{200, 300, 100, 0, 0},
|
||||
{300, 100, 200, 0, 0},
|
||||
{300, 200, 100, 0, 0},
|
||||
{100, 200, 300, 400, 500},
|
||||
{100, 300, 200, 400, 500},
|
||||
{200, 100, 300, 400, 500},
|
||||
{200, 300, 100, 400, 500},
|
||||
{300, 100, 200, 400, 500},
|
||||
{300, 200, 100, 400, 500},
|
||||
{100, 200, 300, 500, 400},
|
||||
{100, 300, 200, 500, 400},
|
||||
{200, 100, 300, 500, 400},
|
||||
{200, 300, 100, 500, 400},
|
||||
{300, 100, 200, 500, 400},
|
||||
{300, 200, 100, 500, 400},
|
||||
} {
|
||||
var ma, na, mc, nc int
|
||||
if side == blas.Left {
|
||||
ma = test.adim
|
||||
na = test.common
|
||||
mc = test.common
|
||||
nc = test.cdim
|
||||
} else {
|
||||
ma = test.adim
|
||||
na = test.common
|
||||
mc = test.cdim
|
||||
nc = test.common
|
||||
}
|
||||
// Generate a random matrix
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = na
|
||||
}
|
||||
a := make([]float64, ma*lda)
|
||||
for i := range a {
|
||||
a[i] = rand.Float64()
|
||||
}
|
||||
// Compute random C matrix
|
||||
ldc := test.ldc
|
||||
if ldc == 0 {
|
||||
ldc = nc
|
||||
}
|
||||
c := make([]float64, mc*ldc)
|
||||
for i := range c {
|
||||
c[i] = rand.Float64()
|
||||
}
|
||||
|
||||
// Compute LQ
|
||||
k := min(ma, na)
|
||||
tau := make([]float64, k)
|
||||
work := make([]float64, 1)
|
||||
impl.Dgelqf(ma, na, a, lda, tau, work, -1)
|
||||
work = make([]float64, int(work[0]))
|
||||
impl.Dgelqf(ma, na, a, lda, tau, work, len(work))
|
||||
|
||||
cCopy := make([]float64, len(c))
|
||||
copy(cCopy, c)
|
||||
ans := make([]float64, len(c))
|
||||
copy(ans, cCopy)
|
||||
|
||||
if side == blas.Left {
|
||||
work = make([]float64, nc)
|
||||
} else {
|
||||
work = make([]float64, mc)
|
||||
}
|
||||
impl.Dorml2(side, trans, mc, nc, k, a, lda, tau, ans, ldc, work)
|
||||
|
||||
// Make sure Dorml2 and Dormlq match with small work
|
||||
for i := range work {
|
||||
work[i] = rand.Float64()
|
||||
}
|
||||
lwork := len(work)
|
||||
copy(c, cCopy)
|
||||
impl.Dormlq(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||
t.Errorf("Dormqr and Dorm2r mismatch for small work")
|
||||
}
|
||||
|
||||
// Try with the optimum amount of work
|
||||
copy(c, cCopy)
|
||||
impl.Dormlq(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, -1)
|
||||
work = make([]float64, int(work[0]))
|
||||
lwork = len(work)
|
||||
for i := range work {
|
||||
work[i] = rand.Float64()
|
||||
}
|
||||
impl.Dormlq(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||
t.Errorf("Dormqr and Dorm2r mismatch for full work")
|
||||
fmt.Println("ccopy")
|
||||
for i := 0; i < mc; i++ {
|
||||
fmt.Println(cCopy[i*ldc : (i+1)*ldc])
|
||||
}
|
||||
fmt.Println("ans =")
|
||||
for i := 0; i < mc; i++ {
|
||||
fmt.Println(ans[i*ldc : (i+1)*ldc])
|
||||
}
|
||||
fmt.Println("c =")
|
||||
for i := 0; i < mc; i++ {
|
||||
fmt.Println(c[i*ldc : (i+1)*ldc])
|
||||
}
|
||||
}
|
||||
// Try with less than the optimum amount of work
|
||||
copy(c, cCopy)
|
||||
work = work[1:]
|
||||
lwork--
|
||||
impl.Dormlq(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||
t.Errorf("Dormqr and Dorm2r mismatch for medium work")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
150
testlapack/dormqr.go
Normal file
150
testlapack/dormqr.go
Normal file
@@ -0,0 +1,150 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/floats"
|
||||
)
|
||||
|
||||
type Dormqrer interface {
|
||||
Dorm2rer
|
||||
Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
|
||||
}
|
||||
|
||||
func DormqrTest(t *testing.T, impl Dormqrer) {
|
||||
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
|
||||
for _, test := range []struct {
|
||||
common, adim, cdim, lda, ldc int
|
||||
}{
|
||||
{6, 7, 8, 0, 0},
|
||||
{6, 8, 7, 0, 0},
|
||||
{7, 6, 8, 0, 0},
|
||||
{7, 8, 6, 0, 0},
|
||||
{8, 6, 7, 0, 0},
|
||||
{8, 7, 6, 0, 0},
|
||||
{100, 200, 300, 0, 0},
|
||||
{100, 300, 200, 0, 0},
|
||||
{200, 100, 300, 0, 0},
|
||||
{200, 300, 100, 0, 0},
|
||||
{300, 100, 200, 0, 0},
|
||||
{300, 200, 100, 0, 0},
|
||||
{100, 200, 300, 400, 500},
|
||||
{100, 300, 200, 400, 500},
|
||||
{200, 100, 300, 400, 500},
|
||||
{200, 300, 100, 400, 500},
|
||||
{300, 100, 200, 400, 500},
|
||||
{300, 200, 100, 400, 500},
|
||||
{100, 200, 300, 500, 400},
|
||||
{100, 300, 200, 500, 400},
|
||||
{200, 100, 300, 500, 400},
|
||||
{200, 300, 100, 500, 400},
|
||||
{300, 100, 200, 500, 400},
|
||||
{300, 200, 100, 500, 400},
|
||||
} {
|
||||
var ma, na, mc, nc int
|
||||
if side == blas.Left {
|
||||
ma = test.common
|
||||
na = test.adim
|
||||
mc = test.common
|
||||
nc = test.cdim
|
||||
} else {
|
||||
ma = test.common
|
||||
na = test.adim
|
||||
mc = test.cdim
|
||||
nc = test.common
|
||||
}
|
||||
// Generate a random matrix
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = na
|
||||
}
|
||||
a := make([]float64, ma*lda)
|
||||
for i := range a {
|
||||
a[i] = rand.Float64()
|
||||
}
|
||||
// Compute random C matrix
|
||||
ldc := test.ldc
|
||||
if ldc == 0 {
|
||||
ldc = nc
|
||||
}
|
||||
c := make([]float64, mc*ldc)
|
||||
for i := range c {
|
||||
c[i] = rand.Float64()
|
||||
}
|
||||
|
||||
// Compute QR
|
||||
k := min(ma, na)
|
||||
tau := make([]float64, k)
|
||||
work := make([]float64, 1)
|
||||
impl.Dgeqrf(ma, na, a, lda, tau, work, -1)
|
||||
work = make([]float64, int(work[0]))
|
||||
impl.Dgeqrf(ma, na, a, lda, tau, work, len(work))
|
||||
|
||||
cCopy := make([]float64, len(c))
|
||||
copy(cCopy, c)
|
||||
ans := make([]float64, len(c))
|
||||
copy(ans, cCopy)
|
||||
|
||||
if side == blas.Left {
|
||||
work = make([]float64, nc)
|
||||
} else {
|
||||
work = make([]float64, mc)
|
||||
}
|
||||
impl.Dorm2r(side, trans, mc, nc, k, a, lda, tau, ans, ldc, work)
|
||||
|
||||
// Make sure Dorm2r and Dormqr match with small work
|
||||
for i := range work {
|
||||
work[i] = rand.Float64()
|
||||
}
|
||||
lwork := len(work)
|
||||
copy(c, cCopy)
|
||||
impl.Dormqr(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||
t.Errorf("Dormqr and Dorm2r mismatch for small work")
|
||||
}
|
||||
|
||||
// Try with the optimum amount of work
|
||||
copy(c, cCopy)
|
||||
impl.Dormqr(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, -1)
|
||||
work = make([]float64, int(work[0]))
|
||||
lwork = len(work)
|
||||
for i := range work {
|
||||
work[i] = rand.Float64()
|
||||
}
|
||||
_ = lwork
|
||||
impl.Dormqr(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||
t.Errorf("Dormqr and Dorm2r mismatch for full work")
|
||||
fmt.Println("ccopy")
|
||||
for i := 0; i < mc; i++ {
|
||||
fmt.Println(cCopy[i*ldc : (i+1)*ldc])
|
||||
}
|
||||
fmt.Println("ans =")
|
||||
for i := 0; i < mc; i++ {
|
||||
fmt.Println(ans[i*ldc : (i+1)*ldc])
|
||||
}
|
||||
fmt.Println("c =")
|
||||
for i := 0; i < mc; i++ {
|
||||
fmt.Println(c[i*ldc : (i+1)*ldc])
|
||||
}
|
||||
}
|
||||
// Try with less than the optimum amount of work
|
||||
copy(c, cCopy)
|
||||
work = work[1:]
|
||||
lwork--
|
||||
impl.Dormqr(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||
t.Errorf("Dormqr and Dorm2r mismatch for medium work")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -4,9 +4,289 @@
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// extractVMat collects the single reflectors from a into a matrix.
|
||||
func extractVMat(m, n int, a []float64, lda int, direct lapack.Direct, store lapack.StoreV) blas64.General {
|
||||
k := min(m, n)
|
||||
switch {
|
||||
default:
|
||||
panic("not implemented")
|
||||
case direct == lapack.Forward && store == lapack.ColumnWise:
|
||||
v := blas64.General{
|
||||
Rows: m,
|
||||
Cols: k,
|
||||
Stride: k,
|
||||
Data: make([]float64, m*k),
|
||||
}
|
||||
for i := 0; i < k; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
v.Data[j*v.Stride+i] = 0
|
||||
}
|
||||
v.Data[i*v.Stride+i] = 1
|
||||
for j := i + 1; j < m; j++ {
|
||||
v.Data[j*v.Stride+i] = a[j*lda+i]
|
||||
}
|
||||
}
|
||||
return v
|
||||
case direct == lapack.Forward && store == lapack.RowWise:
|
||||
v := blas64.General{
|
||||
Rows: k,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, k*n),
|
||||
}
|
||||
for i := 0; i < k; i++ {
|
||||
for j := 0; j < i; j++ {
|
||||
v.Data[i*v.Stride+j] = 0
|
||||
}
|
||||
v.Data[i*v.Stride+i] = 1
|
||||
for j := i + 1; j < n; j++ {
|
||||
v.Data[i*v.Stride+j] = a[i*lda+j]
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// constructVMat transforms the v matrix based on the storage.
|
||||
func constructVMat(vMat blas64.General, store lapack.StoreV, direct lapack.Direct) blas64.General {
|
||||
m := vMat.Rows
|
||||
k := vMat.Cols
|
||||
switch {
|
||||
default:
|
||||
panic("not implemented")
|
||||
case store == lapack.ColumnWise && direct == lapack.Forward:
|
||||
ldv := k
|
||||
v := make([]float64, m*k)
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < k; j++ {
|
||||
if j > i {
|
||||
v[i*ldv+j] = 0
|
||||
} else if j == i {
|
||||
v[i*ldv+i] = 1
|
||||
} else {
|
||||
v[i*ldv+j] = vMat.Data[i*vMat.Stride+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
return blas64.General{
|
||||
Rows: m,
|
||||
Cols: k,
|
||||
Stride: k,
|
||||
Data: v,
|
||||
}
|
||||
case store == lapack.RowWise && direct == lapack.Forward:
|
||||
ldv := m
|
||||
v := make([]float64, m*k)
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < k; j++ {
|
||||
if j > i {
|
||||
v[j*ldv+i] = 0
|
||||
} else if j == i {
|
||||
v[j*ldv+i] = 1
|
||||
} else {
|
||||
v[j*ldv+i] = vMat.Data[i*vMat.Stride+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
return blas64.General{
|
||||
Rows: k,
|
||||
Cols: m,
|
||||
Stride: m,
|
||||
Data: v,
|
||||
}
|
||||
case store == lapack.ColumnWise && direct == lapack.Backward:
|
||||
rowsv := m
|
||||
ldv := k
|
||||
v := make([]float64, m*k)
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < k; j++ {
|
||||
vrow := rowsv - i - 1
|
||||
vcol := k - j - 1
|
||||
if j > i {
|
||||
v[vrow*ldv+vcol] = 0
|
||||
} else if j == i {
|
||||
v[vrow*ldv+vcol] = 1
|
||||
} else {
|
||||
v[vrow*ldv+vcol] = vMat.Data[i*vMat.Stride+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
return blas64.General{
|
||||
Rows: rowsv,
|
||||
Cols: ldv,
|
||||
Stride: ldv,
|
||||
Data: v,
|
||||
}
|
||||
case store == lapack.RowWise && direct == lapack.Backward:
|
||||
rowsv := k
|
||||
ldv := m
|
||||
v := make([]float64, m*k)
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < k; j++ {
|
||||
vcol := ldv - i - 1
|
||||
vrow := k - j - 1
|
||||
if j > i {
|
||||
v[vrow*ldv+vcol] = 0
|
||||
} else if j == i {
|
||||
v[vrow*ldv+vcol] = 1
|
||||
} else {
|
||||
v[vrow*ldv+vcol] = vMat.Data[i*vMat.Stride+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
return blas64.General{
|
||||
Rows: rowsv,
|
||||
Cols: ldv,
|
||||
Stride: ldv,
|
||||
Data: v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func constructH(tau []float64, v blas64.General, store lapack.StoreV, direct lapack.Direct) blas64.General {
|
||||
m := v.Rows
|
||||
k := v.Cols
|
||||
if store == lapack.RowWise {
|
||||
m, k = k, m
|
||||
}
|
||||
h := blas64.General{
|
||||
Rows: m,
|
||||
Cols: m,
|
||||
Stride: m,
|
||||
Data: make([]float64, m*m),
|
||||
}
|
||||
for i := 0; i < m; i++ {
|
||||
h.Data[i*m+i] = 1
|
||||
}
|
||||
for i := 0; i < k; i++ {
|
||||
vecData := make([]float64, m)
|
||||
if store == lapack.ColumnWise {
|
||||
for j := 0; j < m; j++ {
|
||||
vecData[j] = v.Data[j*v.Cols+i]
|
||||
}
|
||||
} else {
|
||||
for j := 0; j < m; j++ {
|
||||
vecData[j] = v.Data[i*v.Cols+j]
|
||||
}
|
||||
}
|
||||
vec := blas64.Vector{
|
||||
Inc: 1,
|
||||
Data: vecData,
|
||||
}
|
||||
|
||||
hi := blas64.General{
|
||||
Rows: m,
|
||||
Cols: m,
|
||||
Stride: m,
|
||||
Data: make([]float64, m*m),
|
||||
}
|
||||
for i := 0; i < m; i++ {
|
||||
hi.Data[i*m+i] = 1
|
||||
}
|
||||
// hi = I - tau * v * v^T
|
||||
blas64.Ger(-tau[i], vec, vec, hi)
|
||||
|
||||
hcopy := blas64.General{
|
||||
Rows: m,
|
||||
Cols: m,
|
||||
Stride: m,
|
||||
Data: make([]float64, m*m),
|
||||
}
|
||||
copy(hcopy.Data, h.Data)
|
||||
if direct == lapack.Forward {
|
||||
// H = H * H_I in forward mode
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hcopy, hi, 0, h)
|
||||
} else {
|
||||
// H = H_I * H in backward mode
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hi, hcopy, 0, h)
|
||||
}
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// constructQ constructs the Q matrix from the result of dgeqrf and dgeqr2
|
||||
func constructQ(kind string, m, n int, a []float64, lda int, tau []float64) blas64.General {
|
||||
k := min(m, n)
|
||||
var sz int
|
||||
switch kind {
|
||||
case "QR":
|
||||
sz = m
|
||||
case "LQ":
|
||||
sz = n
|
||||
}
|
||||
|
||||
q := blas64.General{
|
||||
Rows: sz,
|
||||
Cols: sz,
|
||||
Stride: sz,
|
||||
Data: make([]float64, sz*sz),
|
||||
}
|
||||
for i := 0; i < sz; i++ {
|
||||
q.Data[i*sz+i] = 1
|
||||
}
|
||||
qCopy := blas64.General{
|
||||
Rows: q.Rows,
|
||||
Cols: q.Cols,
|
||||
Stride: q.Stride,
|
||||
Data: make([]float64, len(q.Data)),
|
||||
}
|
||||
for i := 0; i < k; i++ {
|
||||
h := blas64.General{
|
||||
Rows: sz,
|
||||
Cols: sz,
|
||||
Stride: sz,
|
||||
Data: make([]float64, sz*sz),
|
||||
}
|
||||
for j := 0; j < sz; j++ {
|
||||
h.Data[j*sz+j] = 1
|
||||
}
|
||||
vVec := blas64.Vector{
|
||||
Inc: 1,
|
||||
Data: make([]float64, sz),
|
||||
}
|
||||
for j := 0; j < i; j++ {
|
||||
vVec.Data[j] = 0
|
||||
}
|
||||
vVec.Data[i] = 1
|
||||
switch kind {
|
||||
case "QR":
|
||||
for j := i + 1; j < sz; j++ {
|
||||
vVec.Data[j] = a[lda*j+i]
|
||||
}
|
||||
case "LQ":
|
||||
for j := i + 1; j < sz; j++ {
|
||||
vVec.Data[j] = a[i*lda+j]
|
||||
}
|
||||
}
|
||||
blas64.Ger(-tau[i], vVec, vVec, h)
|
||||
copy(qCopy.Data, q.Data)
|
||||
// Mulitply q by the new h
|
||||
switch kind {
|
||||
case "QR":
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, h, 0, q)
|
||||
case "LQ":
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qCopy, 0, q)
|
||||
}
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
Reference in New Issue
Block a user