mirror of
https://github.com/gonum/gonum.git
synced 2025-10-20 21:59:25 +08:00
Add dgetrf and tests
removed unnecessary requirement Implement Dlaswp in both directions Responded to PR comments
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
// Copied from lapack/native. Keep in sync.
|
// Copied from lapack/native. Keep in sync.
|
||||||
const (
|
const (
|
||||||
|
absIncNotOne = "lapack: increment not one or negative one"
|
||||||
badDirect = "lapack: bad direct"
|
badDirect = "lapack: bad direct"
|
||||||
badIpiv = "lapack: insufficient permutation length"
|
badIpiv = "lapack: insufficient permutation length"
|
||||||
badLdA = "lapack: index of a out of range"
|
badLdA = "lapack: index of a out of range"
|
||||||
@@ -76,7 +77,7 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok
|
|||||||
return clapack.Dpotrf(ul, n, a, lda)
|
return clapack.Dpotrf(ul, n, a, lda)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dgetf2 computes the LU decomposition of the m×n matrix a.
|
// Dgetf2 computes the LU decomposition of the m×n matrix A.
|
||||||
// The LU decomposition is a factorization of a into
|
// The LU decomposition is a factorization of a into
|
||||||
// A = P * L * U
|
// A = P * L * U
|
||||||
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
||||||
@@ -85,9 +86,9 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok
|
|||||||
//
|
//
|
||||||
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
||||||
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
||||||
// otherwise.
|
// otherwise. ipiv is zero-indexed.
|
||||||
//
|
//
|
||||||
// Dgetf2 returns whether the matrix a is singular. The LU decomposition will
|
// Dgetf2 returns whether the matrix A is singular. The LU decomposition will
|
||||||
// be computed regardless of the singularity of A, but division by zero
|
// be computed regardless of the singularity of A, but division by zero
|
||||||
// will occur if the false is returned and the result is used to solve a
|
// will occur if the false is returned and the result is used to solve a
|
||||||
// system of equations.
|
// system of equations.
|
||||||
@@ -100,7 +101,38 @@ func (Implementation) Dgetf2(m, n int, a []float64, lda int, ipiv []int) (ok boo
|
|||||||
ipiv32 := make([]int32, len(ipiv))
|
ipiv32 := make([]int32, len(ipiv))
|
||||||
ok = clapack.Dgetf2(m, n, a, lda, ipiv32)
|
ok = clapack.Dgetf2(m, n, a, lda, ipiv32)
|
||||||
for i, v := range ipiv32 {
|
for i, v := range ipiv32 {
|
||||||
ipiv[i] = int(v) - 1 // OpenBLAS returns one indexed.
|
ipiv[i] = int(v) - 1 // Transform to zero-indexed.
|
||||||
|
}
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dgetrf computes the LU decomposition of the m×n matrix A.
|
||||||
|
// The LU decomposition is a factorization of a into
|
||||||
|
// A = P * L * U
|
||||||
|
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
||||||
|
// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored
|
||||||
|
// in place into a.
|
||||||
|
//
|
||||||
|
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
||||||
|
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
||||||
|
// otherwise. ipiv is zero-indexed.
|
||||||
|
//
|
||||||
|
// Dgetrf is the blocked version of the algorithm.
|
||||||
|
//
|
||||||
|
// Dgetrf returns whether the matrix A is singular. The LU decomposition will
|
||||||
|
// be computed regardless of the singularity of A, but division by zero
|
||||||
|
// will occur if the false is returned and the result is used to solve a
|
||||||
|
// system of equations.
|
||||||
|
func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) {
|
||||||
|
mn := min(m, n)
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if len(ipiv) < mn {
|
||||||
|
panic(badIpiv)
|
||||||
|
}
|
||||||
|
ipiv32 := make([]int32, len(ipiv))
|
||||||
|
ok = clapack.Dgetrf(m, n, a, lda, ipiv32)
|
||||||
|
for i, v := range ipiv32 {
|
||||||
|
ipiv[i] = int(v) - 1 // Transform to zero-indexed.
|
||||||
}
|
}
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
@@ -19,3 +19,7 @@ func TestDpotrf(t *testing.T) {
|
|||||||
func TestDgetf2(t *testing.T) {
|
func TestDgetf2(t *testing.T) {
|
||||||
testlapack.Dgetf2Test(t, impl)
|
testlapack.Dgetf2Test(t, impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDgetrf(t *testing.T) {
|
||||||
|
testlapack.DgetrfTest(t, impl)
|
||||||
|
}
|
||||||
|
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/gonum/lapack"
|
"github.com/gonum/lapack"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Dgeqrf computes the QR factorization of the m×n matrix a using a blocked
|
// Dgeqrf computes the QR factorization of the m×n matrix A using a blocked
|
||||||
// algorithm. Please see the documentation for Dgeqr2 for a description of the
|
// algorithm. Please see the documentation for Dgeqr2 for a description of the
|
||||||
// parameters at entry and exit.
|
// parameters at entry and exit.
|
||||||
//
|
//
|
||||||
@@ -21,9 +21,6 @@ import (
|
|||||||
//
|
//
|
||||||
// tau must be at least len min(m,n), and this function will panic otherwise.
|
// tau must be at least len min(m,n), and this function will panic otherwise.
|
||||||
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
|
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
|
||||||
// TODO(btracey): This algorithm is oriented for column-major storage.
|
|
||||||
// Consider modifying the algorithm to better suit row-major storage.
|
|
||||||
|
|
||||||
// nb is the optimal blocksize, i.e. the number of columns transformed at a time.
|
// nb is the optimal blocksize, i.e. the number of columns transformed at a time.
|
||||||
nb := impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1)
|
nb := impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1)
|
||||||
lworkopt := n * max(nb, 1)
|
lworkopt := n * max(nb, 1)
|
||||||
|
@@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/gonum/blas/blas64"
|
"github.com/gonum/blas/blas64"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Dgetf2 computes the LU decomposition of the m×n matrix a.
|
// Dgetf2 computes the LU decomposition of the m×n matrix A.
|
||||||
// The LU decomposition is a factorization of a into
|
// The LU decomposition is a factorization of a into
|
||||||
// A = P * L * U
|
// A = P * L * U
|
||||||
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
||||||
@@ -15,9 +15,9 @@ import (
|
|||||||
//
|
//
|
||||||
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
||||||
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
||||||
// otherwise.
|
// otherwise. ipiv is zero-indexed.
|
||||||
//
|
//
|
||||||
// Dgetf2 returns whether the matrix a is singular. The LU decomposition will
|
// Dgetf2 returns whether the matrix A is singular. The LU decomposition will
|
||||||
// be computed regardless of the singularity of A, but division by zero
|
// be computed regardless of the singularity of A, but division by zero
|
||||||
// will occur if the false is returned and the result is used to solve a
|
// will occur if the false is returned and the result is used to solve a
|
||||||
// system of equations.
|
// system of equations.
|
||||||
|
66
native/dgetrf.go
Normal file
66
native/dgetrf.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dgetrf computes the LU decomposition of the m×n matrix a.
|
||||||
|
// The LU decomposition is a factorization of a into
|
||||||
|
// A = P * L * U
|
||||||
|
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
||||||
|
// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored
|
||||||
|
// in place into a.
|
||||||
|
//
|
||||||
|
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
||||||
|
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
||||||
|
// otherwise. ipiv is zero-indexed.
|
||||||
|
//
|
||||||
|
// Dgetrf is the blocked version of the algorithm.
|
||||||
|
//
|
||||||
|
// Dgetrf returns whether the matrix A is singular. The LU decomposition will
|
||||||
|
// be computed regardless of the singularity of A, but division by zero
|
||||||
|
// will occur if the false is returned and the result is used to solve a
|
||||||
|
// system of equations.
|
||||||
|
func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) {
|
||||||
|
mn := min(m, n)
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if len(ipiv) < mn {
|
||||||
|
panic(badIpiv)
|
||||||
|
}
|
||||||
|
if m == 0 || n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bi := blas64.Implementation()
|
||||||
|
nb := impl.Ilaenv(1, "DGETRF", " ", m, n, -1, -1)
|
||||||
|
if nb <= 1 || nb >= min(m, n) {
|
||||||
|
// Use the unblocked algorithm.
|
||||||
|
return impl.Dgetf2(m, n, a, lda, ipiv)
|
||||||
|
}
|
||||||
|
ok = true
|
||||||
|
for j := 0; j < mn; j += nb {
|
||||||
|
jb := min(mn-j, nb)
|
||||||
|
blockOk := impl.Dgetf2(m-j, jb, a[j*lda+j:], lda, ipiv[j:])
|
||||||
|
if !blockOk {
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
|
for i := j; i <= min(m-1, j+jb-1); i++ {
|
||||||
|
ipiv[i] = j + ipiv[i]
|
||||||
|
}
|
||||||
|
impl.Dlaswp(j, a, lda, j, j+jb-1, ipiv, 1)
|
||||||
|
if j+jb < n {
|
||||||
|
impl.Dlaswp(n-j-jb, a[j+jb:], lda, j, j+jb-1, ipiv, 1)
|
||||||
|
bi.Dtrsm(blas.Left, blas.Lower, blas.NoTrans, blas.Unit,
|
||||||
|
jb, n-j-jb, 1,
|
||||||
|
a[j*lda+j:], lda,
|
||||||
|
a[j*lda+j+jb:], lda)
|
||||||
|
if j+jb < m {
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.NoTrans, m-j-jb, n-j-jb, jb, -1,
|
||||||
|
a[(j+jb)*lda+j:], lda,
|
||||||
|
a[j*lda+j+jb:], lda,
|
||||||
|
1, a[(j+jb)*lda+j+jb:], lda)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ok
|
||||||
|
}
|
23
native/dlaswp.go
Normal file
23
native/dlaswp.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package native
|
||||||
|
|
||||||
|
import "github.com/gonum/blas/blas64"
|
||||||
|
|
||||||
|
// Dlaswp swaps the rows k1 to k2 of a according to the indices in ipiv.
|
||||||
|
// a is a matrix with n columns and stride lda. incX is the increment for ipiv.
|
||||||
|
// k1 and k2 are zero-indexed. If incX is negative, then loops from k2 to k1
|
||||||
|
func (impl Implementation) Dlaswp(n int, a []float64, lda, k1, k2 int, ipiv []int, incX int) {
|
||||||
|
if incX != 1 && incX != -1 {
|
||||||
|
panic(absIncNotOne)
|
||||||
|
}
|
||||||
|
bi := blas64.Implementation()
|
||||||
|
if incX == 1 {
|
||||||
|
for k := k1; k <= k2; k++ {
|
||||||
|
bi.Dswap(n, a[k*lda:], 1, a[ipiv[k]*lda:], 1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for k := k2; k >= k1; k-- {
|
||||||
|
bi.Dswap(n, a[k*lda:], 1, a[ipiv[k]*lda:], 1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
@@ -17,7 +17,9 @@ type Implementation struct{}
|
|||||||
|
|
||||||
var _ lapack.Float64 = Implementation{}
|
var _ lapack.Float64 = Implementation{}
|
||||||
|
|
||||||
|
// This list is duplicated in lapack/cgo. Keep in sync.
|
||||||
const (
|
const (
|
||||||
|
absIncNotOne = "lapack: increment not one or negative one"
|
||||||
badDirect = "lapack: bad direct"
|
badDirect = "lapack: bad direct"
|
||||||
badIpiv = "lapack: insufficient permutation length"
|
badIpiv = "lapack: insufficient permutation length"
|
||||||
badLdA = "lapack: index of a out of range"
|
badLdA = "lapack: index of a out of range"
|
||||||
|
@@ -36,6 +36,10 @@ func TestDgetf2(t *testing.T) {
|
|||||||
testlapack.Dgetf2Test(t, impl)
|
testlapack.Dgetf2Test(t, impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDgetrf(t *testing.T) {
|
||||||
|
testlapack.DgetrfTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDlange(t *testing.T) {
|
func TestDlange(t *testing.T) {
|
||||||
testlapack.DlangeTest(t, impl)
|
testlapack.DlangeTest(t, impl)
|
||||||
}
|
}
|
||||||
|
@@ -40,86 +40,11 @@ func Dgetf2Test(t *testing.T, impl Dgetf2er) {
|
|||||||
|
|
||||||
mn := min(m, n)
|
mn := min(m, n)
|
||||||
ipiv := make([]int, mn)
|
ipiv := make([]int, mn)
|
||||||
|
for i := range ipiv {
|
||||||
|
ipiv[i] = rand.Int()
|
||||||
|
}
|
||||||
ok := impl.Dgetf2(m, n, a, lda, ipiv)
|
ok := impl.Dgetf2(m, n, a, lda, ipiv)
|
||||||
var hasZeroDiagonal bool
|
checkPLU(t, ok, m, n, lda, ipiv, a, aCopy, 1e-14, true)
|
||||||
for i := 0; i < min(m, n); i++ {
|
|
||||||
if a[i*lda+i] == 0 {
|
|
||||||
hasZeroDiagonal = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if hasZeroDiagonal && ok {
|
|
||||||
t.Errorf("Has a zero diagonal but returned ok")
|
|
||||||
}
|
|
||||||
if !hasZeroDiagonal && !ok {
|
|
||||||
t.Errorf("Non-zero diagonal but returned !ok")
|
|
||||||
}
|
|
||||||
// Check that the LU decomposition is correct.
|
|
||||||
l := make([]float64, m*mn)
|
|
||||||
ldl := mn
|
|
||||||
u := make([]float64, mn*n)
|
|
||||||
ldu := n
|
|
||||||
for i := 0; i < m; i++ {
|
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
v := a[i*lda+j]
|
|
||||||
switch {
|
|
||||||
case i == j:
|
|
||||||
l[i*ldl+i] = 1
|
|
||||||
u[i*ldu+i] = v
|
|
||||||
case i > j:
|
|
||||||
l[i*ldl+j] = v
|
|
||||||
case i < j:
|
|
||||||
u[i*ldu+j] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LU := blas64.General{
|
|
||||||
Rows: m,
|
|
||||||
Cols: n,
|
|
||||||
Stride: n,
|
|
||||||
Data: make([]float64, m*n),
|
|
||||||
}
|
|
||||||
U := blas64.General{
|
|
||||||
Rows: mn,
|
|
||||||
Cols: n,
|
|
||||||
Stride: ldu,
|
|
||||||
Data: u,
|
|
||||||
}
|
|
||||||
L := blas64.General{
|
|
||||||
Rows: m,
|
|
||||||
Cols: mn,
|
|
||||||
Stride: ldl,
|
|
||||||
Data: l,
|
|
||||||
}
|
|
||||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU)
|
|
||||||
|
|
||||||
p := make([]float64, m*m)
|
|
||||||
ldp := m
|
|
||||||
for i := 0; i < m; i++ {
|
|
||||||
p[i*ldp+i] = 1
|
|
||||||
}
|
|
||||||
for i := len(ipiv) - 1; i >= 0; i-- {
|
|
||||||
v := ipiv[i]
|
|
||||||
blas64.Swap(m, blas64.Vector{1, p[i*ldp:]}, blas64.Vector{1, p[v*ldp:]})
|
|
||||||
}
|
|
||||||
P := blas64.General{
|
|
||||||
Rows: m,
|
|
||||||
Cols: m,
|
|
||||||
Stride: m,
|
|
||||||
Data: p,
|
|
||||||
}
|
|
||||||
aComp := blas64.General{
|
|
||||||
Rows: m,
|
|
||||||
Cols: n,
|
|
||||||
Stride: lda,
|
|
||||||
Data: make([]float64, m*lda),
|
|
||||||
}
|
|
||||||
copy(aComp.Data, a)
|
|
||||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp)
|
|
||||||
if !floats.EqualApprox(aComp.Data, aCopy, 1e-14) {
|
|
||||||
t.Errorf("Answer mismatch.\nWant\n %v,\nGot %v.", aCopy, aComp.Data)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test with singular matrices (random matrices are almost surely non-singular).
|
// Test with singular matrices (random matrices are almost surely non-singular).
|
||||||
@@ -173,3 +98,93 @@ func Dgetf2Test(t *testing.T, impl Dgetf2er) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkPLU checks that the PLU factorization contained in factorize matches
|
||||||
|
// the original matrix contained in original.
|
||||||
|
func checkPLU(t *testing.T, ok bool, m, n, lda int, ipiv []int, factorized, original []float64, tol float64, print bool) {
|
||||||
|
var hasZeroDiagonal bool
|
||||||
|
for i := 0; i < min(m, n); i++ {
|
||||||
|
if factorized[i*lda+i] == 0 {
|
||||||
|
hasZeroDiagonal = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasZeroDiagonal && ok {
|
||||||
|
t.Errorf("Has a zero diagonal but returned ok")
|
||||||
|
}
|
||||||
|
if !hasZeroDiagonal && !ok {
|
||||||
|
t.Errorf("Non-zero diagonal but returned !ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the LU decomposition is correct.
|
||||||
|
mn := min(m, n)
|
||||||
|
l := make([]float64, m*mn)
|
||||||
|
ldl := mn
|
||||||
|
u := make([]float64, mn*n)
|
||||||
|
ldu := n
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
v := factorized[i*lda+j]
|
||||||
|
switch {
|
||||||
|
case i == j:
|
||||||
|
l[i*ldl+i] = 1
|
||||||
|
u[i*ldu+i] = v
|
||||||
|
case i > j:
|
||||||
|
l[i*ldl+j] = v
|
||||||
|
case i < j:
|
||||||
|
u[i*ldu+j] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LU := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: n,
|
||||||
|
Stride: n,
|
||||||
|
Data: make([]float64, m*n),
|
||||||
|
}
|
||||||
|
U := blas64.General{
|
||||||
|
Rows: mn,
|
||||||
|
Cols: n,
|
||||||
|
Stride: ldu,
|
||||||
|
Data: u,
|
||||||
|
}
|
||||||
|
L := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: mn,
|
||||||
|
Stride: ldl,
|
||||||
|
Data: l,
|
||||||
|
}
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU)
|
||||||
|
|
||||||
|
p := make([]float64, m*m)
|
||||||
|
ldp := m
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
p[i*ldp+i] = 1
|
||||||
|
}
|
||||||
|
for i := len(ipiv) - 1; i >= 0; i-- {
|
||||||
|
v := ipiv[i]
|
||||||
|
blas64.Swap(m, blas64.Vector{1, p[i*ldp:]}, blas64.Vector{1, p[v*ldp:]})
|
||||||
|
}
|
||||||
|
P := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: m,
|
||||||
|
Stride: m,
|
||||||
|
Data: p,
|
||||||
|
}
|
||||||
|
aComp := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: n,
|
||||||
|
Stride: lda,
|
||||||
|
Data: make([]float64, m*lda),
|
||||||
|
}
|
||||||
|
copy(aComp.Data, factorized)
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp)
|
||||||
|
if !floats.EqualApprox(aComp.Data, original, tol) {
|
||||||
|
if print {
|
||||||
|
t.Errorf("PLU multiplication does not match original matrix.\nWant: %v\nGot: %v", original, aComp.Data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Errorf("PLU multiplication does not match original matrix.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
60
testlapack/dgetrf.go
Normal file
60
testlapack/dgetrf.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dgetrfer interface {
|
||||||
|
Dgetrf(m, n int, a []float64, lda int, ipiv []int) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func DgetrfTest(t *testing.T, impl Dgetrfer) {
|
||||||
|
for _, test := range []struct {
|
||||||
|
m, n, lda int
|
||||||
|
}{
|
||||||
|
{10, 5, 0},
|
||||||
|
{5, 10, 0},
|
||||||
|
{10, 10, 0},
|
||||||
|
{300, 5, 0},
|
||||||
|
{3, 500, 0},
|
||||||
|
{4, 5, 0},
|
||||||
|
{300, 200, 0},
|
||||||
|
{204, 300, 0},
|
||||||
|
{1, 3000, 0},
|
||||||
|
{3000, 1, 0},
|
||||||
|
{10, 5, 20},
|
||||||
|
{5, 10, 20},
|
||||||
|
{10, 10, 20},
|
||||||
|
{300, 5, 400},
|
||||||
|
{3, 500, 600},
|
||||||
|
{200, 200, 300},
|
||||||
|
{300, 200, 300},
|
||||||
|
{204, 300, 400},
|
||||||
|
{1, 3000, 4000},
|
||||||
|
{3000, 1, 4000},
|
||||||
|
} {
|
||||||
|
m := test.m
|
||||||
|
n := test.n
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = n
|
||||||
|
}
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
mn := min(m, n)
|
||||||
|
ipiv := make([]int, mn)
|
||||||
|
for i := range ipiv {
|
||||||
|
ipiv[i] = rand.Int()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cannot compare the outputs of Dgetrf and Dgetf2 because the pivoting may
|
||||||
|
// happen differently. Instead check that the LPQ factorization is correct.
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
ok := impl.Dgetrf(m, n, a, lda, ipiv)
|
||||||
|
checkPLU(t, ok, m, n, lda, ipiv, a, aCopy, 1e-10, false)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user