mirror of
https://github.com/gonum/gonum.git
synced 2025-10-28 01:21:44 +08:00
lapack/gonum: add Dgesc2 (#1652)
This commit is contained in:
committed by
GitHub
parent
7fe5bb7344
commit
c0f40d7826
83
lapack/gonum/dgesc2.go
Normal file
83
lapack/gonum/dgesc2.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright ©2021 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gonum
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"gonum.org/v1/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
// Dgesc2 solves a system of linear equations
|
||||
// A * X = scale * RHS
|
||||
// with a general N-by-N matrix A using the LU factorization with
|
||||
// complete pivoting computed by Dgetc2. The result is placed in
|
||||
// rhs on exit.
|
||||
//
|
||||
// Dgesc2 is an internal routine. It is exported for testing purposes.
|
||||
func (impl Implementation) Dgesc2(n int, a []float64, lda int, rhs []float64, ipiv, jpiv []int) (scale float64) {
|
||||
switch {
|
||||
case n < 0:
|
||||
panic(nLT0)
|
||||
case lda < max(1, n):
|
||||
panic(badLdA)
|
||||
}
|
||||
|
||||
// Quick return if possible.
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(a) < (n-1)*lda+n:
|
||||
panic(shortA)
|
||||
case len(rhs) < n:
|
||||
panic(shortRHS)
|
||||
case len(ipiv) != n:
|
||||
panic(badLenIpiv)
|
||||
case len(jpiv) != n:
|
||||
panic(badLenJpiv)
|
||||
}
|
||||
|
||||
const smlnum = dlamchS / dlamchP
|
||||
if len(a) < (n-1)*lda+n {
|
||||
panic(shortA)
|
||||
}
|
||||
|
||||
// Apply permutations ipiv to RHS.
|
||||
impl.Dlaswp(1, rhs, 1, 0, n-1, ipiv[:n], 1)
|
||||
|
||||
// Solve for L part.
|
||||
for i := 0; i < n-1; i++ {
|
||||
for j := i + 1; j < n; j++ {
|
||||
rhs[j] -= float64(a[j*lda+i] * rhs[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Solve for U part.
|
||||
|
||||
scale = 1.0
|
||||
|
||||
// Check for scaling.
|
||||
bi := blas64.Implementation()
|
||||
i := bi.Idamax(n, rhs, 1)
|
||||
if 2*smlnum*math.Abs(rhs[i]) > math.Abs(a[(n-1)*lda+(n-1)]) {
|
||||
temp := 0.5 / math.Abs(rhs[i])
|
||||
bi.Dscal(n, temp, rhs, 1)
|
||||
scale *= temp
|
||||
}
|
||||
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
temp := 1.0 / a[i*lda+i]
|
||||
rhs[i] *= temp
|
||||
for j := i + 1; j < n; j++ {
|
||||
rhs[i] -= float64(rhs[j] * (a[i*lda+j] * temp))
|
||||
}
|
||||
}
|
||||
|
||||
// Apply permutations jpiv to the solution (rhs).
|
||||
impl.Dlaswp(1, rhs, 1, 0, n-1, jpiv[:n], -1)
|
||||
return scale
|
||||
}
|
||||
@@ -101,6 +101,7 @@ const (
|
||||
badLenAlpha = "lapack: bad length of alpha"
|
||||
badLenBeta = "lapack: bad length of beta"
|
||||
badLenIpiv = "lapack: bad length of ipiv"
|
||||
badLenJpiv = "lapack: bad length of jpiv"
|
||||
badLenJpvt = "lapack: bad length of jpvt"
|
||||
badLenK = "lapack: bad length of k"
|
||||
badLenSelected = "lapack: bad length of selected"
|
||||
@@ -126,6 +127,7 @@ const (
|
||||
shortIWork = "lapack: insufficient length of iwork"
|
||||
shortIsgn = "lapack: insufficient length of isgn"
|
||||
shortQ = "lapack: insufficient length of q"
|
||||
shortRHS = "lapack: insufficient length of rhs"
|
||||
shortS = "lapack: insufficient length of s"
|
||||
shortScale = "lapack: insufficient length of scale"
|
||||
shortT = "lapack: insufficient length of t"
|
||||
|
||||
@@ -92,6 +92,11 @@ func TestDgerq2(t *testing.T) {
|
||||
testlapack.Dgerq2Test(t, impl)
|
||||
}
|
||||
|
||||
func TestDgesc2(t *testing.T) {
|
||||
t.Parallel()
|
||||
testlapack.Dgesc2Test(t, impl)
|
||||
}
|
||||
|
||||
func TestDgeqp3(t *testing.T) {
|
||||
t.Parallel()
|
||||
testlapack.Dgeqp3Test(t, impl)
|
||||
|
||||
91
lapack/testlapack/dgesc2.go
Normal file
91
lapack/testlapack/dgesc2.go
Normal file
@@ -0,0 +1,91 @@
|
||||
// Copyright ©2021 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
|
||||
"gonum.org/v1/gonum/blas"
|
||||
"gonum.org/v1/gonum/blas/blas64"
|
||||
)
|
||||
|
||||
type Dgesc2er interface {
|
||||
Dgetc2er
|
||||
// Dgesc2 solves a system of linear equations
|
||||
// A * X = scale * RHS
|
||||
// with a general n×n matrix A using the LU factorization with
|
||||
// complete pivoting computed by Dgetc2. The result is placed in
|
||||
// rhs on exit.
|
||||
Dgesc2(n int, a []float64, lda int, rhs []float64, ipiv, jpiv []int) (scale float64)
|
||||
}
|
||||
|
||||
func Dgesc2Test(t *testing.T, impl Dgesc2er) {
|
||||
const tol = 1e-12
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
for _, test := range []struct {
|
||||
n, lda int
|
||||
}{
|
||||
{3, 0},
|
||||
{5, 0},
|
||||
{20, 30},
|
||||
{200, 0},
|
||||
} {
|
||||
testSolveDgesc2(t, impl, rnd, test.n, test.lda, tol)
|
||||
}
|
||||
}
|
||||
|
||||
func testSolveDgesc2(t *testing.T, impl Dgesc2er, rnd *rand.Rand, n, lda int, tol float64) {
|
||||
name := fmt.Sprintf("n=%v,lda=%v", n, lda)
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
// Generate random general matrix.
|
||||
a := randomGeneral(n, n, lda, rnd)
|
||||
// anorm := floats.Norm(a.Data, 1)
|
||||
|
||||
// Generate a random solution.
|
||||
xWant := randomGeneral(n, 1, 1, rnd)
|
||||
// xnorm := floats.Norm(xWant.Data, 1)
|
||||
|
||||
// Compute RHS vector that solves for X such that A*X = scale * RHS
|
||||
rhs := zeros(n, 1, 1)
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, xWant, 1, rhs)
|
||||
rhsCopy := zeros(n, 1, 1) // Will contain A*x result.
|
||||
copyGeneral(rhsCopy, rhs)
|
||||
// Compute LU factorization with full pivoting.
|
||||
lu := zeros(n, n, lda)
|
||||
copyGeneral(lu, a)
|
||||
ipiv := make([]int, n)
|
||||
jpiv := make([]int, n)
|
||||
impl.Dgetc2(n, lu.Data, lu.Stride, ipiv, jpiv)
|
||||
|
||||
// Solve using lu factorization.
|
||||
scale := impl.Dgesc2(lu.Rows, lu.Data, lu.Stride, rhs.Data, ipiv, jpiv)
|
||||
x := rhs
|
||||
if scale < 0 || scale > 1 {
|
||||
t.Errorf("%v: resulting scale out of bounds [0,1]", name)
|
||||
}
|
||||
|
||||
var diff float64
|
||||
for i := range x.Data {
|
||||
diff = math.Max(diff, math.Abs(xWant.Data[i]-x.Data[i]))
|
||||
}
|
||||
if diff > tol {
|
||||
t.Errorf("%v: unexpected result, diff=%v", name, diff)
|
||||
}
|
||||
// |A*X - scale*RHS| / |A| / |X| is an indicator that solution is good
|
||||
// AxResult := zeros(n, 1, 1)
|
||||
// blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, x, 1, AxResult)
|
||||
// blas64.Scal(scale, blas64.Vector{N: n, Data: rhsCopy.Data, Inc: 1})
|
||||
// floats.Sub(AxResult.Data, rhsCopy.Data)
|
||||
|
||||
// residualNorm := floats.Norm(rhsCopy.Data, 1) / anorm / xnorm
|
||||
// if residualNorm > tol {
|
||||
// t.Errorf("%v: |A*X - scale*RHS| / |A| / |X| = %g is greater than permissible tol", name, residualNorm)
|
||||
// }
|
||||
}
|
||||
Reference in New Issue
Block a user