mirror of
				https://github.com/gonum/gonum.git
				synced 2025-10-25 08:10:28 +08:00 
			
		
		
		
	Add Dtrti2
This commit is contained in:
		| @@ -576,6 +576,21 @@ func (impl Implementation) Dtrcon(norm lapack.MatrixNorm, uplo blas.Uplo, diag b | |||||||
| 	return rcond[0] | 	return rcond[0] | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Dtrtri computes the inverse of a triangular matrix, storing the result in place | ||||||
|  | // into a. This is the BLAS level 3 version of the algorithm. | ||||||
|  | // | ||||||
|  | // Dtrti returns whether the matrix a is singular. | ||||||
|  | func (impl Implementation) Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) (ok bool) { | ||||||
|  | 	checkMatrix(n, n, a, lda) | ||||||
|  | 	if uplo != blas.Upper && uplo != blas.Lower { | ||||||
|  | 		panic(badUplo) | ||||||
|  | 	} | ||||||
|  | 	if diag != blas.NonUnit && diag != blas.Unit { | ||||||
|  | 		panic(badDiag) | ||||||
|  | 	} | ||||||
|  | 	return clapack.Dtrtri(uplo, diag, n, a, lda) | ||||||
|  | } | ||||||
|  |  | ||||||
| // Dtrtrs solves a triangular system of the form A * X = B or A^T * X = B. Dtrtrs | // Dtrtrs solves a triangular system of the form A * X = B or A^T * X = B. Dtrtrs | ||||||
| // returns whether the solve completed successfully. If A is singular, no solve is performed. | // returns whether the solve completed successfully. If A is singular, no solve is performed. | ||||||
| func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) { | func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) { | ||||||
|   | |||||||
| @@ -94,3 +94,7 @@ func TestDpocon(t *testing.T) { | |||||||
| func TestDtrcon(t *testing.T) { | func TestDtrcon(t *testing.T) { | ||||||
| 	testlapack.DtrconTest(t, impl) | 	testlapack.DtrconTest(t, impl) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestDtrtri(t *testing.T) { | ||||||
|  | 	testlapack.DtrtriTest(t, impl) | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										51
									
								
								native/dtrti2.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								native/dtrti2.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | |||||||
|  | package native | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/gonum/blas" | ||||||
|  | 	"github.com/gonum/blas/blas64" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Dtrti2 computes the inverse of a triangular matrix, storing the result in place | ||||||
|  | // into a. This is the BLAS level 2 version of the algorithm. | ||||||
|  | func (impl Implementation) Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) { | ||||||
|  | 	checkMatrix(n, n, a, lda) | ||||||
|  | 	if uplo != blas.Upper && uplo != blas.Lower { | ||||||
|  | 		panic(badUplo) | ||||||
|  | 	} | ||||||
|  | 	if diag != blas.NonUnit && diag != blas.Unit { | ||||||
|  | 		panic(badDiag) | ||||||
|  | 	} | ||||||
|  | 	bi := blas64.Implementation() | ||||||
|  |  | ||||||
|  | 	nonUnit := diag == blas.NonUnit | ||||||
|  | 	// TODO(btracey): Replace this with a row-major ordering. | ||||||
|  | 	if uplo == blas.Upper { | ||||||
|  | 		for j := 0; j < n; j++ { | ||||||
|  | 			var ajj float64 | ||||||
|  | 			if nonUnit { | ||||||
|  | 				ajj = 1 / a[j*lda+j] | ||||||
|  | 				a[j*lda+j] = ajj | ||||||
|  | 				ajj *= -1 | ||||||
|  | 			} else { | ||||||
|  | 				ajj = -1 | ||||||
|  | 			} | ||||||
|  | 			bi.Dtrmv(blas.Upper, blas.NoTrans, diag, j, a, lda, a[j:], lda) | ||||||
|  | 			bi.Dscal(j, ajj, a[j:], lda) | ||||||
|  | 		} | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	for j := n - 1; j >= 0; j-- { | ||||||
|  | 		var ajj float64 | ||||||
|  | 		if nonUnit { | ||||||
|  | 			ajj = 1 / a[j*lda+j] | ||||||
|  | 			a[j*lda+j] = ajj | ||||||
|  | 			ajj *= -1 | ||||||
|  | 		} else { | ||||||
|  | 			ajj = -1 | ||||||
|  | 		} | ||||||
|  | 		if j < n-1 { | ||||||
|  | 			bi.Dtrmv(blas.Lower, blas.NoTrans, diag, n-j-1, a[(j+1)*lda+j+1:], lda, a[(j+1)*lda+j:], lda) | ||||||
|  | 			bi.Dscal(n-j-1, ajj, a[(j+1)*lda+j:], lda) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										60
									
								
								native/dtrtri.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								native/dtrtri.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | |||||||
|  | package native | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/gonum/blas" | ||||||
|  | 	"github.com/gonum/blas/blas64" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Dtrtri computes the inverse of a triangular matrix, storing the result in place | ||||||
|  | // into a. This is the BLAS level 3 version of the algorithm which builds upon | ||||||
|  | // Dtrti2 to operate on matrix blocks instead of only individual columns. | ||||||
|  | // | ||||||
|  | // Dtrti returns whether the matrix a is singular or whether it's not singular. | ||||||
|  | // If the matrix is singular the inversion is not performed. | ||||||
|  | func (impl Implementation) Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) (ok bool) { | ||||||
|  | 	checkMatrix(n, n, a, lda) | ||||||
|  | 	if uplo != blas.Upper && uplo != blas.Lower { | ||||||
|  | 		panic(badUplo) | ||||||
|  | 	} | ||||||
|  | 	if diag != blas.NonUnit && diag != blas.Unit { | ||||||
|  | 		panic(badDiag) | ||||||
|  | 	} | ||||||
|  | 	if n == 0 { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	nonUnit := diag == blas.NonUnit | ||||||
|  | 	if nonUnit { | ||||||
|  | 		for i := 0; i < n; i++ { | ||||||
|  | 			if a[i*lda+i] == 0 { | ||||||
|  | 				return false | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	bi := blas64.Implementation() | ||||||
|  |  | ||||||
|  | 	nb := impl.Ilaenv(1, "DTRTRI", "UD", n, -1, -1, -1) | ||||||
|  | 	if nb <= 1 || nb > n { | ||||||
|  | 		impl.Dtrti2(uplo, diag, n, a, lda) | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if uplo == blas.Upper { | ||||||
|  | 		for j := 0; j < n; j += nb { | ||||||
|  | 			jb := min(nb, n-j) | ||||||
|  | 			bi.Dtrmm(blas.Left, blas.Upper, blas.NoTrans, diag, j, jb, 1, a, lda, a[j:], lda) | ||||||
|  | 			bi.Dtrsm(blas.Right, blas.Upper, blas.NoTrans, diag, j, jb, -1, a[j*lda+j:], lda, a[j:], lda) | ||||||
|  | 			impl.Dtrti2(blas.Upper, diag, jb, a[j*lda+j:], lda) | ||||||
|  | 		} | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	nn := ((n - 1) / nb) * nb | ||||||
|  | 	for j := nn; j >= 0; j -= nb { | ||||||
|  | 		jb := min(nb, n-j) | ||||||
|  | 		if j+jb <= n-1 { | ||||||
|  | 			bi.Dtrmm(blas.Left, blas.Lower, blas.NoTrans, diag, n-j-jb, jb, 1, a[(j+jb)*lda+j+jb:], lda, a[(j+jb)*lda+j:], lda) | ||||||
|  | 			bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, diag, n-j-jb, jb, -1, a[j*lda+j:], lda, a[(j+jb)*lda+j:], lda) | ||||||
|  | 		} | ||||||
|  | 		impl.Dtrti2(blas.Lower, diag, jb, a[j*lda+j:], lda) | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
| @@ -116,6 +116,14 @@ func TestDtrcon(t *testing.T) { | |||||||
| 	testlapack.DtrconTest(t, impl) | 	testlapack.DtrconTest(t, impl) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestDtrtri(t *testing.T) { | ||||||
|  | 	testlapack.DtrtriTest(t, impl) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDtrti2(t *testing.T) { | ||||||
|  | 	testlapack.Dtrti2Test(t, impl) | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestIladlc(t *testing.T) { | func TestIladlc(t *testing.T) { | ||||||
| 	testlapack.IladlcTest(t, impl) | 	testlapack.IladlcTest(t, impl) | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										153
									
								
								testlapack/dtrti2.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										153
									
								
								testlapack/dtrti2.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,153 @@ | |||||||
|  | package testlapack | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"math" | ||||||
|  | 	"math/rand" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/gonum/blas" | ||||||
|  | 	"github.com/gonum/blas/blas64" | ||||||
|  | 	"github.com/gonum/floats" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Dtrti2er interface { | ||||||
|  | 	Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Dtrti2Test(t *testing.T, impl Dtrti2er) { | ||||||
|  | 	for _, test := range []struct { | ||||||
|  | 		a    []float64 | ||||||
|  | 		n    int | ||||||
|  | 		uplo blas.Uplo | ||||||
|  | 		diag blas.Diag | ||||||
|  | 		ans  []float64 | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			a: []float64{ | ||||||
|  | 				2, 3, 4, | ||||||
|  | 				0, 5, 6, | ||||||
|  | 				8, 0, 8}, | ||||||
|  | 			n:    3, | ||||||
|  | 			uplo: blas.Upper, | ||||||
|  | 			diag: blas.NonUnit, | ||||||
|  | 			ans: []float64{ | ||||||
|  | 				0.5, -0.3, -0.025, | ||||||
|  | 				0, 0.2, -0.15, | ||||||
|  | 				8, 0, 0.125, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			a: []float64{ | ||||||
|  | 				5, 3, 4, | ||||||
|  | 				0, 7, 6, | ||||||
|  | 				10, 0, 8}, | ||||||
|  | 			n:    3, | ||||||
|  | 			uplo: blas.Upper, | ||||||
|  | 			diag: blas.Unit, | ||||||
|  | 			ans: []float64{ | ||||||
|  | 				5, -3, 14, | ||||||
|  | 				0, 7, -6, | ||||||
|  | 				10, 0, 8, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			a: []float64{ | ||||||
|  | 				2, 0, 0, | ||||||
|  | 				3, 5, 0, | ||||||
|  | 				4, 6, 8}, | ||||||
|  | 			n:    3, | ||||||
|  | 			uplo: blas.Lower, | ||||||
|  | 			diag: blas.NonUnit, | ||||||
|  | 			ans: []float64{ | ||||||
|  | 				0.5, 0, 0, | ||||||
|  | 				-0.3, 0.2, 0, | ||||||
|  | 				-0.025, -0.15, 0.125, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			a: []float64{ | ||||||
|  | 				1, 0, 0, | ||||||
|  | 				3, 1, 0, | ||||||
|  | 				4, 6, 1}, | ||||||
|  | 			n:    3, | ||||||
|  | 			uplo: blas.Lower, | ||||||
|  | 			diag: blas.Unit, | ||||||
|  | 			ans: []float64{ | ||||||
|  | 				1, 0, 0, | ||||||
|  | 				-3, 1, 0, | ||||||
|  | 				14, -6, 1, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} { | ||||||
|  | 		impl.Dtrti2(test.uplo, test.diag, test.n, test.a, test.n) | ||||||
|  | 		if !floats.EqualApprox(test.ans, test.a, 1e-14) { | ||||||
|  | 			t.Errorf("Matrix inverse mismatch. Want %v, got %v.", test.ans, test.a) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	bi := blas64.Implementation() | ||||||
|  | 	for _, uplo := range []blas.Uplo{blas.Upper} { | ||||||
|  | 		for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} { | ||||||
|  | 			for _, test := range []struct { | ||||||
|  | 				n, lda int | ||||||
|  | 			}{ | ||||||
|  | 				{3, 0}, | ||||||
|  | 				{3, 5}, | ||||||
|  | 			} { | ||||||
|  | 				n := test.n | ||||||
|  | 				lda := test.lda | ||||||
|  | 				if lda == 0 { | ||||||
|  | 					lda = n | ||||||
|  | 				} | ||||||
|  | 				a := make([]float64, n*lda) | ||||||
|  | 				for i := range a { | ||||||
|  | 					a[i] = rand.Float64() | ||||||
|  | 				} | ||||||
|  | 				aCopy := make([]float64, len(a)) | ||||||
|  | 				copy(aCopy, a) | ||||||
|  | 				impl.Dtrti2(uplo, diag, n, a, lda) | ||||||
|  | 				if uplo == blas.Upper { | ||||||
|  | 					for i := 1; i < n; i++ { | ||||||
|  | 						for j := 0; j < i; j++ { | ||||||
|  | 							aCopy[i*lda+j] = 0 | ||||||
|  | 							a[i*lda+j] = 0 | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} else { | ||||||
|  | 					for i := 1; i < n; i++ { | ||||||
|  | 						for j := i + 1; j < n; j++ { | ||||||
|  | 							aCopy[i*lda+j] = 0 | ||||||
|  | 							a[i*lda+j] = 0 | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				if diag == blas.Unit { | ||||||
|  | 					for i := 0; i < n; i++ { | ||||||
|  | 						a[i*lda+i] = 1 | ||||||
|  | 						aCopy[i*lda+i] = 1 | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				ans := make([]float64, len(a)) | ||||||
|  | 				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda) | ||||||
|  | 				iseye := true | ||||||
|  | 				for i := 0; i < n; i++ { | ||||||
|  | 					for j := 0; j < n; j++ { | ||||||
|  | 						if i == j { | ||||||
|  | 							if math.Abs(ans[i*lda+i]-1) > 1e-14 { | ||||||
|  | 								iseye = false | ||||||
|  | 								break | ||||||
|  | 							} | ||||||
|  | 						} else { | ||||||
|  | 							if math.Abs(ans[i*lda+j]) > 1e-14 { | ||||||
|  | 								iseye = false | ||||||
|  | 								break | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				if !iseye { | ||||||
|  | 					t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, ans = %v", uplo == blas.Upper, diag == blas.Unit, ans) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										89
									
								
								testlapack/dtrtri.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								testlapack/dtrtri.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | |||||||
|  | package testlapack | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"math" | ||||||
|  | 	"math/rand" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/gonum/blas" | ||||||
|  | 	"github.com/gonum/blas/blas64" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Dtrtrier interface { | ||||||
|  | 	Dtrconer | ||||||
|  | 	Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) bool | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DtrtriTest(t *testing.T, impl Dtrtrier) { | ||||||
|  | 	bi := blas64.Implementation() | ||||||
|  | 	for _, uplo := range []blas.Uplo{blas.Upper} { | ||||||
|  | 		for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} { | ||||||
|  | 			for _, test := range []struct { | ||||||
|  | 				n, lda int | ||||||
|  | 			}{ | ||||||
|  | 				{3, 0}, | ||||||
|  | 				{70, 0}, | ||||||
|  | 				{200, 0}, | ||||||
|  | 				{3, 5}, | ||||||
|  | 				{70, 92}, | ||||||
|  | 				{200, 205}, | ||||||
|  | 			} { | ||||||
|  | 				n := test.n | ||||||
|  | 				lda := test.lda | ||||||
|  | 				if lda == 0 { | ||||||
|  | 					lda = n | ||||||
|  | 				} | ||||||
|  | 				a := make([]float64, n*lda) | ||||||
|  | 				for i := range a { | ||||||
|  | 					a[i] = rand.Float64() + 1 // This keeps the matrices well conditioned. | ||||||
|  | 				} | ||||||
|  | 				aCopy := make([]float64, len(a)) | ||||||
|  | 				copy(aCopy, a) | ||||||
|  | 				impl.Dtrtri(uplo, diag, n, a, lda) | ||||||
|  | 				if uplo == blas.Upper { | ||||||
|  | 					for i := 1; i < n; i++ { | ||||||
|  | 						for j := 0; j < i; j++ { | ||||||
|  | 							aCopy[i*lda+j] = 0 | ||||||
|  | 							a[i*lda+j] = 0 | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} else { | ||||||
|  | 					for i := 1; i < n; i++ { | ||||||
|  | 						for j := i + 1; j < n; j++ { | ||||||
|  | 							aCopy[i*lda+j] = 0 | ||||||
|  | 							a[i*lda+j] = 0 | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				if diag == blas.Unit { | ||||||
|  | 					for i := 0; i < n; i++ { | ||||||
|  | 						a[i*lda+i] = 1 | ||||||
|  | 						aCopy[i*lda+i] = 1 | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				ans := make([]float64, len(a)) | ||||||
|  | 				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda) | ||||||
|  | 				iseye := true | ||||||
|  | 				for i := 0; i < n; i++ { | ||||||
|  | 					for j := 0; j < n; j++ { | ||||||
|  | 						if i == j { | ||||||
|  | 							if math.Abs(ans[i*lda+i]-1) > 1e-4 { | ||||||
|  | 								iseye = false | ||||||
|  | 								break | ||||||
|  | 							} | ||||||
|  | 						} else { | ||||||
|  | 							if math.Abs(ans[i*lda+j]) > 1e-4 { | ||||||
|  | 								iseye = false | ||||||
|  | 								break | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				if !iseye { | ||||||
|  | 					t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, n = %v, lda = %v", | ||||||
|  | 						uplo == blas.Upper, diag == blas.Unit, n, lda) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user
	 btracey
					btracey