diff --git a/cgo/lapack.go b/cgo/lapack.go index ad34cbd3..d9c51ccd 100644 --- a/cgo/lapack.go +++ b/cgo/lapack.go @@ -576,6 +576,21 @@ func (impl Implementation) Dtrcon(norm lapack.MatrixNorm, uplo blas.Uplo, diag b return rcond[0] } +// Dtrtri computes the inverse of a triangular matrix, storing the result in place +// into a. This is the BLAS level 3 version of the algorithm. +// +// Dtrti returns whether the matrix a is singular. +func (impl Implementation) Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) (ok bool) { + checkMatrix(n, n, a, lda) + if uplo != blas.Upper && uplo != blas.Lower { + panic(badUplo) + } + if diag != blas.NonUnit && diag != blas.Unit { + panic(badDiag) + } + return clapack.Dtrtri(uplo, diag, n, a, lda) +} + // Dtrtrs solves a triangular system of the form A * X = B or A^T * X = B. Dtrtrs // returns whether the solve completed successfully. If A is singular, no solve is performed. func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) { diff --git a/cgo/lapack_test.go b/cgo/lapack_test.go index 71676166..017e78ae 100644 --- a/cgo/lapack_test.go +++ b/cgo/lapack_test.go @@ -94,3 +94,7 @@ func TestDpocon(t *testing.T) { func TestDtrcon(t *testing.T) { testlapack.DtrconTest(t, impl) } + +func TestDtrtri(t *testing.T) { + testlapack.DtrtriTest(t, impl) +} diff --git a/native/dtrti2.go b/native/dtrti2.go new file mode 100644 index 00000000..69e1f3eb --- /dev/null +++ b/native/dtrti2.go @@ -0,0 +1,51 @@ +package native + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Dtrti2 computes the inverse of a triangular matrix, storing the result in place +// into a. This is the BLAS level 2 version of the algorithm. +func (impl Implementation) Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) { + checkMatrix(n, n, a, lda) + if uplo != blas.Upper && uplo != blas.Lower { + panic(badUplo) + } + if diag != blas.NonUnit && diag != blas.Unit { + panic(badDiag) + } + bi := blas64.Implementation() + + nonUnit := diag == blas.NonUnit + // TODO(btracey): Replace this with a row-major ordering. + if uplo == blas.Upper { + for j := 0; j < n; j++ { + var ajj float64 + if nonUnit { + ajj = 1 / a[j*lda+j] + a[j*lda+j] = ajj + ajj *= -1 + } else { + ajj = -1 + } + bi.Dtrmv(blas.Upper, blas.NoTrans, diag, j, a, lda, a[j:], lda) + bi.Dscal(j, ajj, a[j:], lda) + } + return + } + for j := n - 1; j >= 0; j-- { + var ajj float64 + if nonUnit { + ajj = 1 / a[j*lda+j] + a[j*lda+j] = ajj + ajj *= -1 + } else { + ajj = -1 + } + if j < n-1 { + bi.Dtrmv(blas.Lower, blas.NoTrans, diag, n-j-1, a[(j+1)*lda+j+1:], lda, a[(j+1)*lda+j:], lda) + bi.Dscal(n-j-1, ajj, a[(j+1)*lda+j:], lda) + } + } +} diff --git a/native/dtrtri.go b/native/dtrtri.go new file mode 100644 index 00000000..73ff2da4 --- /dev/null +++ b/native/dtrtri.go @@ -0,0 +1,60 @@ +package native + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Dtrtri computes the inverse of a triangular matrix, storing the result in place +// into a. This is the BLAS level 3 version of the algorithm which builds upon +// Dtrti2 to operate on matrix blocks instead of only individual columns. +// +// Dtrti returns whether the matrix a is singular or whether it's not singular. +// If the matrix is singular the inversion is not performed. +func (impl Implementation) Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) (ok bool) { + checkMatrix(n, n, a, lda) + if uplo != blas.Upper && uplo != blas.Lower { + panic(badUplo) + } + if diag != blas.NonUnit && diag != blas.Unit { + panic(badDiag) + } + if n == 0 { + return + } + nonUnit := diag == blas.NonUnit + if nonUnit { + for i := 0; i < n; i++ { + if a[i*lda+i] == 0 { + return false + } + } + } + + bi := blas64.Implementation() + + nb := impl.Ilaenv(1, "DTRTRI", "UD", n, -1, -1, -1) + if nb <= 1 || nb > n { + impl.Dtrti2(uplo, diag, n, a, lda) + return true + } + if uplo == blas.Upper { + for j := 0; j < n; j += nb { + jb := min(nb, n-j) + bi.Dtrmm(blas.Left, blas.Upper, blas.NoTrans, diag, j, jb, 1, a, lda, a[j:], lda) + bi.Dtrsm(blas.Right, blas.Upper, blas.NoTrans, diag, j, jb, -1, a[j*lda+j:], lda, a[j:], lda) + impl.Dtrti2(blas.Upper, diag, jb, a[j*lda+j:], lda) + } + return true + } + nn := ((n - 1) / nb) * nb + for j := nn; j >= 0; j -= nb { + jb := min(nb, n-j) + if j+jb <= n-1 { + bi.Dtrmm(blas.Left, blas.Lower, blas.NoTrans, diag, n-j-jb, jb, 1, a[(j+jb)*lda+j+jb:], lda, a[(j+jb)*lda+j:], lda) + bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, diag, n-j-jb, jb, -1, a[j*lda+j:], lda, a[(j+jb)*lda+j:], lda) + } + impl.Dtrti2(blas.Lower, diag, jb, a[j*lda+j:], lda) + } + return true +} diff --git a/native/lapack_test.go b/native/lapack_test.go index f42aa1f1..e50f43bc 100644 --- a/native/lapack_test.go +++ b/native/lapack_test.go @@ -116,6 +116,14 @@ func TestDtrcon(t *testing.T) { testlapack.DtrconTest(t, impl) } +func TestDtrtri(t *testing.T) { + testlapack.DtrtriTest(t, impl) +} + +func TestDtrti2(t *testing.T) { + testlapack.Dtrti2Test(t, impl) +} + func TestIladlc(t *testing.T) { testlapack.IladlcTest(t, impl) } diff --git a/testlapack/dtrti2.go b/testlapack/dtrti2.go new file mode 100644 index 00000000..dada82a3 --- /dev/null +++ b/testlapack/dtrti2.go @@ -0,0 +1,153 @@ +package testlapack + +import ( + "math" + "math/rand" + "testing" + + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" + "github.com/gonum/floats" +) + +type Dtrti2er interface { + Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) +} + +func Dtrti2Test(t *testing.T, impl Dtrti2er) { + for _, test := range []struct { + a []float64 + n int + uplo blas.Uplo + diag blas.Diag + ans []float64 + }{ + { + a: []float64{ + 2, 3, 4, + 0, 5, 6, + 8, 0, 8}, + n: 3, + uplo: blas.Upper, + diag: blas.NonUnit, + ans: []float64{ + 0.5, -0.3, -0.025, + 0, 0.2, -0.15, + 8, 0, 0.125, + }, + }, + { + a: []float64{ + 5, 3, 4, + 0, 7, 6, + 10, 0, 8}, + n: 3, + uplo: blas.Upper, + diag: blas.Unit, + ans: []float64{ + 5, -3, 14, + 0, 7, -6, + 10, 0, 8, + }, + }, + { + a: []float64{ + 2, 0, 0, + 3, 5, 0, + 4, 6, 8}, + n: 3, + uplo: blas.Lower, + diag: blas.NonUnit, + ans: []float64{ + 0.5, 0, 0, + -0.3, 0.2, 0, + -0.025, -0.15, 0.125, + }, + }, + { + a: []float64{ + 1, 0, 0, + 3, 1, 0, + 4, 6, 1}, + n: 3, + uplo: blas.Lower, + diag: blas.Unit, + ans: []float64{ + 1, 0, 0, + -3, 1, 0, + 14, -6, 1, + }, + }, + } { + impl.Dtrti2(test.uplo, test.diag, test.n, test.a, test.n) + if !floats.EqualApprox(test.ans, test.a, 1e-14) { + t.Errorf("Matrix inverse mismatch. Want %v, got %v.", test.ans, test.a) + } + } + bi := blas64.Implementation() + for _, uplo := range []blas.Uplo{blas.Upper} { + for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} { + for _, test := range []struct { + n, lda int + }{ + {3, 0}, + {3, 5}, + } { + n := test.n + lda := test.lda + if lda == 0 { + lda = n + } + a := make([]float64, n*lda) + for i := range a { + a[i] = rand.Float64() + } + aCopy := make([]float64, len(a)) + copy(aCopy, a) + impl.Dtrti2(uplo, diag, n, a, lda) + if uplo == blas.Upper { + for i := 1; i < n; i++ { + for j := 0; j < i; j++ { + aCopy[i*lda+j] = 0 + a[i*lda+j] = 0 + } + } + } else { + for i := 1; i < n; i++ { + for j := i + 1; j < n; j++ { + aCopy[i*lda+j] = 0 + a[i*lda+j] = 0 + } + } + } + if diag == blas.Unit { + for i := 0; i < n; i++ { + a[i*lda+i] = 1 + aCopy[i*lda+i] = 1 + } + } + ans := make([]float64, len(a)) + bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda) + iseye := true + for i := 0; i < n; i++ { + for j := 0; j < n; j++ { + if i == j { + if math.Abs(ans[i*lda+i]-1) > 1e-14 { + iseye = false + break + } + } else { + if math.Abs(ans[i*lda+j]) > 1e-14 { + iseye = false + break + } + } + } + } + if !iseye { + t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, ans = %v", uplo == blas.Upper, diag == blas.Unit, ans) + } + } + } + } +} diff --git a/testlapack/dtrtri.go b/testlapack/dtrtri.go new file mode 100644 index 00000000..694d494b --- /dev/null +++ b/testlapack/dtrtri.go @@ -0,0 +1,89 @@ +package testlapack + +import ( + "math" + "math/rand" + "testing" + + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +type Dtrtrier interface { + Dtrconer + Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) bool +} + +func DtrtriTest(t *testing.T, impl Dtrtrier) { + bi := blas64.Implementation() + for _, uplo := range []blas.Uplo{blas.Upper} { + for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} { + for _, test := range []struct { + n, lda int + }{ + {3, 0}, + {70, 0}, + {200, 0}, + {3, 5}, + {70, 92}, + {200, 205}, + } { + n := test.n + lda := test.lda + if lda == 0 { + lda = n + } + a := make([]float64, n*lda) + for i := range a { + a[i] = rand.Float64() + 1 // This keeps the matrices well conditioned. + } + aCopy := make([]float64, len(a)) + copy(aCopy, a) + impl.Dtrtri(uplo, diag, n, a, lda) + if uplo == blas.Upper { + for i := 1; i < n; i++ { + for j := 0; j < i; j++ { + aCopy[i*lda+j] = 0 + a[i*lda+j] = 0 + } + } + } else { + for i := 1; i < n; i++ { + for j := i + 1; j < n; j++ { + aCopy[i*lda+j] = 0 + a[i*lda+j] = 0 + } + } + } + if diag == blas.Unit { + for i := 0; i < n; i++ { + a[i*lda+i] = 1 + aCopy[i*lda+i] = 1 + } + } + ans := make([]float64, len(a)) + bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda) + iseye := true + for i := 0; i < n; i++ { + for j := 0; j < n; j++ { + if i == j { + if math.Abs(ans[i*lda+i]-1) > 1e-4 { + iseye = false + break + } + } else { + if math.Abs(ans[i*lda+j]) > 1e-4 { + iseye = false + break + } + } + } + } + if !iseye { + t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, n = %v, lda = %v", + uplo == blas.Upper, diag == blas.Unit, n, lda) + } + } + } + } +}