mirror of
https://github.com/gonum/gonum.git
synced 2025-10-04 23:02:42 +08:00

blas64: add length field N to Vector Alongside, fix the implementation of mat.VecDense and mat.Diagonal, as well as other changes needed to fix this change. Fixes #736.
198 lines
3.7 KiB
Go
198 lines
3.7 KiB
Go
// Copyright ©2015 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package testlapack
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"golang.org/x/exp/rand"
|
|
|
|
"gonum.org/v1/gonum/blas"
|
|
"gonum.org/v1/gonum/blas/blas64"
|
|
"gonum.org/v1/gonum/floats"
|
|
)
|
|
|
|
type Dgetf2er interface {
|
|
Dgetf2(m, n int, a []float64, lda int, ipiv []int) bool
|
|
}
|
|
|
|
func Dgetf2Test(t *testing.T, impl Dgetf2er) {
|
|
rnd := rand.New(rand.NewSource(1))
|
|
for _, test := range []struct {
|
|
m, n, lda int
|
|
}{
|
|
{10, 10, 0},
|
|
{10, 5, 0},
|
|
{10, 5, 0},
|
|
|
|
{10, 10, 20},
|
|
{5, 10, 20},
|
|
{10, 5, 20},
|
|
} {
|
|
m := test.m
|
|
n := test.n
|
|
lda := test.lda
|
|
if lda == 0 {
|
|
lda = n
|
|
}
|
|
a := make([]float64, m*lda)
|
|
for i := range a {
|
|
a[i] = rnd.Float64()
|
|
}
|
|
aCopy := make([]float64, len(a))
|
|
copy(aCopy, a)
|
|
|
|
mn := min(m, n)
|
|
ipiv := make([]int, mn)
|
|
for i := range ipiv {
|
|
ipiv[i] = rnd.Int()
|
|
}
|
|
ok := impl.Dgetf2(m, n, a, lda, ipiv)
|
|
checkPLU(t, ok, m, n, lda, ipiv, a, aCopy, 1e-14, true)
|
|
}
|
|
|
|
// Test with singular matrices (random matrices are almost surely non-singular).
|
|
for _, test := range []struct {
|
|
m, n, lda int
|
|
a []float64
|
|
}{
|
|
{
|
|
m: 2,
|
|
n: 2,
|
|
lda: 2,
|
|
a: []float64{
|
|
1, 0,
|
|
0, 0,
|
|
},
|
|
},
|
|
{
|
|
m: 2,
|
|
n: 2,
|
|
lda: 2,
|
|
a: []float64{
|
|
1, 5,
|
|
2, 10,
|
|
},
|
|
},
|
|
{
|
|
m: 3,
|
|
n: 3,
|
|
lda: 3,
|
|
// row 3 = row1 + 2 * row2
|
|
a: []float64{
|
|
1, 5, 7,
|
|
2, 10, -3,
|
|
5, 25, 1,
|
|
},
|
|
},
|
|
{
|
|
m: 3,
|
|
n: 4,
|
|
lda: 4,
|
|
// row 3 = row1 + 2 * row2
|
|
a: []float64{
|
|
1, 5, 7, 9,
|
|
2, 10, -3, 11,
|
|
5, 25, 1, 31,
|
|
},
|
|
},
|
|
} {
|
|
if impl.Dgetf2(test.m, test.n, test.a, test.lda, make([]int, min(test.m, test.n))) {
|
|
t.Log("Returned ok with singular matrix.")
|
|
}
|
|
}
|
|
}
|
|
|
|
// checkPLU checks that the PLU factorization contained in factorize matches
|
|
// the original matrix contained in original.
|
|
func checkPLU(t *testing.T, ok bool, m, n, lda int, ipiv []int, factorized, original []float64, tol float64, print bool) {
|
|
var hasZeroDiagonal bool
|
|
for i := 0; i < min(m, n); i++ {
|
|
if factorized[i*lda+i] == 0 {
|
|
hasZeroDiagonal = true
|
|
break
|
|
}
|
|
}
|
|
if hasZeroDiagonal && ok {
|
|
t.Error("Has a zero diagonal but returned ok")
|
|
}
|
|
if !hasZeroDiagonal && !ok {
|
|
t.Error("Non-zero diagonal but returned !ok")
|
|
}
|
|
|
|
// Check that the LU decomposition is correct.
|
|
mn := min(m, n)
|
|
l := make([]float64, m*mn)
|
|
ldl := mn
|
|
u := make([]float64, mn*n)
|
|
ldu := n
|
|
for i := 0; i < m; i++ {
|
|
for j := 0; j < n; j++ {
|
|
v := factorized[i*lda+j]
|
|
switch {
|
|
case i == j:
|
|
l[i*ldl+i] = 1
|
|
u[i*ldu+i] = v
|
|
case i > j:
|
|
l[i*ldl+j] = v
|
|
case i < j:
|
|
u[i*ldu+j] = v
|
|
}
|
|
}
|
|
}
|
|
|
|
LU := blas64.General{
|
|
Rows: m,
|
|
Cols: n,
|
|
Stride: n,
|
|
Data: make([]float64, m*n),
|
|
}
|
|
U := blas64.General{
|
|
Rows: mn,
|
|
Cols: n,
|
|
Stride: ldu,
|
|
Data: u,
|
|
}
|
|
L := blas64.General{
|
|
Rows: m,
|
|
Cols: mn,
|
|
Stride: ldl,
|
|
Data: l,
|
|
}
|
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU)
|
|
|
|
p := make([]float64, m*m)
|
|
ldp := m
|
|
for i := 0; i < m; i++ {
|
|
p[i*ldp+i] = 1
|
|
}
|
|
for i := len(ipiv) - 1; i >= 0; i-- {
|
|
v := ipiv[i]
|
|
blas64.Swap(blas64.Vector{N: m, Inc: 1, Data: p[i*ldp:]},
|
|
blas64.Vector{N: m, Inc: 1, Data: p[v*ldp:]})
|
|
}
|
|
P := blas64.General{
|
|
Rows: m,
|
|
Cols: m,
|
|
Stride: m,
|
|
Data: p,
|
|
}
|
|
aComp := blas64.General{
|
|
Rows: m,
|
|
Cols: n,
|
|
Stride: lda,
|
|
Data: make([]float64, m*lda),
|
|
}
|
|
copy(aComp.Data, factorized)
|
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp)
|
|
if !floats.EqualApprox(aComp.Data, original, tol) {
|
|
if print {
|
|
t.Errorf("PLU multiplication does not match original matrix.\nWant: %v\nGot: %v", original, aComp.Data)
|
|
return
|
|
}
|
|
t.Error("PLU multiplication does not match original matrix.")
|
|
}
|
|
}
|