diff --git a/cgo/lapack.go b/cgo/lapack.go index bbfa6174..29993f2a 100644 --- a/cgo/lapack.go +++ b/cgo/lapack.go @@ -36,6 +36,13 @@ func min(m, n int) int { return n } +func max(m, n int) int { + if m < n { + return n + } + return m +} + // checkMatrix verifies the parameters of a matrix input. // Copied from lapack/native. Keep in sync. func checkMatrix(m, n int, a []float64, lda int) { @@ -193,6 +200,50 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl clapack.Dgeqrf(m, n, a, lda, tau) } +// Dgels finds a minimum-norm solution based on the matrices A and B using the +// QR or LQ factorization. Dgels returns false if the matrix +// A is singular, and true if this solution was successfully found. +// +// The minimization problem solved depends on the input parameters. +// +// 1. If m >= n and trans == blas.NoTrans, Dgels finds X such that || A*X - B||_2 +// is minimized. +// 2. If m < n and trans == blas.NoTrans, Dgels finds the minimum norm solution of +// A * X = B. +// 3. If m >= n and trans == blas.Trans, Dgels finds the minimum norm solution of +// A^T * X = B. +// 4. If m < n and trans == blas.Trans, Dgels finds X such that || A*X - B||_2 +// is minimized. +// Note that the least-squares solutions (cases 1 and 3) perform the minimization +// per column of B. This is not the same as finding the minimum-norm matrix. +// +// The matrix A is a general matrix of size m×n and is modified during this call. +// The input matrix B is of size max(m,n)×nrhs, and serves two purposes. On entry, +// the elements of b specify the input matrix B. B has size m×nrhs if +// trans == blas.NoTrans, and n×nrhs if trans == blas.Trans. On exit, the +// leading submatrix of b contains the solution vectors X. If trans == blas.NoTrans, +// this submatrix is of size n×nrhs, and of size m×nrhs otherwise. +// +// The C interface does not support providing temporary storage. To provide compatibility +// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum +// work necessary to work[0]. If len(work) < lwork, Dgeqrf will panic. +func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool { + mn := min(m, n) + if lwork == -1 { + work[0] = float64(mn + max(mn, nrhs)) + return true + } + checkMatrix(m, n, a, lda) + checkMatrix(mn, nrhs, b, ldb) + if len(work) < lwork { + panic(shortWork) + } + if lwork < mn+max(mn, nrhs) { + panic(badWork) + } + return clapack.Dgels(trans, m, n, nrhs, a, lda, b, ldb) +} + // Dgetf2 computes the LU decomposition of the m×n matrix A. // The LU decomposition is a factorization of a into // A = P * L * U @@ -223,7 +274,7 @@ func (Implementation) Dgetf2(m, n int, a []float64, lda int, ipiv []int) (ok boo } // Dgetrf computes the LU decomposition of the m×n matrix A. -// The LU decomposition is a factorization of a into +// The LU decomposition is a factorization of A into // A = P * L * U // where P is a permutation matrix, L is a unit lower triangular matrix, and // U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored diff --git a/cgo/lapack_test.go b/cgo/lapack_test.go index f81067a4..88de4367 100644 --- a/cgo/lapack_test.go +++ b/cgo/lapack_test.go @@ -20,6 +20,10 @@ func TestDgelq2(t *testing.T) { testlapack.Dgelq2Test(t, impl) } +func TestDgels(t *testing.T) { + testlapack.DgelsTest(t, impl) +} + func TestDgelqf(t *testing.T) { testlapack.DgelqfTest(t, impl) } diff --git a/lapack.go b/lapack.go index fbfa51fa..89eb0b99 100644 --- a/lapack.go +++ b/lapack.go @@ -23,9 +23,12 @@ type Complex128 interface{} // Float64 defines the public float64 LAPACK API supported by gonum/lapack. type Float64 interface { + Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool Dgelqf(m, n int, a []float64, lda int, tau, work []float64, lwork int) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) + Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) + Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) } // Direct specifies the direction of the multiplication for the Householder matrix. diff --git a/lapack64/lapack64.go b/lapack64/lapack64.go index e3c67f7f..685c2afa 100644 --- a/lapack64/lapack64.go +++ b/lapack64/lapack64.go @@ -48,6 +48,39 @@ func Potrf(a blas64.Symmetric) (t blas64.Triangular, ok bool) { return } +// Gels finds a minimum-norm solution based on the matrices A and B using the +// QR or LQ factorization. Dgels returns false if the matrix +// A is singular, and true if this solution was successfully found. +// +// The minimization problem solved depends on the input parameters. +// +// 1. If m >= n and trans == blas.NoTrans, Dgels finds X such that || A*X - B||_2 +// is minimized. +// 2. If m < n and trans == blas.NoTrans, Dgels finds the minimum norm solution of +// A * X = B. +// 3. If m >= n and trans == blas.Trans, Dgels finds the minimum norm solution of +// A^T * X = B. +// 4. If m < n and trans == blas.Trans, Dgels finds X such that || A*X - B||_2 +// is minimized. +// Note that the least-squares solutions (cases 1 and 3) perform the minimization +// per column of B. This is not the same as finding the minimum-norm matrix. +// +// The matrix A is a general matrix of size m×n and is modified during this call. +// The input matrix B is of size max(m,n)×nrhs, and serves two purposes. On entry, +// the elements of b specify the input matrix B. B has size m×nrhs if +// trans == blas.NoTrans, and n×nrhs if trans == blas.Trans. On exit, the +// leading submatrix of b contains the solution vectors X. If trans == blas.NoTrans, +// this submatrix is of size n×nrhs, and of size m×nrhs otherwise. +// +// Work is temporary storage, and lwork specifies the usable memory length. +// At minimum, lwork >= max(m,n) + max(m,n,nrhs), and this function will panic +// otherwise. A longer work will enable blocked algorithms to be called. +// In the special case that lwork == -1, work[0] will be set to the optimal working +// length. +func Gels(trans blas.Transpose, a blas64.General, b blas64.General, work []float64, lwork int) { + lapack64.Dgels(trans, a.Rows, a.Cols, b.Cols, a.Data, a.Stride, b.Data, b.Stride, work, lwork) +} + // Geqrf computes the QR factorization of the m×n matrix A using a blocked // algorithm. A is modified to contain the information to construct Q and R. // The upper triangle of a contains the matrix R. The lower triangular elements @@ -93,3 +126,39 @@ func Geqrf(a blas64.General, tau, work []float64, lwork int) { func Gelqf(a blas64.General, tau, work []float64, lwork int) { lapack64.Dgelqf(a.Rows, a.Cols, a.Data, a.Stride, tau, work, lwork) } + +// Getrf computes the LU decomposition of the m×n matrix A. +// The LU decomposition is a factorization of A into +// A = P * L * U +// where P is a permutation matrix, L is a unit lower triangular matrix, and +// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored +// in place into a. +// +// ipiv is a permutation vector. It indicates that row i of the matrix was +// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic +// otherwise. ipiv is zero-indexed. +// +// Dgetrf is the blocked version of the algorithm. +// +// Dgetrf returns whether the matrix A is singular. The LU decomposition will +// be computed regardless of the singularity of A, but division by zero +// will occur if the false is returned and the result is used to solve a +// system of equations. +func Getrf(a blas64.General, ipiv []int) bool { + return lapack64.Dgetrf(a.Rows, a.Cols, a.Data, a.Stride, ipiv) +} + +// Dgetrs solves a system of equations using an LU factorization. +// The system of equations solved is +// A * X = B if trans == blas.Trans +// A^T * X = B if trans == blas.NoTrans +// A is a general n×n matrix with stride lda. B is a general matrix of size n×nrhs. +// +// On entry b contains the elements of the matrix B. On exit, b contains the +// elements of X, the solution to the system of equations. +// +// a and ipiv contain the LU factorization of A and the permutation indices as +// computed by Getrf. ipiv is zero-indexed. +func Getrs(trans blas.Transpose, a blas64.General, b blas64.General, ipiv []int) { + lapack64.Dgetrs(trans, a.Cols, b.Cols, a.Data, a.Stride, ipiv, b.Data, b.Stride) +} diff --git a/native/dgels.go b/native/dgels.go index 7759561b..47bfa68e 100644 --- a/native/dgels.go +++ b/native/dgels.go @@ -9,25 +9,25 @@ import ( "github.com/gonum/lapack" ) -// Dgels finds a minimum-norm solution based on the matrices a and b using the +// Dgels finds a minimum-norm solution based on the matrices A and B using the // QR or LQ factorization. Dgels returns false if the matrix // A is singular, and true if this solution was successfully found. // // The minimization problem solved depends on the input parameters. // // 1. If m >= n and trans == blas.NoTrans, Dgels finds X such that || A*X - B||_2 -// is minimized. +// is minimized. // 2. If m < n and trans == blas.NoTrans, Dgels finds the minimum norm solution of -// A * X = B. +// A * X = B. // 3. If m >= n and trans == blas.Trans, Dgels finds the minimum norm solution of -// A^T * X = B. +// A^T * X = B. // 4. If m < n and trans == blas.Trans, Dgels finds X such that || A*X - B||_2 -// is minimized. +// is minimized. // Note that the least-squares solutions (cases 1 and 3) perform the minimization // per column of B. This is not the same as finding the minimum-norm matrix. // -// The matrix a is a general matrix of size m×n and is modified during this call. -// The input matrix b is of size max(m,n)×nrhs, and serves two purposes. On entry, +// The matrix A is a general matrix of size m×n and is modified during this call. +// The input matrix B is of size max(m,n)×nrhs, and serves two purposes. On entry, // the elements of b specify the input matrix B. B has size m×nrhs if // trans == blas.NoTrans, and n×nrhs if trans == blas.Trans. On exit, the // leading submatrix of b contains the solution vectors X. If trans == blas.NoTrans, diff --git a/native/dgetrf.go b/native/dgetrf.go index 9fe17661..97d55fc2 100644 --- a/native/dgetrf.go +++ b/native/dgetrf.go @@ -5,8 +5,8 @@ import ( "github.com/gonum/blas/blas64" ) -// Dgetrf computes the LU decomposition of the m×n matrix a. -// The LU decomposition is a factorization of a into +// Dgetrf computes the LU decomposition of the m×n matrix A. +// The LU decomposition is a factorization of A into // A = P * L * U // where P is a permutation matrix, L is a unit lower triangular matrix, and // U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored