mirror of
https://github.com/gonum/gonum.git
synced 2025-10-20 21:59:25 +08:00
Add dgetrf and tests
removed unnecessary requirement Implement Dlaswp in both directions Responded to PR comments
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
// Copied from lapack/native. Keep in sync.
|
||||
const (
|
||||
absIncNotOne = "lapack: increment not one or negative one"
|
||||
badDirect = "lapack: bad direct"
|
||||
badIpiv = "lapack: insufficient permutation length"
|
||||
badLdA = "lapack: index of a out of range"
|
||||
@@ -76,7 +77,7 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok
|
||||
return clapack.Dpotrf(ul, n, a, lda)
|
||||
}
|
||||
|
||||
// Dgetf2 computes the LU decomposition of the m×n matrix a.
|
||||
// Dgetf2 computes the LU decomposition of the m×n matrix A.
|
||||
// The LU decomposition is a factorization of a into
|
||||
// A = P * L * U
|
||||
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
||||
@@ -85,9 +86,9 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok
|
||||
//
|
||||
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
||||
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
||||
// otherwise.
|
||||
// otherwise. ipiv is zero-indexed.
|
||||
//
|
||||
// Dgetf2 returns whether the matrix a is singular. The LU decomposition will
|
||||
// Dgetf2 returns whether the matrix A is singular. The LU decomposition will
|
||||
// be computed regardless of the singularity of A, but division by zero
|
||||
// will occur if the false is returned and the result is used to solve a
|
||||
// system of equations.
|
||||
@@ -100,7 +101,38 @@ func (Implementation) Dgetf2(m, n int, a []float64, lda int, ipiv []int) (ok boo
|
||||
ipiv32 := make([]int32, len(ipiv))
|
||||
ok = clapack.Dgetf2(m, n, a, lda, ipiv32)
|
||||
for i, v := range ipiv32 {
|
||||
ipiv[i] = int(v) - 1 // OpenBLAS returns one indexed.
|
||||
ipiv[i] = int(v) - 1 // Transform to zero-indexed.
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// Dgetrf computes the LU decomposition of the m×n matrix A.
|
||||
// The LU decomposition is a factorization of a into
|
||||
// A = P * L * U
|
||||
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
||||
// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored
|
||||
// in place into a.
|
||||
//
|
||||
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
||||
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
||||
// otherwise. ipiv is zero-indexed.
|
||||
//
|
||||
// Dgetrf is the blocked version of the algorithm.
|
||||
//
|
||||
// Dgetrf returns whether the matrix A is singular. The LU decomposition will
|
||||
// be computed regardless of the singularity of A, but division by zero
|
||||
// will occur if the false is returned and the result is used to solve a
|
||||
// system of equations.
|
||||
func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) {
|
||||
mn := min(m, n)
|
||||
checkMatrix(m, n, a, lda)
|
||||
if len(ipiv) < mn {
|
||||
panic(badIpiv)
|
||||
}
|
||||
ipiv32 := make([]int32, len(ipiv))
|
||||
ok = clapack.Dgetrf(m, n, a, lda, ipiv32)
|
||||
for i, v := range ipiv32 {
|
||||
ipiv[i] = int(v) - 1 // Transform to zero-indexed.
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
@@ -19,3 +19,7 @@ func TestDpotrf(t *testing.T) {
|
||||
func TestDgetf2(t *testing.T) {
|
||||
testlapack.Dgetf2Test(t, impl)
|
||||
}
|
||||
|
||||
func TestDgetrf(t *testing.T) {
|
||||
testlapack.DgetrfTest(t, impl)
|
||||
}
|
||||
|
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
// Dgeqrf computes the QR factorization of the m×n matrix a using a blocked
|
||||
// Dgeqrf computes the QR factorization of the m×n matrix A using a blocked
|
||||
// algorithm. Please see the documentation for Dgeqr2 for a description of the
|
||||
// parameters at entry and exit.
|
||||
//
|
||||
@@ -21,9 +21,6 @@ import (
|
||||
//
|
||||
// tau must be at least len min(m,n), and this function will panic otherwise.
|
||||
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
|
||||
// TODO(btracey): This algorithm is oriented for column-major storage.
|
||||
// Consider modifying the algorithm to better suit row-major storage.
|
||||
|
||||
// nb is the optimal blocksize, i.e. the number of columns transformed at a time.
|
||||
nb := impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1)
|
||||
lworkopt := n * max(nb, 1)
|
||||
|
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
// Dgetf2 computes the LU decomposition of the m×n matrix a.
|
||||
// Dgetf2 computes the LU decomposition of the m×n matrix A.
|
||||
// The LU decomposition is a factorization of a into
|
||||
// A = P * L * U
|
||||
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
//
|
||||
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
||||
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
||||
// otherwise.
|
||||
// otherwise. ipiv is zero-indexed.
|
||||
//
|
||||
// Dgetf2 returns whether the matrix a is singular. The LU decomposition will
|
||||
// Dgetf2 returns whether the matrix A is singular. The LU decomposition will
|
||||
// be computed regardless of the singularity of A, but division by zero
|
||||
// will occur if the false is returned and the result is used to solve a
|
||||
// system of equations.
|
||||
|
66
native/dgetrf.go
Normal file
66
native/dgetrf.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package native
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
// Dgetrf computes the LU decomposition of the m×n matrix a.
|
||||
// The LU decomposition is a factorization of a into
|
||||
// A = P * L * U
|
||||
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
||||
// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored
|
||||
// in place into a.
|
||||
//
|
||||
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
||||
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
||||
// otherwise. ipiv is zero-indexed.
|
||||
//
|
||||
// Dgetrf is the blocked version of the algorithm.
|
||||
//
|
||||
// Dgetrf returns whether the matrix A is singular. The LU decomposition will
|
||||
// be computed regardless of the singularity of A, but division by zero
|
||||
// will occur if the false is returned and the result is used to solve a
|
||||
// system of equations.
|
||||
func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) {
|
||||
mn := min(m, n)
|
||||
checkMatrix(m, n, a, lda)
|
||||
if len(ipiv) < mn {
|
||||
panic(badIpiv)
|
||||
}
|
||||
if m == 0 || n == 0 {
|
||||
return
|
||||
}
|
||||
bi := blas64.Implementation()
|
||||
nb := impl.Ilaenv(1, "DGETRF", " ", m, n, -1, -1)
|
||||
if nb <= 1 || nb >= min(m, n) {
|
||||
// Use the unblocked algorithm.
|
||||
return impl.Dgetf2(m, n, a, lda, ipiv)
|
||||
}
|
||||
ok = true
|
||||
for j := 0; j < mn; j += nb {
|
||||
jb := min(mn-j, nb)
|
||||
blockOk := impl.Dgetf2(m-j, jb, a[j*lda+j:], lda, ipiv[j:])
|
||||
if !blockOk {
|
||||
ok = false
|
||||
}
|
||||
for i := j; i <= min(m-1, j+jb-1); i++ {
|
||||
ipiv[i] = j + ipiv[i]
|
||||
}
|
||||
impl.Dlaswp(j, a, lda, j, j+jb-1, ipiv, 1)
|
||||
if j+jb < n {
|
||||
impl.Dlaswp(n-j-jb, a[j+jb:], lda, j, j+jb-1, ipiv, 1)
|
||||
bi.Dtrsm(blas.Left, blas.Lower, blas.NoTrans, blas.Unit,
|
||||
jb, n-j-jb, 1,
|
||||
a[j*lda+j:], lda,
|
||||
a[j*lda+j+jb:], lda)
|
||||
if j+jb < m {
|
||||
bi.Dgemm(blas.NoTrans, blas.NoTrans, m-j-jb, n-j-jb, jb, -1,
|
||||
a[(j+jb)*lda+j:], lda,
|
||||
a[j*lda+j+jb:], lda,
|
||||
1, a[(j+jb)*lda+j+jb:], lda)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ok
|
||||
}
|
23
native/dlaswp.go
Normal file
23
native/dlaswp.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package native
|
||||
|
||||
import "github.com/gonum/blas/blas64"
|
||||
|
||||
// Dlaswp swaps the rows k1 to k2 of a according to the indices in ipiv.
|
||||
// a is a matrix with n columns and stride lda. incX is the increment for ipiv.
|
||||
// k1 and k2 are zero-indexed. If incX is negative, then loops from k2 to k1
|
||||
func (impl Implementation) Dlaswp(n int, a []float64, lda, k1, k2 int, ipiv []int, incX int) {
|
||||
if incX != 1 && incX != -1 {
|
||||
panic(absIncNotOne)
|
||||
}
|
||||
bi := blas64.Implementation()
|
||||
if incX == 1 {
|
||||
for k := k1; k <= k2; k++ {
|
||||
bi.Dswap(n, a[k*lda:], 1, a[ipiv[k]*lda:], 1)
|
||||
}
|
||||
return
|
||||
}
|
||||
for k := k2; k >= k1; k-- {
|
||||
bi.Dswap(n, a[k*lda:], 1, a[ipiv[k]*lda:], 1)
|
||||
}
|
||||
return
|
||||
}
|
@@ -17,7 +17,9 @@ type Implementation struct{}
|
||||
|
||||
var _ lapack.Float64 = Implementation{}
|
||||
|
||||
// This list is duplicated in lapack/cgo. Keep in sync.
|
||||
const (
|
||||
absIncNotOne = "lapack: increment not one or negative one"
|
||||
badDirect = "lapack: bad direct"
|
||||
badIpiv = "lapack: insufficient permutation length"
|
||||
badLdA = "lapack: index of a out of range"
|
||||
|
@@ -36,6 +36,10 @@ func TestDgetf2(t *testing.T) {
|
||||
testlapack.Dgetf2Test(t, impl)
|
||||
}
|
||||
|
||||
func TestDgetrf(t *testing.T) {
|
||||
testlapack.DgetrfTest(t, impl)
|
||||
}
|
||||
|
||||
func TestDlange(t *testing.T) {
|
||||
testlapack.DlangeTest(t, impl)
|
||||
}
|
||||
|
@@ -40,86 +40,11 @@ func Dgetf2Test(t *testing.T, impl Dgetf2er) {
|
||||
|
||||
mn := min(m, n)
|
||||
ipiv := make([]int, mn)
|
||||
for i := range ipiv {
|
||||
ipiv[i] = rand.Int()
|
||||
}
|
||||
ok := impl.Dgetf2(m, n, a, lda, ipiv)
|
||||
var hasZeroDiagonal bool
|
||||
for i := 0; i < min(m, n); i++ {
|
||||
if a[i*lda+i] == 0 {
|
||||
hasZeroDiagonal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasZeroDiagonal && ok {
|
||||
t.Errorf("Has a zero diagonal but returned ok")
|
||||
}
|
||||
if !hasZeroDiagonal && !ok {
|
||||
t.Errorf("Non-zero diagonal but returned !ok")
|
||||
}
|
||||
// Check that the LU decomposition is correct.
|
||||
l := make([]float64, m*mn)
|
||||
ldl := mn
|
||||
u := make([]float64, mn*n)
|
||||
ldu := n
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
v := a[i*lda+j]
|
||||
switch {
|
||||
case i == j:
|
||||
l[i*ldl+i] = 1
|
||||
u[i*ldu+i] = v
|
||||
case i > j:
|
||||
l[i*ldl+j] = v
|
||||
case i < j:
|
||||
u[i*ldu+j] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LU := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, m*n),
|
||||
}
|
||||
U := blas64.General{
|
||||
Rows: mn,
|
||||
Cols: n,
|
||||
Stride: ldu,
|
||||
Data: u,
|
||||
}
|
||||
L := blas64.General{
|
||||
Rows: m,
|
||||
Cols: mn,
|
||||
Stride: ldl,
|
||||
Data: l,
|
||||
}
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU)
|
||||
|
||||
p := make([]float64, m*m)
|
||||
ldp := m
|
||||
for i := 0; i < m; i++ {
|
||||
p[i*ldp+i] = 1
|
||||
}
|
||||
for i := len(ipiv) - 1; i >= 0; i-- {
|
||||
v := ipiv[i]
|
||||
blas64.Swap(m, blas64.Vector{1, p[i*ldp:]}, blas64.Vector{1, p[v*ldp:]})
|
||||
}
|
||||
P := blas64.General{
|
||||
Rows: m,
|
||||
Cols: m,
|
||||
Stride: m,
|
||||
Data: p,
|
||||
}
|
||||
aComp := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: lda,
|
||||
Data: make([]float64, m*lda),
|
||||
}
|
||||
copy(aComp.Data, a)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp)
|
||||
if !floats.EqualApprox(aComp.Data, aCopy, 1e-14) {
|
||||
t.Errorf("Answer mismatch.\nWant\n %v,\nGot %v.", aCopy, aComp.Data)
|
||||
}
|
||||
checkPLU(t, ok, m, n, lda, ipiv, a, aCopy, 1e-14, true)
|
||||
}
|
||||
|
||||
// Test with singular matrices (random matrices are almost surely non-singular).
|
||||
@@ -173,3 +98,93 @@ func Dgetf2Test(t *testing.T, impl Dgetf2er) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkPLU checks that the PLU factorization contained in factorize matches
|
||||
// the original matrix contained in original.
|
||||
func checkPLU(t *testing.T, ok bool, m, n, lda int, ipiv []int, factorized, original []float64, tol float64, print bool) {
|
||||
var hasZeroDiagonal bool
|
||||
for i := 0; i < min(m, n); i++ {
|
||||
if factorized[i*lda+i] == 0 {
|
||||
hasZeroDiagonal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasZeroDiagonal && ok {
|
||||
t.Errorf("Has a zero diagonal but returned ok")
|
||||
}
|
||||
if !hasZeroDiagonal && !ok {
|
||||
t.Errorf("Non-zero diagonal but returned !ok")
|
||||
}
|
||||
|
||||
// Check that the LU decomposition is correct.
|
||||
mn := min(m, n)
|
||||
l := make([]float64, m*mn)
|
||||
ldl := mn
|
||||
u := make([]float64, mn*n)
|
||||
ldu := n
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
v := factorized[i*lda+j]
|
||||
switch {
|
||||
case i == j:
|
||||
l[i*ldl+i] = 1
|
||||
u[i*ldu+i] = v
|
||||
case i > j:
|
||||
l[i*ldl+j] = v
|
||||
case i < j:
|
||||
u[i*ldu+j] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LU := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, m*n),
|
||||
}
|
||||
U := blas64.General{
|
||||
Rows: mn,
|
||||
Cols: n,
|
||||
Stride: ldu,
|
||||
Data: u,
|
||||
}
|
||||
L := blas64.General{
|
||||
Rows: m,
|
||||
Cols: mn,
|
||||
Stride: ldl,
|
||||
Data: l,
|
||||
}
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU)
|
||||
|
||||
p := make([]float64, m*m)
|
||||
ldp := m
|
||||
for i := 0; i < m; i++ {
|
||||
p[i*ldp+i] = 1
|
||||
}
|
||||
for i := len(ipiv) - 1; i >= 0; i-- {
|
||||
v := ipiv[i]
|
||||
blas64.Swap(m, blas64.Vector{1, p[i*ldp:]}, blas64.Vector{1, p[v*ldp:]})
|
||||
}
|
||||
P := blas64.General{
|
||||
Rows: m,
|
||||
Cols: m,
|
||||
Stride: m,
|
||||
Data: p,
|
||||
}
|
||||
aComp := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: lda,
|
||||
Data: make([]float64, m*lda),
|
||||
}
|
||||
copy(aComp.Data, factorized)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp)
|
||||
if !floats.EqualApprox(aComp.Data, original, tol) {
|
||||
if print {
|
||||
t.Errorf("PLU multiplication does not match original matrix.\nWant: %v\nGot: %v", original, aComp.Data)
|
||||
return
|
||||
}
|
||||
t.Errorf("PLU multiplication does not match original matrix.")
|
||||
}
|
||||
}
|
||||
|
60
testlapack/dgetrf.go
Normal file
60
testlapack/dgetrf.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type Dgetrfer interface {
|
||||
Dgetrf(m, n int, a []float64, lda int, ipiv []int) bool
|
||||
}
|
||||
|
||||
func DgetrfTest(t *testing.T, impl Dgetrfer) {
|
||||
for _, test := range []struct {
|
||||
m, n, lda int
|
||||
}{
|
||||
{10, 5, 0},
|
||||
{5, 10, 0},
|
||||
{10, 10, 0},
|
||||
{300, 5, 0},
|
||||
{3, 500, 0},
|
||||
{4, 5, 0},
|
||||
{300, 200, 0},
|
||||
{204, 300, 0},
|
||||
{1, 3000, 0},
|
||||
{3000, 1, 0},
|
||||
{10, 5, 20},
|
||||
{5, 10, 20},
|
||||
{10, 10, 20},
|
||||
{300, 5, 400},
|
||||
{3, 500, 600},
|
||||
{200, 200, 300},
|
||||
{300, 200, 300},
|
||||
{204, 300, 400},
|
||||
{1, 3000, 4000},
|
||||
{3000, 1, 4000},
|
||||
} {
|
||||
m := test.m
|
||||
n := test.n
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
a := make([]float64, m*lda)
|
||||
for i := range a {
|
||||
a[i] = rand.Float64()
|
||||
}
|
||||
mn := min(m, n)
|
||||
ipiv := make([]int, mn)
|
||||
for i := range ipiv {
|
||||
ipiv[i] = rand.Int()
|
||||
}
|
||||
|
||||
// Cannot compare the outputs of Dgetrf and Dgetf2 because the pivoting may
|
||||
// happen differently. Instead check that the LPQ factorization is correct.
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(aCopy, a)
|
||||
ok := impl.Dgetrf(m, n, a, lda, ipiv)
|
||||
checkPLU(t, ok, m, n, lda, ipiv, a, aCopy, 1e-10, false)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user