mirror of
https://github.com/gonum/gonum.git
synced 2025-10-20 21:59:25 +08:00
AddDgetri for computing matrix inverses
This commit is contained in:
@@ -395,6 +395,38 @@ func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (o
|
|||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dgetri computes the inverse of the matrix A using the LU factorization computed
|
||||||
|
// by Dgetrf. On entry, a contains the PLU decomposition of A as computed by
|
||||||
|
// Dgetrf and on exit contains the reciprocal of the original matrix.
|
||||||
|
//
|
||||||
|
// Dtrtri will not perform the inversion if the matrix is singular, and returns
|
||||||
|
// a boolean indicating whether the inversion was successful.
|
||||||
|
//
|
||||||
|
// The C interface does not support providing temporary storage. To provide compatibility
|
||||||
|
// with native, lwork == -1 will not run Dgetri but will instead write the minimum
|
||||||
|
// work necessary to work[0]. If len(work) < lwork, Dgetri will panic.
|
||||||
|
func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) {
|
||||||
|
checkMatrix(n, n, a, lda)
|
||||||
|
if len(ipiv) < n {
|
||||||
|
panic(badIpiv)
|
||||||
|
}
|
||||||
|
if lwork == -1 {
|
||||||
|
work[0] = float64(n)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if lwork < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
if len(work) < lwork {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
ipiv32 := make([]int32, len(ipiv))
|
||||||
|
for i, v := range ipiv {
|
||||||
|
ipiv32[i] = int32(v) + 1 // Transform to one-indexed.
|
||||||
|
}
|
||||||
|
return clapack.Dgetri(n, a, lda, ipiv32)
|
||||||
|
}
|
||||||
|
|
||||||
// Dgetrs solves a system of equations using an LU factorization.
|
// Dgetrs solves a system of equations using an LU factorization.
|
||||||
// The system of equations solved is
|
// The system of equations solved is
|
||||||
// A * X = B if trans == blas.Trans
|
// A * X = B if trans == blas.Trans
|
||||||
|
@@ -57,6 +57,10 @@ func TestDgetrf(t *testing.T) {
|
|||||||
testlapack.DgetrfTest(t, impl)
|
testlapack.DgetrfTest(t, impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDgetri(t *testing.T) {
|
||||||
|
testlapack.DgetriTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDgetrs(t *testing.T) {
|
func TestDgetrs(t *testing.T) {
|
||||||
testlapack.DgetrsTest(t, impl)
|
testlapack.DgetrsTest(t, impl)
|
||||||
}
|
}
|
||||||
|
88
native/dgetri.go
Normal file
88
native/dgetri.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dgetri computes the inverse of the matrix A using the LU factorization computed
|
||||||
|
// by Dgetrf. On entry, a contains the PLU decomposition of A as computed by
|
||||||
|
// Dgetrf and on exit contains the reciprocal of the original matrix.
|
||||||
|
//
|
||||||
|
// Dgetri will not perform the inversion if the matrix is singular, and returns
|
||||||
|
// a boolean indicating whether the inversion was successful.
|
||||||
|
//
|
||||||
|
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||||
|
// At minimum, lwork >= n and this function will panic otherwise.
|
||||||
|
// Dgetri is a blocked inversion, but the block size is limited
|
||||||
|
// by the temporary space available. If lwork == -1, instead of performing Dgetri,
|
||||||
|
// the optimal work length will be stored into work[0].
|
||||||
|
func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) {
|
||||||
|
checkMatrix(n, n, a, lda)
|
||||||
|
if len(ipiv) < n {
|
||||||
|
panic(badIpiv)
|
||||||
|
}
|
||||||
|
nb := impl.Ilaenv(1, "DGETRI", " ", n, -1, -1, -1)
|
||||||
|
if lwork == -1 {
|
||||||
|
work[0] = float64(n * nb)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if lwork < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
if len(work) < lwork {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
ok = impl.Dtrtri(blas.Upper, blas.NonUnit, n, a, lda)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
nbmin := 2
|
||||||
|
ldwork := nb
|
||||||
|
if nb > 1 && nb < n {
|
||||||
|
iws := max(ldwork*n, 1)
|
||||||
|
if lwork < iws {
|
||||||
|
nb = lwork / ldwork
|
||||||
|
nbmin = max(2, impl.Ilaenv(2, "DGETRI", " ", n, -1, -1, -1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bi := blas64.Implementation()
|
||||||
|
// TODO(btracey): Replace this with a more row-major oriented algorithm.
|
||||||
|
if nb < nbmin || nb >= n {
|
||||||
|
// Unblocked code.
|
||||||
|
for j := n - 1; j >= 0; j-- {
|
||||||
|
for i := j + 1; i < n; i++ {
|
||||||
|
work[i*ldwork] = a[i*lda+j]
|
||||||
|
a[i*lda+j] = 0
|
||||||
|
}
|
||||||
|
if j < n {
|
||||||
|
bi.Dgemv(blas.NoTrans, n, n-j-1, -1, a[(j+1):], lda, work[(j+1)*ldwork:], ldwork, 1, a[j:], lda)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nn := ((n - 1) / nb) * nb
|
||||||
|
for j := nn; j >= 0; j -= nb {
|
||||||
|
jb := min(nb, n-j)
|
||||||
|
for jj := j; jj < j+jb-1; jj++ {
|
||||||
|
for i := jj + 1; i < n; i++ {
|
||||||
|
work[i*ldwork+(jj-j)] = a[i*lda+jj]
|
||||||
|
a[i*lda+jj] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if j+jb < n {
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, jb, n-j-jb, -1, a[(j+jb):], lda, work[(j+jb)*ldwork:], ldwork, 1, a[j:], lda)
|
||||||
|
bi.Dtrsm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, jb, 1, work[j*ldwork:], ldwork, a[j:], lda)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j := n - 2; j >= 0; j-- {
|
||||||
|
jp := ipiv[j]
|
||||||
|
if jp != j {
|
||||||
|
bi.Dswap(n, a[j:], lda, a[jp:], lda)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
@@ -9,8 +9,8 @@ import (
|
|||||||
// into a. This is the BLAS level 3 version of the algorithm which builds upon
|
// into a. This is the BLAS level 3 version of the algorithm which builds upon
|
||||||
// Dtrti2 to operate on matrix blocks instead of only individual columns.
|
// Dtrti2 to operate on matrix blocks instead of only individual columns.
|
||||||
//
|
//
|
||||||
// Dtrti returns whether the matrix a is singular or whether it's not singular.
|
// Dtrtri will not perform the inversion if the matrix is singular, and returns
|
||||||
// If the matrix is singular the inversion is not performed.
|
// a boolean indicating whether the inversion was successful.
|
||||||
func (impl Implementation) Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) (ok bool) {
|
func (impl Implementation) Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) (ok bool) {
|
||||||
checkMatrix(n, n, a, lda)
|
checkMatrix(n, n, a, lda)
|
||||||
if uplo != blas.Upper && uplo != blas.Lower {
|
if uplo != blas.Upper && uplo != blas.Lower {
|
||||||
|
@@ -36,6 +36,10 @@ func TestDgeqrf(t *testing.T) {
|
|||||||
testlapack.DgeqrfTest(t, impl)
|
testlapack.DgeqrfTest(t, impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDgetri(t *testing.T) {
|
||||||
|
testlapack.DgetriTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDgetf2(t *testing.T) {
|
func TestDgetf2(t *testing.T) {
|
||||||
testlapack.Dgetf2Test(t, impl)
|
testlapack.Dgetf2Test(t, impl)
|
||||||
}
|
}
|
||||||
|
84
testlapack/dgetri.go
Normal file
84
testlapack/dgetri.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dgetrier interface {
|
||||||
|
Dgetrfer
|
||||||
|
Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func DgetriTest(t *testing.T, impl Dgetrier) {
|
||||||
|
bi := blas64.Implementation()
|
||||||
|
for _, test := range []struct {
|
||||||
|
n, lda int
|
||||||
|
}{
|
||||||
|
{5, 0},
|
||||||
|
{5, 8},
|
||||||
|
{45, 0},
|
||||||
|
{45, 50},
|
||||||
|
{65, 0},
|
||||||
|
{65, 70},
|
||||||
|
{150, 0},
|
||||||
|
{150, 250},
|
||||||
|
} {
|
||||||
|
n := test.n
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = n
|
||||||
|
}
|
||||||
|
// Generate a random well conditioned matrix
|
||||||
|
perm := rand.Perm(n)
|
||||||
|
a := make([]float64, n*lda)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
a[i*lda+perm[i]] = 1
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
a[i] += 0.01 * rand.Float64()
|
||||||
|
}
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
ipiv := make([]int, n)
|
||||||
|
// Compute LU decomposition.
|
||||||
|
impl.Dgetrf(n, n, a, lda, ipiv)
|
||||||
|
// Compute inverse.
|
||||||
|
work := make([]float64, 1)
|
||||||
|
impl.Dgetri(n, a, lda, ipiv, work, -1)
|
||||||
|
work = make([]float64, int(work[0]))
|
||||||
|
lwork := len(work)
|
||||||
|
|
||||||
|
ok := impl.Dgetri(n, a, lda, ipiv, work, lwork)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Unexpected singular matrix.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that A(inv) * A = I.
|
||||||
|
ans := make([]float64, len(a))
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, aCopy, lda, a, lda, 0, ans, lda)
|
||||||
|
isEye := true
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
if i == j {
|
||||||
|
// This tolerance is so high because computing matrix inverses
|
||||||
|
// is very unstable.
|
||||||
|
if math.Abs(ans[i*lda+j]-1) > 2e-2 {
|
||||||
|
isEye = false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if math.Abs(ans[i*lda+j]) > 2e-2 {
|
||||||
|
isEye = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isEye {
|
||||||
|
t.Errorf("Inv(A) * A != I. n = %v, lda = %v", n, lda)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user