diff --git a/native/dtrti2.go b/native/dtrti2.go new file mode 100644 index 00000000..50dc3bee --- /dev/null +++ b/native/dtrti2.go @@ -0,0 +1,51 @@ +package native + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Dtrti2 computes the inverse of a triangular matrix, storing the result in place +// into a. +func (impl Implementation) Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) { + checkMatrix(n, n, a, lda) + if uplo != blas.Upper && uplo != blas.Lower { + panic(badUplo) + } + if diag != blas.NonUnit && diag != blas.Unit { + panic(badDiag) + } + bi := blas64.Implementation() + + nonUnit := diag == blas.NonUnit + // TODO(btracey): Replace this with a row-major ordering. + if uplo == blas.Upper { + for j := 0; j < n; j++ { + var ajj float64 + if nonUnit { + ajj = 1 / a[j*lda+j] + a[j*lda+j] = ajj + ajj *= -1 + } else { + ajj = -1 + } + bi.Dtrmv(blas.Upper, blas.NoTrans, diag, j, a, lda, a[j:], lda) + bi.Dscal(j, ajj, a[j:], lda) + } + return + } + for j := n - 1; j >= 0; j-- { + var ajj float64 + if nonUnit { + ajj = 1 / a[j*lda+j] + a[j*lda+j] = ajj + ajj *= -1 + } else { + ajj = -1 + } + if j < n-1 { + bi.Dtrmv(blas.Lower, blas.NoTrans, diag, n-j-1, a[(j+1)*lda+j+1:], lda, a[(j+1)*lda+j:], lda) + bi.Dscal(n-j-1, ajj, a[(j+1)*lda+j:], lda) + } + } +} diff --git a/native/lapack_test.go b/native/lapack_test.go index f42aa1f1..efdae019 100644 --- a/native/lapack_test.go +++ b/native/lapack_test.go @@ -116,6 +116,10 @@ func TestDtrcon(t *testing.T) { testlapack.DtrconTest(t, impl) } +func TestDtrtri2(t *testing.T) { + testlapack.Dtrtri2Test(t, impl) +} + func TestIladlc(t *testing.T) { testlapack.IladlcTest(t, impl) } diff --git a/testlapack/dtrti2.go b/testlapack/dtrti2.go new file mode 100644 index 00000000..ff2b089b --- /dev/null +++ b/testlapack/dtrti2.go @@ -0,0 +1,153 @@ +package testlapack + +import ( + "math" + "math/rand" + "testing" + + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" + "github.com/gonum/floats" +) + +type Dtrti2er interface { + Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) +} + +func Dtrtri2Test(t *testing.T, impl Dtrti2er) { + for _, test := range []struct { + a []float64 + n int + uplo blas.Uplo + diag blas.Diag + ans []float64 + }{ + { + a: []float64{ + 2, 3, 4, + 0, 5, 6, + 8, 0, 8}, + n: 3, + uplo: blas.Upper, + diag: blas.NonUnit, + ans: []float64{ + 0.5, -0.3, -0.025, + 0, 0.2, -0.15, + 8, 0, 0.125, + }, + }, + { + a: []float64{ + 5, 3, 4, + 0, 7, 6, + 10, 0, 8}, + n: 3, + uplo: blas.Upper, + diag: blas.Unit, + ans: []float64{ + 5, -3, 14, + 0, 7, -6, + 10, 0, 8, + }, + }, + { + a: []float64{ + 2, 0, 0, + 3, 5, 0, + 4, 6, 8}, + n: 3, + uplo: blas.Lower, + diag: blas.NonUnit, + ans: []float64{ + 0.5, 0, 0, + -0.3, 0.2, 0, + -0.025, -0.15, 0.125, + }, + }, + { + a: []float64{ + 1, 0, 0, + 3, 1, 0, + 4, 6, 1}, + n: 3, + uplo: blas.Lower, + diag: blas.Unit, + ans: []float64{ + 1, 0, 0, + -3, 1, 0, + 14, -6, 1, + }, + }, + } { + impl.Dtrti2(test.uplo, test.diag, test.n, test.a, test.n) + if !floats.EqualApprox(test.ans, test.a, 1e-14) { + t.Errorf("Matrix inverse mismatch. Want %v, got %v.", test.ans, test.a) + } + } + bi := blas64.Implementation() + for _, uplo := range []blas.Uplo{blas.Upper} { + for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} { + for _, test := range []struct { + n, lda int + }{ + {3, 0}, + {3, 5}, + } { + n := test.n + lda := test.lda + if lda == 0 { + lda = n + } + a := make([]float64, n*lda) + for i := range a { + a[i] = rand.Float64() + } + aCopy := make([]float64, len(a)) + copy(aCopy, a) + impl.Dtrti2(uplo, diag, n, a, lda) + if uplo == blas.Upper { + for i := 1; i < n; i++ { + for j := 0; j < i; j++ { + aCopy[i*lda+j] = 0 + a[i*lda+j] = 0 + } + } + } else { + for i := 1; i < n; i++ { + for j := i + 1; j < n; j++ { + aCopy[i*lda+j] = 0 + a[i*lda+j] = 0 + } + } + } + if diag == blas.Unit { + for i := 0; i < n; i++ { + a[i*lda+i] = 1 + aCopy[i*lda+i] = 1 + } + } + ans := make([]float64, len(a)) + bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda) + iseye := true + for i := 0; i < n; i++ { + for j := 0; j < n; j++ { + if i == j { + if math.Abs(ans[i*lda+i]-1) > 1e-14 { + iseye = false + break + } + } else { + if math.Abs(ans[i*lda+j]) > 1e-14 { + iseye = false + break + } + } + } + } + if !iseye { + t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, ans = %v", uplo == blas.Upper, diag == blas.Unit, ans) + } + } + } + } +}