mirror of
https://github.com/gonum/gonum.git
synced 2025-10-30 18:16:32 +08:00
testlapack: rework DlarfTest
This commit is contained in:
committed by
Vladimír Chalupecký
parent
c8be30b70e
commit
489fd3c18f
@@ -5,6 +5,7 @@
|
|||||||
package testlapack
|
package testlapack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
@@ -12,6 +13,7 @@ import (
|
|||||||
"gonum.org/v1/gonum/blas"
|
"gonum.org/v1/gonum/blas"
|
||||||
"gonum.org/v1/gonum/blas/blas64"
|
"gonum.org/v1/gonum/blas/blas64"
|
||||||
"gonum.org/v1/gonum/floats"
|
"gonum.org/v1/gonum/floats"
|
||||||
|
"gonum.org/v1/gonum/lapack"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Dlarfer interface {
|
type Dlarfer interface {
|
||||||
@@ -19,156 +21,159 @@ type Dlarfer interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DlarfTest(t *testing.T, impl Dlarfer) {
|
func DlarfTest(t *testing.T, impl Dlarfer) {
|
||||||
|
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||||
|
name := "Right"
|
||||||
|
if side == blas.Left {
|
||||||
|
name = "Left"
|
||||||
|
}
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
runDlarfTest(t, impl, side)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDlarfTest(t *testing.T, impl Dlarfer, side blas.Side) {
|
||||||
rnd := rand.New(rand.NewSource(1))
|
rnd := rand.New(rand.NewSource(1))
|
||||||
for i, test := range []struct {
|
for _, m := range []int{0, 1, 2, 3, 4, 5, 10} {
|
||||||
m, n, ldc int
|
for _, n := range []int{0, 1, 2, 3, 4, 5, 10} {
|
||||||
incv, lastv int
|
for _, incv := range []int{1, 4} {
|
||||||
lastr, lastc int
|
for _, ldc := range []int{max(1, n), n + 3} {
|
||||||
tau float64
|
for _, nnzv := range []int{0, 1, 2} {
|
||||||
}{
|
for _, nnzc := range []int{0, 1, 2} {
|
||||||
{
|
for _, tau := range []float64{0, rnd.NormFloat64()} {
|
||||||
m: 3,
|
dlarfTest(t, impl, rnd, side, m, n, incv, ldc, nnzv, nnzc, tau)
|
||||||
n: 2,
|
|
||||||
ldc: 2,
|
|
||||||
|
|
||||||
incv: 4,
|
|
||||||
lastv: 1,
|
|
||||||
|
|
||||||
lastr: 2,
|
|
||||||
lastc: 1,
|
|
||||||
|
|
||||||
tau: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
m: 2,
|
|
||||||
n: 3,
|
|
||||||
ldc: 3,
|
|
||||||
|
|
||||||
incv: 4,
|
|
||||||
lastv: 1,
|
|
||||||
|
|
||||||
lastr: 1,
|
|
||||||
lastc: 2,
|
|
||||||
|
|
||||||
tau: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
m: 2,
|
|
||||||
n: 3,
|
|
||||||
ldc: 3,
|
|
||||||
|
|
||||||
incv: 4,
|
|
||||||
lastv: 1,
|
|
||||||
|
|
||||||
lastr: 0,
|
|
||||||
lastc: 1,
|
|
||||||
|
|
||||||
tau: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
m: 2,
|
|
||||||
n: 3,
|
|
||||||
ldc: 3,
|
|
||||||
|
|
||||||
incv: 4,
|
|
||||||
lastv: 0,
|
|
||||||
|
|
||||||
lastr: 0,
|
|
||||||
lastc: 1,
|
|
||||||
|
|
||||||
tau: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
m: 10,
|
|
||||||
n: 10,
|
|
||||||
ldc: 10,
|
|
||||||
|
|
||||||
incv: 4,
|
|
||||||
lastv: 6,
|
|
||||||
|
|
||||||
lastr: 9,
|
|
||||||
lastc: 8,
|
|
||||||
|
|
||||||
tau: 2,
|
|
||||||
},
|
|
||||||
} {
|
|
||||||
// Construct a random matrix.
|
|
||||||
c := make([]float64, test.ldc*test.m)
|
|
||||||
for i := 0; i <= test.lastr; i++ {
|
|
||||||
for j := 0; j <= test.lastc; j++ {
|
|
||||||
c[i*test.ldc+j] = rnd.Float64()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cCopy := make([]float64, len(c))
|
|
||||||
copy(cCopy, c)
|
|
||||||
cCopy2 := make([]float64, len(c))
|
|
||||||
copy(cCopy2, c)
|
|
||||||
|
|
||||||
// Test with side right.
|
|
||||||
sz := max(test.m, test.n) // so v works for both right and left side.
|
|
||||||
v := make([]float64, test.incv*sz+1)
|
|
||||||
// Fill with nonzero entries up until lastv.
|
|
||||||
for i := 0; i <= test.lastv; i++ {
|
|
||||||
v[i*test.incv] = rnd.Float64()
|
|
||||||
}
|
}
|
||||||
// Construct h explicitly to compare.
|
|
||||||
h := make([]float64, test.n*test.n)
|
|
||||||
for i := 0; i < test.n; i++ {
|
|
||||||
h[i*test.n+i] = 1
|
|
||||||
}
|
}
|
||||||
hMat := blas64.General{
|
|
||||||
Rows: test.n,
|
|
||||||
Cols: test.n,
|
|
||||||
Stride: test.n,
|
|
||||||
Data: h,
|
|
||||||
}
|
}
|
||||||
vVec := blas64.Vector{
|
|
||||||
Inc: test.incv,
|
|
||||||
Data: v,
|
|
||||||
}
|
|
||||||
blas64.Ger(-test.tau, vVec, vVec, hMat)
|
|
||||||
|
|
||||||
// Apply multiplication (2nd copy is to avoid aliasing).
|
|
||||||
cMat := blas64.General{
|
|
||||||
Rows: test.m,
|
|
||||||
Cols: test.n,
|
|
||||||
Stride: test.ldc,
|
|
||||||
Data: cCopy,
|
|
||||||
}
|
|
||||||
cMat2 := blas64.General{
|
|
||||||
Rows: test.m,
|
|
||||||
Cols: test.n,
|
|
||||||
Stride: test.ldc,
|
|
||||||
Data: cCopy2,
|
|
||||||
}
|
|
||||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat2, hMat, 0, cMat)
|
|
||||||
|
|
||||||
// cMat now stores the true answer. Compare with the function call.
|
|
||||||
work := make([]float64, sz)
|
|
||||||
impl.Dlarf(blas.Right, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
|
|
||||||
if !floats.EqualApprox(c, cMat.Data, 1e-14) {
|
|
||||||
t.Errorf("Dlarf mismatch right, case %v. Want %v, got %v", i, cMat.Data, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test on the left side.
|
|
||||||
copy(c, cCopy2)
|
|
||||||
copy(cCopy, c)
|
|
||||||
// Construct h.
|
|
||||||
h = make([]float64, test.m*test.m)
|
|
||||||
for i := 0; i < test.m; i++ {
|
|
||||||
h[i*test.m+i] = 1
|
|
||||||
}
|
|
||||||
hMat = blas64.General{
|
|
||||||
Rows: test.m,
|
|
||||||
Cols: test.m,
|
|
||||||
Stride: test.m,
|
|
||||||
Data: h,
|
|
||||||
}
|
|
||||||
blas64.Ger(-test.tau, vVec, vVec, hMat)
|
|
||||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, cMat2, 0, cMat)
|
|
||||||
impl.Dlarf(blas.Left, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
|
|
||||||
if !floats.EqualApprox(c, cMat.Data, 1e-14) {
|
|
||||||
t.Errorf("Dlarf mismatch left, case %v. Want %v, got %v", i, cMat.Data, c)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dlarfTest(t *testing.T, impl Dlarfer, rnd *rand.Rand, side blas.Side, m, n, incv, ldc, nnzv, nnzc int, tau float64) {
|
||||||
|
const tol = 1e-14
|
||||||
|
|
||||||
|
c := make([]float64, m*ldc)
|
||||||
|
for i := range c {
|
||||||
|
c[i] = rnd.NormFloat64()
|
||||||
|
}
|
||||||
|
switch nnzc {
|
||||||
|
case 0:
|
||||||
|
// Zero out all of C.
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
c[i*ldc+j] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 1:
|
||||||
|
// Zero out right or bottom half of C.
|
||||||
|
if side == blas.Left {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := n / 2; j < n; j++ {
|
||||||
|
c[i*ldc+j] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := m / 2; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
c[i*ldc+j] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Leave C with random content.
|
||||||
|
}
|
||||||
|
cCopy := make([]float64, len(c))
|
||||||
|
copy(cCopy, c)
|
||||||
|
|
||||||
|
var work []float64
|
||||||
|
if side == blas.Left {
|
||||||
|
work = make([]float64, n)
|
||||||
|
} else {
|
||||||
|
work = make([]float64, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
vlen := n
|
||||||
|
if side == blas.Left {
|
||||||
|
vlen = m
|
||||||
|
}
|
||||||
|
vlen = max(1, vlen)
|
||||||
|
v := make([]float64, 1+(vlen-1)*incv)
|
||||||
|
for i := range v {
|
||||||
|
v[i] = rnd.NormFloat64()
|
||||||
|
}
|
||||||
|
switch nnzv {
|
||||||
|
case 0:
|
||||||
|
// Zero out all of v.
|
||||||
|
for i := 0; i < vlen; i++ {
|
||||||
|
v[i*incv] = 0
|
||||||
|
}
|
||||||
|
case 1:
|
||||||
|
// Zero out half of v.
|
||||||
|
for i := vlen / 2; i < vlen; i++ {
|
||||||
|
v[i*incv] = 0
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Leave v with random content.
|
||||||
|
}
|
||||||
|
vCopy := make([]float64, len(v))
|
||||||
|
copy(vCopy, v)
|
||||||
|
|
||||||
|
impl.Dlarf(side, m, n, v, incv, tau, c, ldc, work)
|
||||||
|
got := c
|
||||||
|
|
||||||
|
name := fmt.Sprintf("m=%d,n=%d,incv=%d,tau=%f,ldc=%d", m, n, incv, tau, ldc)
|
||||||
|
|
||||||
|
if !floats.Equal(v, vCopy) {
|
||||||
|
t.Errorf("%v: unexpected modification of v", name)
|
||||||
|
}
|
||||||
|
if tau == 0 && !floats.Equal(got, cCopy) {
|
||||||
|
t.Errorf("%v: unexpected modification of C", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m == 0 || n == 0 || tau == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bi := blas64.Implementation()
|
||||||
|
|
||||||
|
want := make([]float64, len(cCopy))
|
||||||
|
if side == blas.Left {
|
||||||
|
// Compute want = (I - tau * v * vᵀ) * C
|
||||||
|
|
||||||
|
// vtc = -tau * vᵀ * C = -tau * Cᵀ * v
|
||||||
|
vtc := make([]float64, n)
|
||||||
|
bi.Dgemv(blas.Trans, m, n, -tau, cCopy, ldc, v, incv, 0, vtc, 1)
|
||||||
|
|
||||||
|
// want = C + v * vtcᵀ
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
want[i*ldc+j] = cCopy[i*ldc+j] + v[i*incv]*vtc[j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Compute want = C * (I - tau * v * vᵀ)
|
||||||
|
|
||||||
|
// cv = -tau * C * v
|
||||||
|
cv := make([]float64, m)
|
||||||
|
bi.Dgemv(blas.NoTrans, m, n, -tau, cCopy, ldc, v, incv, 0, cv, 1)
|
||||||
|
|
||||||
|
// want = C + cv * vᵀ
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
want[i*ldc+j] = cCopy[i*ldc+j] + cv[i]*v[j*incv]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
diff := make([]float64, m*n)
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
diff[i*n+j] = got[i*ldc+j] - want[i*ldc+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resid := dlange(lapack.MaxColumnSum, m, n, diff, n)
|
||||||
|
if resid > tol*float64(max(m, n)) {
|
||||||
|
t.Errorf("%v: unexpected result; resid=%v, want<=%v", name, resid, tol*float64(max(m, n)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user