mirror of
https://github.com/gonum/gonum.git
synced 2025-10-20 21:59:25 +08:00
Add dgetrf and tests
removed unnecessary requirement Implement Dlaswp in both directions Responded to PR comments
This commit is contained in:
@@ -40,86 +40,11 @@ func Dgetf2Test(t *testing.T, impl Dgetf2er) {
|
||||
|
||||
mn := min(m, n)
|
||||
ipiv := make([]int, mn)
|
||||
for i := range ipiv {
|
||||
ipiv[i] = rand.Int()
|
||||
}
|
||||
ok := impl.Dgetf2(m, n, a, lda, ipiv)
|
||||
var hasZeroDiagonal bool
|
||||
for i := 0; i < min(m, n); i++ {
|
||||
if a[i*lda+i] == 0 {
|
||||
hasZeroDiagonal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasZeroDiagonal && ok {
|
||||
t.Errorf("Has a zero diagonal but returned ok")
|
||||
}
|
||||
if !hasZeroDiagonal && !ok {
|
||||
t.Errorf("Non-zero diagonal but returned !ok")
|
||||
}
|
||||
// Check that the LU decomposition is correct.
|
||||
l := make([]float64, m*mn)
|
||||
ldl := mn
|
||||
u := make([]float64, mn*n)
|
||||
ldu := n
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
v := a[i*lda+j]
|
||||
switch {
|
||||
case i == j:
|
||||
l[i*ldl+i] = 1
|
||||
u[i*ldu+i] = v
|
||||
case i > j:
|
||||
l[i*ldl+j] = v
|
||||
case i < j:
|
||||
u[i*ldu+j] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LU := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, m*n),
|
||||
}
|
||||
U := blas64.General{
|
||||
Rows: mn,
|
||||
Cols: n,
|
||||
Stride: ldu,
|
||||
Data: u,
|
||||
}
|
||||
L := blas64.General{
|
||||
Rows: m,
|
||||
Cols: mn,
|
||||
Stride: ldl,
|
||||
Data: l,
|
||||
}
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU)
|
||||
|
||||
p := make([]float64, m*m)
|
||||
ldp := m
|
||||
for i := 0; i < m; i++ {
|
||||
p[i*ldp+i] = 1
|
||||
}
|
||||
for i := len(ipiv) - 1; i >= 0; i-- {
|
||||
v := ipiv[i]
|
||||
blas64.Swap(m, blas64.Vector{1, p[i*ldp:]}, blas64.Vector{1, p[v*ldp:]})
|
||||
}
|
||||
P := blas64.General{
|
||||
Rows: m,
|
||||
Cols: m,
|
||||
Stride: m,
|
||||
Data: p,
|
||||
}
|
||||
aComp := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: lda,
|
||||
Data: make([]float64, m*lda),
|
||||
}
|
||||
copy(aComp.Data, a)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp)
|
||||
if !floats.EqualApprox(aComp.Data, aCopy, 1e-14) {
|
||||
t.Errorf("Answer mismatch.\nWant\n %v,\nGot %v.", aCopy, aComp.Data)
|
||||
}
|
||||
checkPLU(t, ok, m, n, lda, ipiv, a, aCopy, 1e-14, true)
|
||||
}
|
||||
|
||||
// Test with singular matrices (random matrices are almost surely non-singular).
|
||||
@@ -173,3 +98,93 @@ func Dgetf2Test(t *testing.T, impl Dgetf2er) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkPLU checks that the PLU factorization contained in factorize matches
|
||||
// the original matrix contained in original.
|
||||
func checkPLU(t *testing.T, ok bool, m, n, lda int, ipiv []int, factorized, original []float64, tol float64, print bool) {
|
||||
var hasZeroDiagonal bool
|
||||
for i := 0; i < min(m, n); i++ {
|
||||
if factorized[i*lda+i] == 0 {
|
||||
hasZeroDiagonal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasZeroDiagonal && ok {
|
||||
t.Errorf("Has a zero diagonal but returned ok")
|
||||
}
|
||||
if !hasZeroDiagonal && !ok {
|
||||
t.Errorf("Non-zero diagonal but returned !ok")
|
||||
}
|
||||
|
||||
// Check that the LU decomposition is correct.
|
||||
mn := min(m, n)
|
||||
l := make([]float64, m*mn)
|
||||
ldl := mn
|
||||
u := make([]float64, mn*n)
|
||||
ldu := n
|
||||
for i := 0; i < m; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
v := factorized[i*lda+j]
|
||||
switch {
|
||||
case i == j:
|
||||
l[i*ldl+i] = 1
|
||||
u[i*ldu+i] = v
|
||||
case i > j:
|
||||
l[i*ldl+j] = v
|
||||
case i < j:
|
||||
u[i*ldu+j] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LU := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: n,
|
||||
Data: make([]float64, m*n),
|
||||
}
|
||||
U := blas64.General{
|
||||
Rows: mn,
|
||||
Cols: n,
|
||||
Stride: ldu,
|
||||
Data: u,
|
||||
}
|
||||
L := blas64.General{
|
||||
Rows: m,
|
||||
Cols: mn,
|
||||
Stride: ldl,
|
||||
Data: l,
|
||||
}
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU)
|
||||
|
||||
p := make([]float64, m*m)
|
||||
ldp := m
|
||||
for i := 0; i < m; i++ {
|
||||
p[i*ldp+i] = 1
|
||||
}
|
||||
for i := len(ipiv) - 1; i >= 0; i-- {
|
||||
v := ipiv[i]
|
||||
blas64.Swap(m, blas64.Vector{1, p[i*ldp:]}, blas64.Vector{1, p[v*ldp:]})
|
||||
}
|
||||
P := blas64.General{
|
||||
Rows: m,
|
||||
Cols: m,
|
||||
Stride: m,
|
||||
Data: p,
|
||||
}
|
||||
aComp := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: lda,
|
||||
Data: make([]float64, m*lda),
|
||||
}
|
||||
copy(aComp.Data, factorized)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp)
|
||||
if !floats.EqualApprox(aComp.Data, original, tol) {
|
||||
if print {
|
||||
t.Errorf("PLU multiplication does not match original matrix.\nWant: %v\nGot: %v", original, aComp.Data)
|
||||
return
|
||||
}
|
||||
t.Errorf("PLU multiplication does not match original matrix.")
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user