From 32bdb776ef9a215c9e9511080285522c17f08f52 Mon Sep 17 00:00:00 2001 From: btracey Date: Sat, 1 Aug 2015 23:30:43 -0600 Subject: [PATCH] Add Dgetrs (compute a solution based on LU factorization) and test. Responded to PR comments --- cgo/lapack.go | 24 +++++++++ cgo/lapack_test.go | 4 ++ native/dgetrs.go | 51 +++++++++++++++++++ native/lapack_test.go | 4 ++ testlapack/dgetrs.go | 116 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 199 insertions(+) create mode 100644 native/dgetrs.go create mode 100644 testlapack/dgetrs.go diff --git a/cgo/lapack.go b/cgo/lapack.go index 66480d3f..cc5e0ef4 100644 --- a/cgo/lapack.go +++ b/cgo/lapack.go @@ -136,3 +136,27 @@ func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (o } return ok } + +// Dgetrs solves a system of equations using an LU factorization. +// The system of equations solved is +// A * X = B if trans == blas.Trans +// A^T * X = B if trans == blas.NoTrans +// A is a general n×n matrix with stride lda. B is a general matrix of size n×nrhs. +// +// On entry b contains the elements of the matrix B. On exit, b contains the +// elements of X, the solution to the system of equations. +// +// a and ipiv contain the LU factorization of A and the permutation indices as +// computed by Dgetrf. ipiv is zero-indexed. +func (impl Implementation) Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) { + checkMatrix(n, n, a, lda) + checkMatrix(n, nrhs, b, ldb) + if len(ipiv) < n { + panic(badIpiv) + } + ipiv32 := make([]int32, len(ipiv)) + for i, v := range ipiv { + ipiv32[i] = int32(v) + 1 // Transform to one-indexed. + } + clapack.Dgetrs(trans, n, nrhs, a, lda, ipiv32, b, ldb) +} diff --git a/cgo/lapack_test.go b/cgo/lapack_test.go index dd2ea3b1..7a69d7e2 100644 --- a/cgo/lapack_test.go +++ b/cgo/lapack_test.go @@ -23,3 +23,7 @@ func TestDgetf2(t *testing.T) { func TestDgetrf(t *testing.T) { testlapack.DgetrfTest(t, impl) } + +func TestDgetrs(t *testing.T) { + testlapack.DgetrsTest(t, impl) +} diff --git a/native/dgetrs.go b/native/dgetrs.go new file mode 100644 index 00000000..8f05decb --- /dev/null +++ b/native/dgetrs.go @@ -0,0 +1,51 @@ +package native + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Dgetrs solves a system of equations using an LU factorization. +// The system of equations solved is +// A * X = B if trans == blas.Trans +// A^T * X = B if trans == blas.NoTrans +// A is a general n×n matrix with stride lda. B is a general matrix of size n×nrhs. +// +// On entry b contains the elements of the matrix B. On exit, b contains the +// elements of X, the solution to the system of equations. +// +// a and ipiv contain the LU factorization of A and the permutation indices as +// computed by Dgetrf. ipiv is zero-indexed. +func (impl Implementation) Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) { + checkMatrix(n, n, a, lda) + checkMatrix(n, nrhs, b, ldb) + if len(ipiv) < n { + panic(badIpiv) + } + if n == 0 || nrhs == 0 { + return + } + if trans != blas.Trans && trans != blas.NoTrans { + panic(badTrans) + } + bi := blas64.Implementation() + if trans == blas.NoTrans { + // Solve A * X = B. + impl.Dlaswp(nrhs, b, ldb, 0, n-1, ipiv, 1) + // Solve L * X = B, updating b. + bi.Dtrsm(blas.Left, blas.Lower, blas.NoTrans, blas.Unit, + n, nrhs, 1, a, lda, b, ldb) + // Solve U * X = B, updating b. + bi.Dtrsm(blas.Left, blas.Upper, blas.NoTrans, blas.NonUnit, + n, nrhs, 1, a, lda, b, ldb) + return + } + // Solve A^T * X = B. + // Solve U^T * X = B, updating b. + bi.Dtrsm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit, + n, nrhs, 1, a, lda, b, ldb) + // Solve L^T * X = B, updating b. + bi.Dtrsm(blas.Left, blas.Lower, blas.Trans, blas.Unit, + n, nrhs, 1, a, lda, b, ldb) + impl.Dlaswp(nrhs, b, ldb, 0, n-1, ipiv, -1) +} diff --git a/native/lapack_test.go b/native/lapack_test.go index 9418dd52..20fa692f 100644 --- a/native/lapack_test.go +++ b/native/lapack_test.go @@ -40,6 +40,10 @@ func TestDgetrf(t *testing.T) { testlapack.DgetrfTest(t, impl) } +func TestDgetrs(t *testing.T) { + testlapack.DgetrsTest(t, impl) +} + func TestDlange(t *testing.T) { testlapack.DlangeTest(t, impl) } diff --git a/testlapack/dgetrs.go b/testlapack/dgetrs.go new file mode 100644 index 00000000..6c94f788 --- /dev/null +++ b/testlapack/dgetrs.go @@ -0,0 +1,116 @@ +package testlapack + +import ( + "math/rand" + "testing" + + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" + "github.com/gonum/floats" +) + +type Dgetrser interface { + Dgetrfer + Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) +} + +func DgetrsTest(t *testing.T, impl Dgetrser) { + for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} { + for _, test := range []struct { + n, nrhs, lda, ldb int + tol float64 + }{ + {3, 3, 0, 0, 1e-14}, + {3, 3, 0, 0, 1e-14}, + {3, 5, 0, 0, 1e-14}, + {3, 5, 0, 0, 1e-14}, + {5, 3, 0, 0, 1e-14}, + {5, 3, 0, 0, 1e-14}, + + {3, 3, 8, 10, 1e-14}, + {3, 3, 8, 10, 1e-14}, + {3, 5, 8, 10, 1e-14}, + {3, 5, 8, 10, 1e-14}, + {5, 3, 8, 10, 1e-14}, + {5, 3, 8, 10, 1e-14}, + + {300, 300, 0, 0, 1e-10}, + {300, 300, 0, 0, 1e-10}, + {300, 500, 0, 0, 1e-10}, + {300, 500, 0, 0, 1e-10}, + {500, 300, 0, 0, 1e-10}, + {500, 300, 0, 0, 1e-10}, + + {300, 300, 700, 600, 1e-10}, + {300, 300, 700, 600, 1e-10}, + {300, 500, 700, 600, 1e-10}, + {300, 500, 700, 600, 1e-10}, + {500, 300, 700, 600, 1e-10}, + {500, 300, 700, 600, 1e-10}, + } { + n := test.n + nrhs := test.nrhs + lda := test.lda + if lda == 0 { + lda = n + } + ldb := test.ldb + if ldb == 0 { + ldb = nrhs + } + a := make([]float64, n*lda) + for i := range a { + a[i] = rand.Float64() + } + b := make([]float64, n*ldb) + for i := range b { + b[i] = rand.Float64() + } + aCopy := make([]float64, len(a)) + copy(aCopy, a) + bCopy := make([]float64, len(b)) + copy(bCopy, b) + + ipiv := make([]int, n) + for i := range ipiv { + ipiv[i] = rand.Int() + } + + // Compute the LU factorization. + impl.Dgetrf(n, n, a, lda, ipiv) + // Solve the system of equations given the result. + impl.Dgetrs(trans, n, nrhs, a, lda, ipiv, b, ldb) + + // Check that the system of equations holds. + A := blas64.General{ + Rows: n, + Cols: n, + Stride: lda, + Data: aCopy, + } + B := blas64.General{ + Rows: n, + Cols: nrhs, + Stride: ldb, + Data: bCopy, + } + X := blas64.General{ + Rows: n, + Cols: nrhs, + Stride: ldb, + Data: b, + } + tmp := blas64.General{ + Rows: n, + Cols: nrhs, + Stride: ldb, + Data: make([]float64, n*ldb), + } + copy(tmp.Data, bCopy) + blas64.Gemm(trans, blas.NoTrans, 1, A, X, 0, B) + if !floats.EqualApprox(tmp.Data, bCopy, test.tol) { + t.Errorf("Linear solve mismatch. trans = %v, n = %v, nrhs = %v, lda = %v, ldb = %v", trans, n, nrhs, lda, ldb) + } + } + } +}