mirror of
https://github.com/gonum/gonum.git
synced 2025-10-16 12:10:37 +08:00
lapack: add Dgtsv
This commit is contained in:

committed by
Vladimír Chalupecký

parent
df1c4f0d6a
commit
6703b9cb87
99
lapack/gonum/dgtsv.go
Normal file
99
lapack/gonum/dgtsv.go
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright ©2020 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gonum
|
||||
|
||||
import "math"
|
||||
|
||||
// Dgtsv solves the equation
|
||||
// A * X = B
|
||||
// where A is an n×n tridiagonal matrix. It uses Gaussian elimination with
|
||||
// partial pivoting. The equation Aᵀ * X = B may be solved by swapping the
|
||||
// arguments for du and dl.
|
||||
//
|
||||
// On entry, dl, d and du contain the sub-diagonal, the diagonal and the
|
||||
// super-diagonal, respectively, of A. On return, the first n-2 elements of dl,
|
||||
// the first n-1 elements of du and the first n elements of d may be
|
||||
// overwritten.
|
||||
//
|
||||
// On entry, b contains the n×nrhs right-hand side matrix B. On return, b will
|
||||
// be overwritten. If ok is true, it will be overwritten by the solution matrix X.
|
||||
//
|
||||
// Dgtsv returns whether the solution X has been successfuly computed.
|
||||
func (impl Implementation) Dgtsv(n, nrhs int, dl, d, du []float64, b []float64, ldb int) (ok bool) {
|
||||
switch {
|
||||
case n < 0:
|
||||
panic(nLT0)
|
||||
case nrhs < 0:
|
||||
panic(nrhsLT0)
|
||||
case ldb < max(1, nrhs):
|
||||
panic(badLdB)
|
||||
}
|
||||
|
||||
if n == 0 || nrhs == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(dl) < n-1:
|
||||
panic(shortDL)
|
||||
case len(d) < n:
|
||||
panic(shortD)
|
||||
case len(du) < n-1:
|
||||
panic(shortDU)
|
||||
case len(b) < (n-1)*ldb+nrhs:
|
||||
panic(shortB)
|
||||
}
|
||||
|
||||
dl = dl[:n-1]
|
||||
d = d[:n]
|
||||
du = du[:n-1]
|
||||
|
||||
for i := 0; i < n-1; i++ {
|
||||
if math.Abs(d[i]) >= math.Abs(dl[i]) {
|
||||
// No row interchange required.
|
||||
if d[i] == 0 {
|
||||
return false
|
||||
}
|
||||
fact := dl[i] / d[i]
|
||||
d[i+1] -= fact * du[i]
|
||||
for j := 0; j < nrhs; j++ {
|
||||
b[(i+1)*ldb+j] -= fact * b[i*ldb+j]
|
||||
}
|
||||
dl[i] = 0
|
||||
} else {
|
||||
// Interchange rows i and i+1.
|
||||
fact := d[i] / dl[i]
|
||||
d[i] = dl[i]
|
||||
tmp := d[i+1]
|
||||
d[i+1] = du[i] - fact*tmp
|
||||
du[i] = tmp
|
||||
if i+1 < n-1 {
|
||||
dl[i] = du[i+1]
|
||||
du[i+1] = -fact * dl[i]
|
||||
}
|
||||
for j := 0; j < nrhs; j++ {
|
||||
tmp = b[i*ldb+j]
|
||||
b[i*ldb+j] = b[(i+1)*ldb+j]
|
||||
b[(i+1)*ldb+j] = tmp - fact*b[(i+1)*ldb+j]
|
||||
}
|
||||
}
|
||||
}
|
||||
if d[n-1] == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Back solve with the matrix U from the factorization.
|
||||
for j := 0; j < nrhs; j++ {
|
||||
b[(n-1)*ldb+j] /= d[n-1]
|
||||
if n > 1 {
|
||||
b[(n-2)*ldb+j] = (b[(n-2)*ldb+j] - du[n-2]*b[(n-1)*ldb+j]) / d[n-2]
|
||||
}
|
||||
for i := n - 3; i >= 0; i-- {
|
||||
b[i*ldb+j] = (b[i*ldb+j] - du[i]*b[(i+1)*ldb+j] - dl[i]*b[(i+2)*ldb+j]) / d[i]
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
@@ -148,6 +148,11 @@ func TestDggsvp3(t *testing.T) {
|
||||
testlapack.Dggsvp3Test(t, impl)
|
||||
}
|
||||
|
||||
func TestDgtsv(t *testing.T) {
|
||||
t.Parallel()
|
||||
testlapack.DgtsvTest(t, impl)
|
||||
}
|
||||
|
||||
func TestDlabrd(t *testing.T) {
|
||||
t.Parallel()
|
||||
testlapack.DlabrdTest(t, impl)
|
||||
|
@@ -420,6 +420,28 @@ func Ggsvd3(jobU, jobV, jobQ lapack.GSVDJob, a, b blas64.General, alpha, beta []
|
||||
return lapack64.Dggsvd3(jobU, jobV, jobQ, a.Rows, a.Cols, b.Rows, a.Data, max(1, a.Stride), b.Data, max(1, b.Stride), alpha, beta, u.Data, max(1, u.Stride), v.Data, max(1, v.Stride), q.Data, max(1, q.Stride), work, lwork, iwork)
|
||||
}
|
||||
|
||||
// Gtsv solves one of the equations
|
||||
// A * X = B if trans == blas.NoTrans
|
||||
// Aᵀ * X = B if trans == blas.Trans or blas.ConjTrans
|
||||
// where A is an n×n tridiagonal matrix. It uses Gaussian elimination with
|
||||
// partial pivoting.
|
||||
//
|
||||
// On entry, a contains the matrix A, on return it will be overwritten.
|
||||
//
|
||||
// On entry, b contains the n×nrhs right-hand side matrix B. On return, it will
|
||||
// be overwritten. If ok is true, it will be overwritten by the solution matrix X.
|
||||
//
|
||||
// Gtsv returns whether the solution X has been successfuly computed.
|
||||
//
|
||||
// Dgtsv is not part of the lapack.Float64 interface and so calls to Gtsv are
|
||||
// always executed by the Gonum implementation.
|
||||
func Gtsv(trans blas.Transpose, a Tridiagonal, b blas64.General) (ok bool) {
|
||||
if trans != blas.NoTrans {
|
||||
a.DL, a.DU = a.DU, a.DL
|
||||
}
|
||||
return gonum.Implementation{}.Dgtsv(a.N, b.Cols, a.DL, a.D, a.DU, b.Data, max(1, b.Stride))
|
||||
}
|
||||
|
||||
// Lagtm performs one of the matrix-matrix operations
|
||||
// C = alpha * A * B + beta * C if trans == blas.NoTrans
|
||||
// C = alpha * Aᵀ * B + beta * C if trans == blas.Trans or blas.ConjTrans
|
||||
|
92
lapack/testlapack/dgtsv.go
Normal file
92
lapack/testlapack/dgtsv.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// Copyright ©2020 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
|
||||
"gonum.org/v1/gonum/blas"
|
||||
"gonum.org/v1/gonum/blas/blas64"
|
||||
"gonum.org/v1/gonum/lapack"
|
||||
)
|
||||
|
||||
type Dgtsver interface {
|
||||
Dgtsv(n, nrhs int, dl, d, du []float64, b []float64, ldb int) (ok bool)
|
||||
}
|
||||
|
||||
func DgtsvTest(t *testing.T, impl Dgtsver) {
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 25, 50} {
|
||||
for _, nrhs := range []int{0, 1, 2, 3, 4, 10} {
|
||||
for _, ldb := range []int{max(1, nrhs), nrhs + 3} {
|
||||
dgtsvTest(t, impl, rnd, n, nrhs, ldb)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dgtsvTest(t *testing.T, impl Dgtsver, rnd *rand.Rand, n, nrhs, ldb int) {
|
||||
const (
|
||||
tol = 1e-14
|
||||
extra = 10
|
||||
)
|
||||
|
||||
name := fmt.Sprintf("Case n=%d,nrhs=%d,ldb=%d", n, nrhs, ldb)
|
||||
|
||||
if n == 0 {
|
||||
ok := impl.Dgtsv(n, nrhs, nil, nil, nil, nil, ldb)
|
||||
if !ok {
|
||||
t.Errorf("%v: unexpected failure for zero size matrix", name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Generate three random diagonals.
|
||||
var (
|
||||
d, dCopy []float64
|
||||
dl, dlCopy []float64
|
||||
du, duCopy []float64
|
||||
)
|
||||
d = randomSlice(n+1+extra, rnd)
|
||||
dCopy = make([]float64, len(d))
|
||||
copy(dCopy, d)
|
||||
if n > 1 {
|
||||
dl = randomSlice(n+extra, rnd)
|
||||
dlCopy = make([]float64, len(dl))
|
||||
copy(dlCopy, dl)
|
||||
|
||||
du = randomSlice(n+extra, rnd)
|
||||
duCopy = make([]float64, len(du))
|
||||
copy(duCopy, du)
|
||||
}
|
||||
|
||||
b := randomGeneral(n, nrhs, ldb, rnd)
|
||||
got := cloneGeneral(b)
|
||||
|
||||
ok := impl.Dgtsv(n, nrhs, dl, d, du, got.Data, got.Stride)
|
||||
if !ok {
|
||||
t.Fatalf("%v: unexpected failure in Dgtsv", name)
|
||||
return
|
||||
}
|
||||
|
||||
// Compute A*X - B.
|
||||
dlagtm(blas.NoTrans, n, nrhs, 1, dlCopy, dCopy, duCopy, got.Data, got.Stride, -1, b.Data, b.Stride)
|
||||
|
||||
anorm := dlangt(lapack.MaxColumnSum, n, dlCopy, dCopy, duCopy)
|
||||
bi := blas64.Implementation()
|
||||
var resid float64
|
||||
for j := 0; j < nrhs; j++ {
|
||||
bnorm := bi.Dasum(n, b.Data[j:], b.Stride)
|
||||
xnorm := bi.Dasum(n, got.Data[j:], got.Stride)
|
||||
resid = math.Max(resid, bnorm/anorm/xnorm)
|
||||
}
|
||||
if resid > tol {
|
||||
t.Errorf("%v: unexpected result; resid=%v,want<=%v", name, resid, tol)
|
||||
}
|
||||
}
|
162
lapack/testlapack/locallapack.go
Normal file
162
lapack/testlapack/locallapack.go
Normal file
@@ -0,0 +1,162 @@
|
||||
// Copyright ©2020 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"gonum.org/v1/gonum/blas"
|
||||
"gonum.org/v1/gonum/lapack"
|
||||
)
|
||||
|
||||
func dlagtm(trans blas.Transpose, m, n int, alpha float64, dl, d, du []float64, b []float64, ldb int, beta float64, c []float64, ldc int) {
|
||||
if m == 0 || n == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if beta != 1 {
|
||||
if beta == 0 {
|
||||
for i := 0; i < m; i++ {
|
||||
ci := c[i*ldc : i*ldc+n]
|
||||
for j := range ci {
|
||||
ci[j] = 0
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < m; i++ {
|
||||
ci := c[i*ldc : i*ldc+n]
|
||||
for j := range ci {
|
||||
ci[j] *= beta
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if alpha == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if m == 1 {
|
||||
if alpha == 1 {
|
||||
for j := 0; j < n; j++ {
|
||||
c[j] += d[0] * b[j]
|
||||
}
|
||||
} else {
|
||||
for j := 0; j < n; j++ {
|
||||
c[j] += alpha * d[0] * b[j]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if trans != blas.NoTrans {
|
||||
dl, du = du, dl
|
||||
}
|
||||
|
||||
if alpha == 1 {
|
||||
for j := 0; j < n; j++ {
|
||||
c[j] += d[0]*b[j] + du[0]*b[ldb+j]
|
||||
}
|
||||
for i := 1; i < m-1; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] += dl[i-1]*b[(i-1)*ldb+j] + d[i]*b[i*ldb+j] + du[i]*b[(i+1)*ldb+j]
|
||||
}
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
c[(m-1)*ldc+j] += dl[m-2]*b[(m-2)*ldb+j] + d[m-1]*b[(m-1)*ldb+j]
|
||||
}
|
||||
} else {
|
||||
for j := 0; j < n; j++ {
|
||||
c[j] += alpha * (d[0]*b[j] + du[0]*b[ldb+j])
|
||||
}
|
||||
for i := 1; i < m-1; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
c[i*ldc+j] += alpha * (dl[i-1]*b[(i-1)*ldb+j] + d[i]*b[i*ldb+j] + du[i]*b[(i+1)*ldb+j])
|
||||
}
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
c[(m-1)*ldc+j] += alpha * (dl[m-2]*b[(m-2)*ldb+j] + d[m-1]*b[(m-1)*ldb+j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dlangt(norm lapack.MatrixNorm, n int, dl, d, du []float64) float64 {
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
dl = dl[:n-1]
|
||||
d = d[:n]
|
||||
du = du[:n-1]
|
||||
|
||||
var anorm float64
|
||||
switch norm {
|
||||
case lapack.MaxAbs:
|
||||
for _, diag := range [][]float64{dl, d, du} {
|
||||
for _, di := range diag {
|
||||
if math.IsNaN(di) {
|
||||
return di
|
||||
}
|
||||
di = math.Abs(di)
|
||||
if di > anorm {
|
||||
anorm = di
|
||||
}
|
||||
}
|
||||
}
|
||||
case lapack.MaxColumnSum:
|
||||
if n == 1 {
|
||||
return math.Abs(d[0])
|
||||
}
|
||||
anorm = math.Abs(d[0]) + math.Abs(dl[0])
|
||||
if math.IsNaN(anorm) {
|
||||
return anorm
|
||||
}
|
||||
tmp := math.Abs(du[n-2]) + math.Abs(d[n-1])
|
||||
if math.IsNaN(tmp) {
|
||||
return tmp
|
||||
}
|
||||
if tmp > anorm {
|
||||
anorm = tmp
|
||||
}
|
||||
for i := 1; i < n-1; i++ {
|
||||
tmp = math.Abs(du[i-1]) + math.Abs(d[i]) + math.Abs(dl[i])
|
||||
if math.IsNaN(tmp) {
|
||||
return tmp
|
||||
}
|
||||
if tmp > anorm {
|
||||
anorm = tmp
|
||||
}
|
||||
}
|
||||
case lapack.MaxRowSum:
|
||||
if n == 1 {
|
||||
return math.Abs(d[0])
|
||||
}
|
||||
anorm = math.Abs(d[0]) + math.Abs(du[0])
|
||||
if math.IsNaN(anorm) {
|
||||
return anorm
|
||||
}
|
||||
tmp := math.Abs(dl[n-2]) + math.Abs(d[n-1])
|
||||
if math.IsNaN(tmp) {
|
||||
return tmp
|
||||
}
|
||||
if tmp > anorm {
|
||||
anorm = tmp
|
||||
}
|
||||
for i := 1; i < n-1; i++ {
|
||||
tmp = math.Abs(dl[i-1]) + math.Abs(d[i]) + math.Abs(du[i])
|
||||
if math.IsNaN(tmp) {
|
||||
return tmp
|
||||
}
|
||||
if tmp > anorm {
|
||||
anorm = tmp
|
||||
}
|
||||
}
|
||||
case lapack.Frobenius:
|
||||
panic("not implemented")
|
||||
default:
|
||||
panic("invalid norm")
|
||||
}
|
||||
return anorm
|
||||
}
|
Reference in New Issue
Block a user