mirror of
https://github.com/gonum/gonum.git
synced 2025-10-20 21:59:25 +08:00
228 lines
7.7 KiB
Go
228 lines
7.7 KiB
Go
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||
// Use of this source code is governed by a BSD-style
|
||
// license that can be found in the LICENSE file.
|
||
|
||
// Package cgo provides an interface to bindings for a C LAPACK library.
|
||
package cgo
|
||
|
||
import (
|
||
"github.com/gonum/blas"
|
||
"github.com/gonum/lapack"
|
||
"github.com/gonum/lapack/cgo/clapack"
|
||
)
|
||
|
||
// Copied from lapack/native. Keep in sync.
|
||
const (
|
||
absIncNotOne = "lapack: increment not one or negative one"
|
||
badDirect = "lapack: bad direct"
|
||
badIpiv = "lapack: insufficient permutation length"
|
||
badLdA = "lapack: index of a out of range"
|
||
badSide = "lapack: bad side"
|
||
badStore = "lapack: bad store"
|
||
badTau = "lapack: tau has insufficient length"
|
||
badTrans = "lapack: bad trans"
|
||
badUplo = "lapack: illegal triangle"
|
||
badWork = "lapack: insufficient working memory"
|
||
badWorkStride = "lapack: insufficient working array stride"
|
||
negDimension = "lapack: negative matrix dimension"
|
||
nLT0 = "lapack: n < 0"
|
||
shortWork = "lapack: working array shorter than declared"
|
||
)
|
||
|
||
func min(m, n int) int {
|
||
if m < n {
|
||
return m
|
||
}
|
||
return n
|
||
}
|
||
|
||
// checkMatrix verifies the parameters of a matrix input.
|
||
// Copied from lapack/native. Keep in sync.
|
||
func checkMatrix(m, n int, a []float64, lda int) {
|
||
if m < 0 {
|
||
panic("lapack: has negative number of rows")
|
||
}
|
||
if m < 0 {
|
||
panic("lapack: has negative number of columns")
|
||
}
|
||
if lda < n {
|
||
panic("lapack: stride less than number of columns")
|
||
}
|
||
if len(a) < (m-1)*lda+n {
|
||
panic("lapack: insufficient matrix slice length")
|
||
}
|
||
}
|
||
|
||
// Implementation is the cgo-based C implementation of LAPACK routines.
|
||
type Implementation struct{}
|
||
|
||
var _ lapack.Float64 = Implementation{}
|
||
|
||
// Dpotrf computes the cholesky decomposition of the symmetric positive definite
|
||
// matrix a. If ul == blas.Upper, then a is stored as an upper-triangular matrix,
|
||
// and a = U U^T is stored in place into a. If ul == blas.Lower, then a = L L^T
|
||
// is computed and stored in-place into a. If a is not positive definite, false
|
||
// is returned. This is the blocked version of the algorithm.
|
||
func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) {
|
||
// ul is checked in clapack.Dpotrf.
|
||
if n < 0 {
|
||
panic(nLT0)
|
||
}
|
||
if lda < n {
|
||
panic(badLdA)
|
||
}
|
||
if n == 0 {
|
||
return true
|
||
}
|
||
return clapack.Dpotrf(ul, n, a, lda)
|
||
}
|
||
|
||
// Dgeqr2 computes a QR factorization of the m×n matrix A.
|
||
//
|
||
// In a QR factorization, Q is an m×m orthonormal matrix, and R is an
|
||
// upper triangular m×n matrix.
|
||
//
|
||
// During Dgeqr2, a is modified to contain the information to construct Q and R.
|
||
// The upper triangle of a contains the matrix R. The lower triangular elements
|
||
// (not including the diagonal) contain the elementary reflectors. Tau is modified
|
||
// to contain the reflector scales. Tau must have length at least k = min(m,n), and
|
||
// this function will panic otherwise.
|
||
//
|
||
// The ith elementary reflector can be explicitly constructed by first extracting
|
||
// the
|
||
// v[j] = 0 j < i
|
||
// v[j] = i j == i
|
||
// v[j] = a[i*lda+j] j > i
|
||
// and computing h_i = I - tau[i] * v * v^T.
|
||
//
|
||
// The orthonormal matrix Q can be constucted from a product of these elementary
|
||
// reflectors, Q = H_1*H_2 ... H_k, where k = min(m,n).
|
||
//
|
||
// Work is temporary storage of length at least n and this function will panic otherwise.
|
||
func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []float64) {
|
||
// TODO(btracey): This is oriented such that columns of a are eliminated.
|
||
// This likely could be re-arranged to take better advantage of row-major
|
||
// storage.
|
||
checkMatrix(m, n, a, lda)
|
||
if len(work) < n {
|
||
panic(badWork)
|
||
}
|
||
k := min(m, n)
|
||
if len(tau) < k {
|
||
panic(badTau)
|
||
}
|
||
clapack.Dgeqr2(m, n, a, lda, tau)
|
||
}
|
||
|
||
// Dgeqrf computes the QR factorization of the m×n matrix A using a blocked
|
||
// algorithm. Please see the documentation for Dgeqr2 for a description of the
|
||
// parameters at entry and exit.
|
||
//
|
||
// The C interface does not support providing temporary storage. To provide compatibility
|
||
// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum
|
||
// work necessary to work[0]. If len(work) < lwork, Dgels will panic.
|
||
//
|
||
// tau must be at least len min(m,n), and this function will panic otherwise.
|
||
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
|
||
if lwork == -1 {
|
||
work[0] = float64(n)
|
||
return
|
||
}
|
||
checkMatrix(m, n, a, lda)
|
||
if len(work) < lwork {
|
||
panic(shortWork)
|
||
}
|
||
if lwork < n {
|
||
panic(badWork)
|
||
}
|
||
k := min(m, n)
|
||
if len(tau) < k {
|
||
panic(badTau)
|
||
}
|
||
clapack.Dgeqrf(m, n, a, lda, tau)
|
||
}
|
||
|
||
// Dgetf2 computes the LU decomposition of the m×n matrix A.
|
||
// The LU decomposition is a factorization of a into
|
||
// A = P * L * U
|
||
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
||
// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored
|
||
// in place into a.
|
||
//
|
||
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
||
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
||
// otherwise. ipiv is zero-indexed.
|
||
//
|
||
// Dgetf2 returns whether the matrix A is singular. The LU decomposition will
|
||
// be computed regardless of the singularity of A, but division by zero
|
||
// will occur if the false is returned and the result is used to solve a
|
||
// system of equations.
|
||
func (Implementation) Dgetf2(m, n int, a []float64, lda int, ipiv []int) (ok bool) {
|
||
mn := min(m, n)
|
||
checkMatrix(m, n, a, lda)
|
||
if len(ipiv) < mn {
|
||
panic(badIpiv)
|
||
}
|
||
ipiv32 := make([]int32, len(ipiv))
|
||
ok = clapack.Dgetf2(m, n, a, lda, ipiv32)
|
||
for i, v := range ipiv32 {
|
||
ipiv[i] = int(v) - 1 // Transform to zero-indexed.
|
||
}
|
||
return ok
|
||
}
|
||
|
||
// Dgetrf computes the LU decomposition of the m×n matrix A.
|
||
// The LU decomposition is a factorization of a into
|
||
// A = P * L * U
|
||
// where P is a permutation matrix, L is a unit lower triangular matrix, and
|
||
// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored
|
||
// in place into a.
|
||
//
|
||
// ipiv is a permutation vector. It indicates that row i of the matrix was
|
||
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
|
||
// otherwise. ipiv is zero-indexed.
|
||
//
|
||
// Dgetrf is the blocked version of the algorithm.
|
||
//
|
||
// Dgetrf returns whether the matrix A is singular. The LU decomposition will
|
||
// be computed regardless of the singularity of A, but division by zero
|
||
// will occur if the false is returned and the result is used to solve a
|
||
// system of equations.
|
||
func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) {
|
||
mn := min(m, n)
|
||
checkMatrix(m, n, a, lda)
|
||
if len(ipiv) < mn {
|
||
panic(badIpiv)
|
||
}
|
||
ipiv32 := make([]int32, len(ipiv))
|
||
ok = clapack.Dgetrf(m, n, a, lda, ipiv32)
|
||
for i, v := range ipiv32 {
|
||
ipiv[i] = int(v) - 1 // Transform to zero-indexed.
|
||
}
|
||
return ok
|
||
}
|
||
|
||
// Dgetrs solves a system of equations using an LU factorization.
|
||
// The system of equations solved is
|
||
// A * X = B if trans == blas.Trans
|
||
// A^T * X = B if trans == blas.NoTrans
|
||
// A is a general n×n matrix with stride lda. B is a general matrix of size n×nrhs.
|
||
//
|
||
// On entry b contains the elements of the matrix B. On exit, b contains the
|
||
// elements of X, the solution to the system of equations.
|
||
//
|
||
// a and ipiv contain the LU factorization of A and the permutation indices as
|
||
// computed by Dgetrf. ipiv is zero-indexed.
|
||
func (impl Implementation) Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) {
|
||
checkMatrix(n, n, a, lda)
|
||
checkMatrix(n, nrhs, b, ldb)
|
||
if len(ipiv) < n {
|
||
panic(badIpiv)
|
||
}
|
||
ipiv32 := make([]int32, len(ipiv))
|
||
for i, v := range ipiv {
|
||
ipiv32[i] = int32(v) + 1 // Transform to one-indexed.
|
||
}
|
||
clapack.Dgetrs(trans, n, nrhs, a, lda, ipiv32, b, ldb)
|
||
}
|