diff --git a/lapack/gonum/dgtsv.go b/lapack/gonum/dgtsv.go new file mode 100644 index 00000000..36d017dd --- /dev/null +++ b/lapack/gonum/dgtsv.go @@ -0,0 +1,99 @@ +// Copyright ©2020 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import "math" + +// Dgtsv solves the equation +// A * X = B +// where A is an n×n tridiagonal matrix. It uses Gaussian elimination with +// partial pivoting. The equation Aᵀ * X = B may be solved by swapping the +// arguments for du and dl. +// +// On entry, dl, d and du contain the sub-diagonal, the diagonal and the +// super-diagonal, respectively, of A. On return, the first n-2 elements of dl, +// the first n-1 elements of du and the first n elements of d may be +// overwritten. +// +// On entry, b contains the n×nrhs right-hand side matrix B. On return, b will +// be overwritten. If ok is true, it will be overwritten by the solution matrix X. +// +// Dgtsv returns whether the solution X has been successfuly computed. +func (impl Implementation) Dgtsv(n, nrhs int, dl, d, du []float64, b []float64, ldb int) (ok bool) { + switch { + case n < 0: + panic(nLT0) + case nrhs < 0: + panic(nrhsLT0) + case ldb < max(1, nrhs): + panic(badLdB) + } + + if n == 0 || nrhs == 0 { + return true + } + + switch { + case len(dl) < n-1: + panic(shortDL) + case len(d) < n: + panic(shortD) + case len(du) < n-1: + panic(shortDU) + case len(b) < (n-1)*ldb+nrhs: + panic(shortB) + } + + dl = dl[:n-1] + d = d[:n] + du = du[:n-1] + + for i := 0; i < n-1; i++ { + if math.Abs(d[i]) >= math.Abs(dl[i]) { + // No row interchange required. + if d[i] == 0 { + return false + } + fact := dl[i] / d[i] + d[i+1] -= fact * du[i] + for j := 0; j < nrhs; j++ { + b[(i+1)*ldb+j] -= fact * b[i*ldb+j] + } + dl[i] = 0 + } else { + // Interchange rows i and i+1. + fact := d[i] / dl[i] + d[i] = dl[i] + tmp := d[i+1] + d[i+1] = du[i] - fact*tmp + du[i] = tmp + if i+1 < n-1 { + dl[i] = du[i+1] + du[i+1] = -fact * dl[i] + } + for j := 0; j < nrhs; j++ { + tmp = b[i*ldb+j] + b[i*ldb+j] = b[(i+1)*ldb+j] + b[(i+1)*ldb+j] = tmp - fact*b[(i+1)*ldb+j] + } + } + } + if d[n-1] == 0 { + return false + } + + // Back solve with the matrix U from the factorization. + for j := 0; j < nrhs; j++ { + b[(n-1)*ldb+j] /= d[n-1] + if n > 1 { + b[(n-2)*ldb+j] = (b[(n-2)*ldb+j] - du[n-2]*b[(n-1)*ldb+j]) / d[n-2] + } + for i := n - 3; i >= 0; i-- { + b[i*ldb+j] = (b[i*ldb+j] - du[i]*b[(i+1)*ldb+j] - dl[i]*b[(i+2)*ldb+j]) / d[i] + } + } + + return true +} diff --git a/lapack/gonum/lapack_test.go b/lapack/gonum/lapack_test.go index 64551780..4766d54a 100644 --- a/lapack/gonum/lapack_test.go +++ b/lapack/gonum/lapack_test.go @@ -148,6 +148,11 @@ func TestDggsvp3(t *testing.T) { testlapack.Dggsvp3Test(t, impl) } +func TestDgtsv(t *testing.T) { + t.Parallel() + testlapack.DgtsvTest(t, impl) +} + func TestDlabrd(t *testing.T) { t.Parallel() testlapack.DlabrdTest(t, impl) diff --git a/lapack/lapack64/lapack64.go b/lapack/lapack64/lapack64.go index 429ee98e..c441bc7c 100644 --- a/lapack/lapack64/lapack64.go +++ b/lapack/lapack64/lapack64.go @@ -420,6 +420,28 @@ func Ggsvd3(jobU, jobV, jobQ lapack.GSVDJob, a, b blas64.General, alpha, beta [] return lapack64.Dggsvd3(jobU, jobV, jobQ, a.Rows, a.Cols, b.Rows, a.Data, max(1, a.Stride), b.Data, max(1, b.Stride), alpha, beta, u.Data, max(1, u.Stride), v.Data, max(1, v.Stride), q.Data, max(1, q.Stride), work, lwork, iwork) } +// Gtsv solves one of the equations +// A * X = B if trans == blas.NoTrans +// Aᵀ * X = B if trans == blas.Trans or blas.ConjTrans +// where A is an n×n tridiagonal matrix. It uses Gaussian elimination with +// partial pivoting. +// +// On entry, a contains the matrix A, on return it will be overwritten. +// +// On entry, b contains the n×nrhs right-hand side matrix B. On return, it will +// be overwritten. If ok is true, it will be overwritten by the solution matrix X. +// +// Gtsv returns whether the solution X has been successfuly computed. +// +// Dgtsv is not part of the lapack.Float64 interface and so calls to Gtsv are +// always executed by the Gonum implementation. +func Gtsv(trans blas.Transpose, a Tridiagonal, b blas64.General) (ok bool) { + if trans != blas.NoTrans { + a.DL, a.DU = a.DU, a.DL + } + return gonum.Implementation{}.Dgtsv(a.N, b.Cols, a.DL, a.D, a.DU, b.Data, max(1, b.Stride)) +} + // Lagtm performs one of the matrix-matrix operations // C = alpha * A * B + beta * C if trans == blas.NoTrans // C = alpha * Aᵀ * B + beta * C if trans == blas.Trans or blas.ConjTrans diff --git a/lapack/testlapack/dgtsv.go b/lapack/testlapack/dgtsv.go new file mode 100644 index 00000000..cdea4238 --- /dev/null +++ b/lapack/testlapack/dgtsv.go @@ -0,0 +1,92 @@ +// Copyright ©2020 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testlapack + +import ( + "fmt" + "math" + "testing" + + "golang.org/x/exp/rand" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" + "gonum.org/v1/gonum/lapack" +) + +type Dgtsver interface { + Dgtsv(n, nrhs int, dl, d, du []float64, b []float64, ldb int) (ok bool) +} + +func DgtsvTest(t *testing.T, impl Dgtsver) { + rnd := rand.New(rand.NewSource(1)) + for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 25, 50} { + for _, nrhs := range []int{0, 1, 2, 3, 4, 10} { + for _, ldb := range []int{max(1, nrhs), nrhs + 3} { + dgtsvTest(t, impl, rnd, n, nrhs, ldb) + } + } + } +} + +func dgtsvTest(t *testing.T, impl Dgtsver, rnd *rand.Rand, n, nrhs, ldb int) { + const ( + tol = 1e-14 + extra = 10 + ) + + name := fmt.Sprintf("Case n=%d,nrhs=%d,ldb=%d", n, nrhs, ldb) + + if n == 0 { + ok := impl.Dgtsv(n, nrhs, nil, nil, nil, nil, ldb) + if !ok { + t.Errorf("%v: unexpected failure for zero size matrix", name) + } + return + } + + // Generate three random diagonals. + var ( + d, dCopy []float64 + dl, dlCopy []float64 + du, duCopy []float64 + ) + d = randomSlice(n+1+extra, rnd) + dCopy = make([]float64, len(d)) + copy(dCopy, d) + if n > 1 { + dl = randomSlice(n+extra, rnd) + dlCopy = make([]float64, len(dl)) + copy(dlCopy, dl) + + du = randomSlice(n+extra, rnd) + duCopy = make([]float64, len(du)) + copy(duCopy, du) + } + + b := randomGeneral(n, nrhs, ldb, rnd) + got := cloneGeneral(b) + + ok := impl.Dgtsv(n, nrhs, dl, d, du, got.Data, got.Stride) + if !ok { + t.Fatalf("%v: unexpected failure in Dgtsv", name) + return + } + + // Compute A*X - B. + dlagtm(blas.NoTrans, n, nrhs, 1, dlCopy, dCopy, duCopy, got.Data, got.Stride, -1, b.Data, b.Stride) + + anorm := dlangt(lapack.MaxColumnSum, n, dlCopy, dCopy, duCopy) + bi := blas64.Implementation() + var resid float64 + for j := 0; j < nrhs; j++ { + bnorm := bi.Dasum(n, b.Data[j:], b.Stride) + xnorm := bi.Dasum(n, got.Data[j:], got.Stride) + resid = math.Max(resid, bnorm/anorm/xnorm) + } + if resid > tol { + t.Errorf("%v: unexpected result; resid=%v,want<=%v", name, resid, tol) + } +} diff --git a/lapack/testlapack/locallapack.go b/lapack/testlapack/locallapack.go new file mode 100644 index 00000000..6ca8c34e --- /dev/null +++ b/lapack/testlapack/locallapack.go @@ -0,0 +1,162 @@ +// Copyright ©2020 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testlapack + +import ( + "math" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/lapack" +) + +func dlagtm(trans blas.Transpose, m, n int, alpha float64, dl, d, du []float64, b []float64, ldb int, beta float64, c []float64, ldc int) { + if m == 0 || n == 0 { + return + } + + if beta != 1 { + if beta == 0 { + for i := 0; i < m; i++ { + ci := c[i*ldc : i*ldc+n] + for j := range ci { + ci[j] = 0 + } + } + } else { + for i := 0; i < m; i++ { + ci := c[i*ldc : i*ldc+n] + for j := range ci { + ci[j] *= beta + } + } + } + } + + if alpha == 0 { + return + } + + if m == 1 { + if alpha == 1 { + for j := 0; j < n; j++ { + c[j] += d[0] * b[j] + } + } else { + for j := 0; j < n; j++ { + c[j] += alpha * d[0] * b[j] + } + } + return + } + + if trans != blas.NoTrans { + dl, du = du, dl + } + + if alpha == 1 { + for j := 0; j < n; j++ { + c[j] += d[0]*b[j] + du[0]*b[ldb+j] + } + for i := 1; i < m-1; i++ { + for j := 0; j < n; j++ { + c[i*ldc+j] += dl[i-1]*b[(i-1)*ldb+j] + d[i]*b[i*ldb+j] + du[i]*b[(i+1)*ldb+j] + } + } + for j := 0; j < n; j++ { + c[(m-1)*ldc+j] += dl[m-2]*b[(m-2)*ldb+j] + d[m-1]*b[(m-1)*ldb+j] + } + } else { + for j := 0; j < n; j++ { + c[j] += alpha * (d[0]*b[j] + du[0]*b[ldb+j]) + } + for i := 1; i < m-1; i++ { + for j := 0; j < n; j++ { + c[i*ldc+j] += alpha * (dl[i-1]*b[(i-1)*ldb+j] + d[i]*b[i*ldb+j] + du[i]*b[(i+1)*ldb+j]) + } + } + for j := 0; j < n; j++ { + c[(m-1)*ldc+j] += alpha * (dl[m-2]*b[(m-2)*ldb+j] + d[m-1]*b[(m-1)*ldb+j]) + } + } +} + +func dlangt(norm lapack.MatrixNorm, n int, dl, d, du []float64) float64 { + if n == 0 { + return 0 + } + + dl = dl[:n-1] + d = d[:n] + du = du[:n-1] + + var anorm float64 + switch norm { + case lapack.MaxAbs: + for _, diag := range [][]float64{dl, d, du} { + for _, di := range diag { + if math.IsNaN(di) { + return di + } + di = math.Abs(di) + if di > anorm { + anorm = di + } + } + } + case lapack.MaxColumnSum: + if n == 1 { + return math.Abs(d[0]) + } + anorm = math.Abs(d[0]) + math.Abs(dl[0]) + if math.IsNaN(anorm) { + return anorm + } + tmp := math.Abs(du[n-2]) + math.Abs(d[n-1]) + if math.IsNaN(tmp) { + return tmp + } + if tmp > anorm { + anorm = tmp + } + for i := 1; i < n-1; i++ { + tmp = math.Abs(du[i-1]) + math.Abs(d[i]) + math.Abs(dl[i]) + if math.IsNaN(tmp) { + return tmp + } + if tmp > anorm { + anorm = tmp + } + } + case lapack.MaxRowSum: + if n == 1 { + return math.Abs(d[0]) + } + anorm = math.Abs(d[0]) + math.Abs(du[0]) + if math.IsNaN(anorm) { + return anorm + } + tmp := math.Abs(dl[n-2]) + math.Abs(d[n-1]) + if math.IsNaN(tmp) { + return tmp + } + if tmp > anorm { + anorm = tmp + } + for i := 1; i < n-1; i++ { + tmp = math.Abs(dl[i-1]) + math.Abs(d[i]) + math.Abs(du[i]) + if math.IsNaN(tmp) { + return tmp + } + if tmp > anorm { + anorm = tmp + } + } + case lapack.Frobenius: + panic("not implemented") + default: + panic("invalid norm") + } + return anorm +}