diff --git a/cgo/lapack.go b/cgo/lapack.go index 71d32222..66480d3f 100644 --- a/cgo/lapack.go +++ b/cgo/lapack.go @@ -13,6 +13,7 @@ import ( // Copied from lapack/native. Keep in sync. const ( + absIncNotOne = "lapack: increment not one or negative one" badDirect = "lapack: bad direct" badIpiv = "lapack: insufficient permutation length" badLdA = "lapack: index of a out of range" @@ -76,7 +77,7 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok return clapack.Dpotrf(ul, n, a, lda) } -// Dgetf2 computes the LU decomposition of the m×n matrix a. +// Dgetf2 computes the LU decomposition of the m×n matrix A. // The LU decomposition is a factorization of a into // A = P * L * U // where P is a permutation matrix, L is a unit lower triangular matrix, and @@ -85,9 +86,9 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok // // ipiv is a permutation vector. It indicates that row i of the matrix was // changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic -// otherwise. +// otherwise. ipiv is zero-indexed. // -// Dgetf2 returns whether the matrix a is singular. The LU decomposition will +// Dgetf2 returns whether the matrix A is singular. The LU decomposition will // be computed regardless of the singularity of A, but division by zero // will occur if the false is returned and the result is used to solve a // system of equations. @@ -100,7 +101,38 @@ func (Implementation) Dgetf2(m, n int, a []float64, lda int, ipiv []int) (ok boo ipiv32 := make([]int32, len(ipiv)) ok = clapack.Dgetf2(m, n, a, lda, ipiv32) for i, v := range ipiv32 { - ipiv[i] = int(v) - 1 // OpenBLAS returns one indexed. + ipiv[i] = int(v) - 1 // Transform to zero-indexed. + } + return ok +} + +// Dgetrf computes the LU decomposition of the m×n matrix A. +// The LU decomposition is a factorization of a into +// A = P * L * U +// where P is a permutation matrix, L is a unit lower triangular matrix, and +// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored +// in place into a. +// +// ipiv is a permutation vector. It indicates that row i of the matrix was +// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic +// otherwise. ipiv is zero-indexed. +// +// Dgetrf is the blocked version of the algorithm. +// +// Dgetrf returns whether the matrix A is singular. The LU decomposition will +// be computed regardless of the singularity of A, but division by zero +// will occur if the false is returned and the result is used to solve a +// system of equations. +func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) { + mn := min(m, n) + checkMatrix(m, n, a, lda) + if len(ipiv) < mn { + panic(badIpiv) + } + ipiv32 := make([]int32, len(ipiv)) + ok = clapack.Dgetrf(m, n, a, lda, ipiv32) + for i, v := range ipiv32 { + ipiv[i] = int(v) - 1 // Transform to zero-indexed. } return ok } diff --git a/cgo/lapack_test.go b/cgo/lapack_test.go index 60bbdfbd..dd2ea3b1 100644 --- a/cgo/lapack_test.go +++ b/cgo/lapack_test.go @@ -19,3 +19,7 @@ func TestDpotrf(t *testing.T) { func TestDgetf2(t *testing.T) { testlapack.Dgetf2Test(t, impl) } + +func TestDgetrf(t *testing.T) { + testlapack.DgetrfTest(t, impl) +} diff --git a/native/dgeqrf.go b/native/dgeqrf.go index e7c05601..4745a08a 100644 --- a/native/dgeqrf.go +++ b/native/dgeqrf.go @@ -9,7 +9,7 @@ import ( "github.com/gonum/lapack" ) -// Dgeqrf computes the QR factorization of the m×n matrix a using a blocked +// Dgeqrf computes the QR factorization of the m×n matrix A using a blocked // algorithm. Please see the documentation for Dgeqr2 for a description of the // parameters at entry and exit. // @@ -21,9 +21,6 @@ import ( // // tau must be at least len min(m,n), and this function will panic otherwise. func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) { - // TODO(btracey): This algorithm is oriented for column-major storage. - // Consider modifying the algorithm to better suit row-major storage. - // nb is the optimal blocksize, i.e. the number of columns transformed at a time. nb := impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1) lworkopt := n * max(nb, 1) diff --git a/native/dgetf2.go b/native/dgetf2.go index ca14f8f9..953258a6 100644 --- a/native/dgetf2.go +++ b/native/dgetf2.go @@ -6,7 +6,7 @@ import ( "github.com/gonum/blas/blas64" ) -// Dgetf2 computes the LU decomposition of the m×n matrix a. +// Dgetf2 computes the LU decomposition of the m×n matrix A. // The LU decomposition is a factorization of a into // A = P * L * U // where P is a permutation matrix, L is a unit lower triangular matrix, and @@ -15,9 +15,9 @@ import ( // // ipiv is a permutation vector. It indicates that row i of the matrix was // changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic -// otherwise. +// otherwise. ipiv is zero-indexed. // -// Dgetf2 returns whether the matrix a is singular. The LU decomposition will +// Dgetf2 returns whether the matrix A is singular. The LU decomposition will // be computed regardless of the singularity of A, but division by zero // will occur if the false is returned and the result is used to solve a // system of equations. diff --git a/native/dgetrf.go b/native/dgetrf.go new file mode 100644 index 00000000..9fe17661 --- /dev/null +++ b/native/dgetrf.go @@ -0,0 +1,66 @@ +package native + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Dgetrf computes the LU decomposition of the m×n matrix a. +// The LU decomposition is a factorization of a into +// A = P * L * U +// where P is a permutation matrix, L is a unit lower triangular matrix, and +// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored +// in place into a. +// +// ipiv is a permutation vector. It indicates that row i of the matrix was +// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic +// otherwise. ipiv is zero-indexed. +// +// Dgetrf is the blocked version of the algorithm. +// +// Dgetrf returns whether the matrix A is singular. The LU decomposition will +// be computed regardless of the singularity of A, but division by zero +// will occur if the false is returned and the result is used to solve a +// system of equations. +func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) { + mn := min(m, n) + checkMatrix(m, n, a, lda) + if len(ipiv) < mn { + panic(badIpiv) + } + if m == 0 || n == 0 { + return + } + bi := blas64.Implementation() + nb := impl.Ilaenv(1, "DGETRF", " ", m, n, -1, -1) + if nb <= 1 || nb >= min(m, n) { + // Use the unblocked algorithm. + return impl.Dgetf2(m, n, a, lda, ipiv) + } + ok = true + for j := 0; j < mn; j += nb { + jb := min(mn-j, nb) + blockOk := impl.Dgetf2(m-j, jb, a[j*lda+j:], lda, ipiv[j:]) + if !blockOk { + ok = false + } + for i := j; i <= min(m-1, j+jb-1); i++ { + ipiv[i] = j + ipiv[i] + } + impl.Dlaswp(j, a, lda, j, j+jb-1, ipiv, 1) + if j+jb < n { + impl.Dlaswp(n-j-jb, a[j+jb:], lda, j, j+jb-1, ipiv, 1) + bi.Dtrsm(blas.Left, blas.Lower, blas.NoTrans, blas.Unit, + jb, n-j-jb, 1, + a[j*lda+j:], lda, + a[j*lda+j+jb:], lda) + if j+jb < m { + bi.Dgemm(blas.NoTrans, blas.NoTrans, m-j-jb, n-j-jb, jb, -1, + a[(j+jb)*lda+j:], lda, + a[j*lda+j+jb:], lda, + 1, a[(j+jb)*lda+j+jb:], lda) + } + } + } + return ok +} diff --git a/native/dlaswp.go b/native/dlaswp.go new file mode 100644 index 00000000..3b3be418 --- /dev/null +++ b/native/dlaswp.go @@ -0,0 +1,23 @@ +package native + +import "github.com/gonum/blas/blas64" + +// Dlaswp swaps the rows k1 to k2 of a according to the indices in ipiv. +// a is a matrix with n columns and stride lda. incX is the increment for ipiv. +// k1 and k2 are zero-indexed. If incX is negative, then loops from k2 to k1 +func (impl Implementation) Dlaswp(n int, a []float64, lda, k1, k2 int, ipiv []int, incX int) { + if incX != 1 && incX != -1 { + panic(absIncNotOne) + } + bi := blas64.Implementation() + if incX == 1 { + for k := k1; k <= k2; k++ { + bi.Dswap(n, a[k*lda:], 1, a[ipiv[k]*lda:], 1) + } + return + } + for k := k2; k >= k1; k-- { + bi.Dswap(n, a[k*lda:], 1, a[ipiv[k]*lda:], 1) + } + return +} diff --git a/native/general.go b/native/general.go index d9901177..d91b4ad3 100644 --- a/native/general.go +++ b/native/general.go @@ -17,7 +17,9 @@ type Implementation struct{} var _ lapack.Float64 = Implementation{} +// This list is duplicated in lapack/cgo. Keep in sync. const ( + absIncNotOne = "lapack: increment not one or negative one" badDirect = "lapack: bad direct" badIpiv = "lapack: insufficient permutation length" badLdA = "lapack: index of a out of range" diff --git a/native/lapack_test.go b/native/lapack_test.go index 1da45504..9418dd52 100644 --- a/native/lapack_test.go +++ b/native/lapack_test.go @@ -36,6 +36,10 @@ func TestDgetf2(t *testing.T) { testlapack.Dgetf2Test(t, impl) } +func TestDgetrf(t *testing.T) { + testlapack.DgetrfTest(t, impl) +} + func TestDlange(t *testing.T) { testlapack.DlangeTest(t, impl) } diff --git a/testlapack/dgetf2.go b/testlapack/dgetf2.go index 07bbdf7c..6c3040e2 100644 --- a/testlapack/dgetf2.go +++ b/testlapack/dgetf2.go @@ -40,86 +40,11 @@ func Dgetf2Test(t *testing.T, impl Dgetf2er) { mn := min(m, n) ipiv := make([]int, mn) + for i := range ipiv { + ipiv[i] = rand.Int() + } ok := impl.Dgetf2(m, n, a, lda, ipiv) - var hasZeroDiagonal bool - for i := 0; i < min(m, n); i++ { - if a[i*lda+i] == 0 { - hasZeroDiagonal = true - break - } - } - if hasZeroDiagonal && ok { - t.Errorf("Has a zero diagonal but returned ok") - } - if !hasZeroDiagonal && !ok { - t.Errorf("Non-zero diagonal but returned !ok") - } - // Check that the LU decomposition is correct. - l := make([]float64, m*mn) - ldl := mn - u := make([]float64, mn*n) - ldu := n - for i := 0; i < m; i++ { - for j := 0; j < n; j++ { - v := a[i*lda+j] - switch { - case i == j: - l[i*ldl+i] = 1 - u[i*ldu+i] = v - case i > j: - l[i*ldl+j] = v - case i < j: - u[i*ldu+j] = v - } - } - } - - LU := blas64.General{ - Rows: m, - Cols: n, - Stride: n, - Data: make([]float64, m*n), - } - U := blas64.General{ - Rows: mn, - Cols: n, - Stride: ldu, - Data: u, - } - L := blas64.General{ - Rows: m, - Cols: mn, - Stride: ldl, - Data: l, - } - blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU) - - p := make([]float64, m*m) - ldp := m - for i := 0; i < m; i++ { - p[i*ldp+i] = 1 - } - for i := len(ipiv) - 1; i >= 0; i-- { - v := ipiv[i] - blas64.Swap(m, blas64.Vector{1, p[i*ldp:]}, blas64.Vector{1, p[v*ldp:]}) - } - P := blas64.General{ - Rows: m, - Cols: m, - Stride: m, - Data: p, - } - aComp := blas64.General{ - Rows: m, - Cols: n, - Stride: lda, - Data: make([]float64, m*lda), - } - copy(aComp.Data, a) - blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp) - if !floats.EqualApprox(aComp.Data, aCopy, 1e-14) { - t.Errorf("Answer mismatch.\nWant\n %v,\nGot %v.", aCopy, aComp.Data) - } + checkPLU(t, ok, m, n, lda, ipiv, a, aCopy, 1e-14, true) } // Test with singular matrices (random matrices are almost surely non-singular). @@ -173,3 +98,93 @@ func Dgetf2Test(t *testing.T, impl Dgetf2er) { } } } + +// checkPLU checks that the PLU factorization contained in factorize matches +// the original matrix contained in original. +func checkPLU(t *testing.T, ok bool, m, n, lda int, ipiv []int, factorized, original []float64, tol float64, print bool) { + var hasZeroDiagonal bool + for i := 0; i < min(m, n); i++ { + if factorized[i*lda+i] == 0 { + hasZeroDiagonal = true + break + } + } + if hasZeroDiagonal && ok { + t.Errorf("Has a zero diagonal but returned ok") + } + if !hasZeroDiagonal && !ok { + t.Errorf("Non-zero diagonal but returned !ok") + } + + // Check that the LU decomposition is correct. + mn := min(m, n) + l := make([]float64, m*mn) + ldl := mn + u := make([]float64, mn*n) + ldu := n + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + v := factorized[i*lda+j] + switch { + case i == j: + l[i*ldl+i] = 1 + u[i*ldu+i] = v + case i > j: + l[i*ldl+j] = v + case i < j: + u[i*ldu+j] = v + } + } + } + + LU := blas64.General{ + Rows: m, + Cols: n, + Stride: n, + Data: make([]float64, m*n), + } + U := blas64.General{ + Rows: mn, + Cols: n, + Stride: ldu, + Data: u, + } + L := blas64.General{ + Rows: m, + Cols: mn, + Stride: ldl, + Data: l, + } + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU) + + p := make([]float64, m*m) + ldp := m + for i := 0; i < m; i++ { + p[i*ldp+i] = 1 + } + for i := len(ipiv) - 1; i >= 0; i-- { + v := ipiv[i] + blas64.Swap(m, blas64.Vector{1, p[i*ldp:]}, blas64.Vector{1, p[v*ldp:]}) + } + P := blas64.General{ + Rows: m, + Cols: m, + Stride: m, + Data: p, + } + aComp := blas64.General{ + Rows: m, + Cols: n, + Stride: lda, + Data: make([]float64, m*lda), + } + copy(aComp.Data, factorized) + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp) + if !floats.EqualApprox(aComp.Data, original, tol) { + if print { + t.Errorf("PLU multiplication does not match original matrix.\nWant: %v\nGot: %v", original, aComp.Data) + return + } + t.Errorf("PLU multiplication does not match original matrix.") + } +} diff --git a/testlapack/dgetrf.go b/testlapack/dgetrf.go new file mode 100644 index 00000000..e8004b65 --- /dev/null +++ b/testlapack/dgetrf.go @@ -0,0 +1,60 @@ +package testlapack + +import ( + "math/rand" + "testing" +) + +type Dgetrfer interface { + Dgetrf(m, n int, a []float64, lda int, ipiv []int) bool +} + +func DgetrfTest(t *testing.T, impl Dgetrfer) { + for _, test := range []struct { + m, n, lda int + }{ + {10, 5, 0}, + {5, 10, 0}, + {10, 10, 0}, + {300, 5, 0}, + {3, 500, 0}, + {4, 5, 0}, + {300, 200, 0}, + {204, 300, 0}, + {1, 3000, 0}, + {3000, 1, 0}, + {10, 5, 20}, + {5, 10, 20}, + {10, 10, 20}, + {300, 5, 400}, + {3, 500, 600}, + {200, 200, 300}, + {300, 200, 300}, + {204, 300, 400}, + {1, 3000, 4000}, + {3000, 1, 4000}, + } { + m := test.m + n := test.n + lda := test.lda + if lda == 0 { + lda = n + } + a := make([]float64, m*lda) + for i := range a { + a[i] = rand.Float64() + } + mn := min(m, n) + ipiv := make([]int, mn) + for i := range ipiv { + ipiv[i] = rand.Int() + } + + // Cannot compare the outputs of Dgetrf and Dgetf2 because the pivoting may + // happen differently. Instead check that the LPQ factorization is correct. + aCopy := make([]float64, len(a)) + copy(aCopy, a) + ok := impl.Dgetrf(m, n, a, lda, ipiv) + checkPLU(t, ok, m, n, lda, ipiv, a, aCopy, 1e-10, false) + } +}