lapack: add Dgtsv

This commit is contained in:
Vladimir Chalupecky
2020-10-21 11:58:12 +02:00
committed by Vladimír Chalupecký
parent df1c4f0d6a
commit 6703b9cb87
5 changed files with 380 additions and 0 deletions

99
lapack/gonum/dgtsv.go Normal file
View File

@@ -0,0 +1,99 @@
// Copyright ©2020 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gonum
import "math"
// Dgtsv solves the equation
// A * X = B
// where A is an n×n tridiagonal matrix. It uses Gaussian elimination with
// partial pivoting. The equation Aᵀ * X = B may be solved by swapping the
// arguments for du and dl.
//
// On entry, dl, d and du contain the sub-diagonal, the diagonal and the
// super-diagonal, respectively, of A. On return, the first n-2 elements of dl,
// the first n-1 elements of du and the first n elements of d may be
// overwritten.
//
// On entry, b contains the n×nrhs right-hand side matrix B. On return, b will
// be overwritten. If ok is true, it will be overwritten by the solution matrix X.
//
// Dgtsv returns whether the solution X has been successfuly computed.
func (impl Implementation) Dgtsv(n, nrhs int, dl, d, du []float64, b []float64, ldb int) (ok bool) {
switch {
case n < 0:
panic(nLT0)
case nrhs < 0:
panic(nrhsLT0)
case ldb < max(1, nrhs):
panic(badLdB)
}
if n == 0 || nrhs == 0 {
return true
}
switch {
case len(dl) < n-1:
panic(shortDL)
case len(d) < n:
panic(shortD)
case len(du) < n-1:
panic(shortDU)
case len(b) < (n-1)*ldb+nrhs:
panic(shortB)
}
dl = dl[:n-1]
d = d[:n]
du = du[:n-1]
for i := 0; i < n-1; i++ {
if math.Abs(d[i]) >= math.Abs(dl[i]) {
// No row interchange required.
if d[i] == 0 {
return false
}
fact := dl[i] / d[i]
d[i+1] -= fact * du[i]
for j := 0; j < nrhs; j++ {
b[(i+1)*ldb+j] -= fact * b[i*ldb+j]
}
dl[i] = 0
} else {
// Interchange rows i and i+1.
fact := d[i] / dl[i]
d[i] = dl[i]
tmp := d[i+1]
d[i+1] = du[i] - fact*tmp
du[i] = tmp
if i+1 < n-1 {
dl[i] = du[i+1]
du[i+1] = -fact * dl[i]
}
for j := 0; j < nrhs; j++ {
tmp = b[i*ldb+j]
b[i*ldb+j] = b[(i+1)*ldb+j]
b[(i+1)*ldb+j] = tmp - fact*b[(i+1)*ldb+j]
}
}
}
if d[n-1] == 0 {
return false
}
// Back solve with the matrix U from the factorization.
for j := 0; j < nrhs; j++ {
b[(n-1)*ldb+j] /= d[n-1]
if n > 1 {
b[(n-2)*ldb+j] = (b[(n-2)*ldb+j] - du[n-2]*b[(n-1)*ldb+j]) / d[n-2]
}
for i := n - 3; i >= 0; i-- {
b[i*ldb+j] = (b[i*ldb+j] - du[i]*b[(i+1)*ldb+j] - dl[i]*b[(i+2)*ldb+j]) / d[i]
}
}
return true
}

View File

@@ -148,6 +148,11 @@ func TestDggsvp3(t *testing.T) {
testlapack.Dggsvp3Test(t, impl) testlapack.Dggsvp3Test(t, impl)
} }
func TestDgtsv(t *testing.T) {
t.Parallel()
testlapack.DgtsvTest(t, impl)
}
func TestDlabrd(t *testing.T) { func TestDlabrd(t *testing.T) {
t.Parallel() t.Parallel()
testlapack.DlabrdTest(t, impl) testlapack.DlabrdTest(t, impl)

View File

@@ -420,6 +420,28 @@ func Ggsvd3(jobU, jobV, jobQ lapack.GSVDJob, a, b blas64.General, alpha, beta []
return lapack64.Dggsvd3(jobU, jobV, jobQ, a.Rows, a.Cols, b.Rows, a.Data, max(1, a.Stride), b.Data, max(1, b.Stride), alpha, beta, u.Data, max(1, u.Stride), v.Data, max(1, v.Stride), q.Data, max(1, q.Stride), work, lwork, iwork) return lapack64.Dggsvd3(jobU, jobV, jobQ, a.Rows, a.Cols, b.Rows, a.Data, max(1, a.Stride), b.Data, max(1, b.Stride), alpha, beta, u.Data, max(1, u.Stride), v.Data, max(1, v.Stride), q.Data, max(1, q.Stride), work, lwork, iwork)
} }
// Gtsv solves one of the equations
// A * X = B if trans == blas.NoTrans
// Aᵀ * X = B if trans == blas.Trans or blas.ConjTrans
// where A is an n×n tridiagonal matrix. It uses Gaussian elimination with
// partial pivoting.
//
// On entry, a contains the matrix A, on return it will be overwritten.
//
// On entry, b contains the n×nrhs right-hand side matrix B. On return, it will
// be overwritten. If ok is true, it will be overwritten by the solution matrix X.
//
// Gtsv returns whether the solution X has been successfuly computed.
//
// Dgtsv is not part of the lapack.Float64 interface and so calls to Gtsv are
// always executed by the Gonum implementation.
func Gtsv(trans blas.Transpose, a Tridiagonal, b blas64.General) (ok bool) {
if trans != blas.NoTrans {
a.DL, a.DU = a.DU, a.DL
}
return gonum.Implementation{}.Dgtsv(a.N, b.Cols, a.DL, a.D, a.DU, b.Data, max(1, b.Stride))
}
// Lagtm performs one of the matrix-matrix operations // Lagtm performs one of the matrix-matrix operations
// C = alpha * A * B + beta * C if trans == blas.NoTrans // C = alpha * A * B + beta * C if trans == blas.NoTrans
// C = alpha * Aᵀ * B + beta * C if trans == blas.Trans or blas.ConjTrans // C = alpha * Aᵀ * B + beta * C if trans == blas.Trans or blas.ConjTrans

View File

@@ -0,0 +1,92 @@
// Copyright ©2020 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testlapack
import (
"fmt"
"math"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
"gonum.org/v1/gonum/lapack"
)
type Dgtsver interface {
Dgtsv(n, nrhs int, dl, d, du []float64, b []float64, ldb int) (ok bool)
}
func DgtsvTest(t *testing.T, impl Dgtsver) {
rnd := rand.New(rand.NewSource(1))
for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 25, 50} {
for _, nrhs := range []int{0, 1, 2, 3, 4, 10} {
for _, ldb := range []int{max(1, nrhs), nrhs + 3} {
dgtsvTest(t, impl, rnd, n, nrhs, ldb)
}
}
}
}
func dgtsvTest(t *testing.T, impl Dgtsver, rnd *rand.Rand, n, nrhs, ldb int) {
const (
tol = 1e-14
extra = 10
)
name := fmt.Sprintf("Case n=%d,nrhs=%d,ldb=%d", n, nrhs, ldb)
if n == 0 {
ok := impl.Dgtsv(n, nrhs, nil, nil, nil, nil, ldb)
if !ok {
t.Errorf("%v: unexpected failure for zero size matrix", name)
}
return
}
// Generate three random diagonals.
var (
d, dCopy []float64
dl, dlCopy []float64
du, duCopy []float64
)
d = randomSlice(n+1+extra, rnd)
dCopy = make([]float64, len(d))
copy(dCopy, d)
if n > 1 {
dl = randomSlice(n+extra, rnd)
dlCopy = make([]float64, len(dl))
copy(dlCopy, dl)
du = randomSlice(n+extra, rnd)
duCopy = make([]float64, len(du))
copy(duCopy, du)
}
b := randomGeneral(n, nrhs, ldb, rnd)
got := cloneGeneral(b)
ok := impl.Dgtsv(n, nrhs, dl, d, du, got.Data, got.Stride)
if !ok {
t.Fatalf("%v: unexpected failure in Dgtsv", name)
return
}
// Compute A*X - B.
dlagtm(blas.NoTrans, n, nrhs, 1, dlCopy, dCopy, duCopy, got.Data, got.Stride, -1, b.Data, b.Stride)
anorm := dlangt(lapack.MaxColumnSum, n, dlCopy, dCopy, duCopy)
bi := blas64.Implementation()
var resid float64
for j := 0; j < nrhs; j++ {
bnorm := bi.Dasum(n, b.Data[j:], b.Stride)
xnorm := bi.Dasum(n, got.Data[j:], got.Stride)
resid = math.Max(resid, bnorm/anorm/xnorm)
}
if resid > tol {
t.Errorf("%v: unexpected result; resid=%v,want<=%v", name, resid, tol)
}
}

View File

@@ -0,0 +1,162 @@
// Copyright ©2020 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testlapack
import (
"math"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/lapack"
)
func dlagtm(trans blas.Transpose, m, n int, alpha float64, dl, d, du []float64, b []float64, ldb int, beta float64, c []float64, ldc int) {
if m == 0 || n == 0 {
return
}
if beta != 1 {
if beta == 0 {
for i := 0; i < m; i++ {
ci := c[i*ldc : i*ldc+n]
for j := range ci {
ci[j] = 0
}
}
} else {
for i := 0; i < m; i++ {
ci := c[i*ldc : i*ldc+n]
for j := range ci {
ci[j] *= beta
}
}
}
}
if alpha == 0 {
return
}
if m == 1 {
if alpha == 1 {
for j := 0; j < n; j++ {
c[j] += d[0] * b[j]
}
} else {
for j := 0; j < n; j++ {
c[j] += alpha * d[0] * b[j]
}
}
return
}
if trans != blas.NoTrans {
dl, du = du, dl
}
if alpha == 1 {
for j := 0; j < n; j++ {
c[j] += d[0]*b[j] + du[0]*b[ldb+j]
}
for i := 1; i < m-1; i++ {
for j := 0; j < n; j++ {
c[i*ldc+j] += dl[i-1]*b[(i-1)*ldb+j] + d[i]*b[i*ldb+j] + du[i]*b[(i+1)*ldb+j]
}
}
for j := 0; j < n; j++ {
c[(m-1)*ldc+j] += dl[m-2]*b[(m-2)*ldb+j] + d[m-1]*b[(m-1)*ldb+j]
}
} else {
for j := 0; j < n; j++ {
c[j] += alpha * (d[0]*b[j] + du[0]*b[ldb+j])
}
for i := 1; i < m-1; i++ {
for j := 0; j < n; j++ {
c[i*ldc+j] += alpha * (dl[i-1]*b[(i-1)*ldb+j] + d[i]*b[i*ldb+j] + du[i]*b[(i+1)*ldb+j])
}
}
for j := 0; j < n; j++ {
c[(m-1)*ldc+j] += alpha * (dl[m-2]*b[(m-2)*ldb+j] + d[m-1]*b[(m-1)*ldb+j])
}
}
}
func dlangt(norm lapack.MatrixNorm, n int, dl, d, du []float64) float64 {
if n == 0 {
return 0
}
dl = dl[:n-1]
d = d[:n]
du = du[:n-1]
var anorm float64
switch norm {
case lapack.MaxAbs:
for _, diag := range [][]float64{dl, d, du} {
for _, di := range diag {
if math.IsNaN(di) {
return di
}
di = math.Abs(di)
if di > anorm {
anorm = di
}
}
}
case lapack.MaxColumnSum:
if n == 1 {
return math.Abs(d[0])
}
anorm = math.Abs(d[0]) + math.Abs(dl[0])
if math.IsNaN(anorm) {
return anorm
}
tmp := math.Abs(du[n-2]) + math.Abs(d[n-1])
if math.IsNaN(tmp) {
return tmp
}
if tmp > anorm {
anorm = tmp
}
for i := 1; i < n-1; i++ {
tmp = math.Abs(du[i-1]) + math.Abs(d[i]) + math.Abs(dl[i])
if math.IsNaN(tmp) {
return tmp
}
if tmp > anorm {
anorm = tmp
}
}
case lapack.MaxRowSum:
if n == 1 {
return math.Abs(d[0])
}
anorm = math.Abs(d[0]) + math.Abs(du[0])
if math.IsNaN(anorm) {
return anorm
}
tmp := math.Abs(dl[n-2]) + math.Abs(d[n-1])
if math.IsNaN(tmp) {
return tmp
}
if tmp > anorm {
anorm = tmp
}
for i := 1; i < n-1; i++ {
tmp = math.Abs(dl[i-1]) + math.Abs(d[i]) + math.Abs(du[i])
if math.IsNaN(tmp) {
return tmp
}
if tmp > anorm {
anorm = tmp
}
}
case lapack.Frobenius:
panic("not implemented")
default:
panic("invalid norm")
}
return anorm
}