lapack/gonum: add Dgesc2 (#1652)

This commit is contained in:
Patricio Whittingslow
2021-06-16 04:23:56 -03:00
committed by GitHub
parent 7fe5bb7344
commit c0f40d7826
4 changed files with 181 additions and 0 deletions

83
lapack/gonum/dgesc2.go Normal file
View File

@@ -0,0 +1,83 @@
// Copyright ©2021 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gonum
import (
"math"
"gonum.org/v1/gonum/blas/blas64"
)
// Dgesc2 solves a system of linear equations
// A * X = scale * RHS
// with a general N-by-N matrix A using the LU factorization with
// complete pivoting computed by Dgetc2. The result is placed in
// rhs on exit.
//
// Dgesc2 is an internal routine. It is exported for testing purposes.
func (impl Implementation) Dgesc2(n int, a []float64, lda int, rhs []float64, ipiv, jpiv []int) (scale float64) {
switch {
case n < 0:
panic(nLT0)
case lda < max(1, n):
panic(badLdA)
}
// Quick return if possible.
if n == 0 {
return 0
}
switch {
case len(a) < (n-1)*lda+n:
panic(shortA)
case len(rhs) < n:
panic(shortRHS)
case len(ipiv) != n:
panic(badLenIpiv)
case len(jpiv) != n:
panic(badLenJpiv)
}
const smlnum = dlamchS / dlamchP
if len(a) < (n-1)*lda+n {
panic(shortA)
}
// Apply permutations ipiv to RHS.
impl.Dlaswp(1, rhs, 1, 0, n-1, ipiv[:n], 1)
// Solve for L part.
for i := 0; i < n-1; i++ {
for j := i + 1; j < n; j++ {
rhs[j] -= float64(a[j*lda+i] * rhs[i])
}
}
// Solve for U part.
scale = 1.0
// Check for scaling.
bi := blas64.Implementation()
i := bi.Idamax(n, rhs, 1)
if 2*smlnum*math.Abs(rhs[i]) > math.Abs(a[(n-1)*lda+(n-1)]) {
temp := 0.5 / math.Abs(rhs[i])
bi.Dscal(n, temp, rhs, 1)
scale *= temp
}
for i := n - 1; i >= 0; i-- {
temp := 1.0 / a[i*lda+i]
rhs[i] *= temp
for j := i + 1; j < n; j++ {
rhs[i] -= float64(rhs[j] * (a[i*lda+j] * temp))
}
}
// Apply permutations jpiv to the solution (rhs).
impl.Dlaswp(1, rhs, 1, 0, n-1, jpiv[:n], -1)
return scale
}

View File

@@ -101,6 +101,7 @@ const (
badLenAlpha = "lapack: bad length of alpha"
badLenBeta = "lapack: bad length of beta"
badLenIpiv = "lapack: bad length of ipiv"
badLenJpiv = "lapack: bad length of jpiv"
badLenJpvt = "lapack: bad length of jpvt"
badLenK = "lapack: bad length of k"
badLenSelected = "lapack: bad length of selected"
@@ -126,6 +127,7 @@ const (
shortIWork = "lapack: insufficient length of iwork"
shortIsgn = "lapack: insufficient length of isgn"
shortQ = "lapack: insufficient length of q"
shortRHS = "lapack: insufficient length of rhs"
shortS = "lapack: insufficient length of s"
shortScale = "lapack: insufficient length of scale"
shortT = "lapack: insufficient length of t"

View File

@@ -92,6 +92,11 @@ func TestDgerq2(t *testing.T) {
testlapack.Dgerq2Test(t, impl)
}
func TestDgesc2(t *testing.T) {
t.Parallel()
testlapack.Dgesc2Test(t, impl)
}
func TestDgeqp3(t *testing.T) {
t.Parallel()
testlapack.Dgeqp3Test(t, impl)

View File

@@ -0,0 +1,91 @@
// Copyright ©2021 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testlapack
import (
"fmt"
"math"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
)
type Dgesc2er interface {
Dgetc2er
// Dgesc2 solves a system of linear equations
// A * X = scale * RHS
// with a general n×n matrix A using the LU factorization with
// complete pivoting computed by Dgetc2. The result is placed in
// rhs on exit.
Dgesc2(n int, a []float64, lda int, rhs []float64, ipiv, jpiv []int) (scale float64)
}
func Dgesc2Test(t *testing.T, impl Dgesc2er) {
const tol = 1e-12
rnd := rand.New(rand.NewSource(1))
for _, test := range []struct {
n, lda int
}{
{3, 0},
{5, 0},
{20, 30},
{200, 0},
} {
testSolveDgesc2(t, impl, rnd, test.n, test.lda, tol)
}
}
func testSolveDgesc2(t *testing.T, impl Dgesc2er, rnd *rand.Rand, n, lda int, tol float64) {
name := fmt.Sprintf("n=%v,lda=%v", n, lda)
if lda == 0 {
lda = n
}
// Generate random general matrix.
a := randomGeneral(n, n, lda, rnd)
// anorm := floats.Norm(a.Data, 1)
// Generate a random solution.
xWant := randomGeneral(n, 1, 1, rnd)
// xnorm := floats.Norm(xWant.Data, 1)
// Compute RHS vector that solves for X such that A*X = scale * RHS
rhs := zeros(n, 1, 1)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, xWant, 1, rhs)
rhsCopy := zeros(n, 1, 1) // Will contain A*x result.
copyGeneral(rhsCopy, rhs)
// Compute LU factorization with full pivoting.
lu := zeros(n, n, lda)
copyGeneral(lu, a)
ipiv := make([]int, n)
jpiv := make([]int, n)
impl.Dgetc2(n, lu.Data, lu.Stride, ipiv, jpiv)
// Solve using lu factorization.
scale := impl.Dgesc2(lu.Rows, lu.Data, lu.Stride, rhs.Data, ipiv, jpiv)
x := rhs
if scale < 0 || scale > 1 {
t.Errorf("%v: resulting scale out of bounds [0,1]", name)
}
var diff float64
for i := range x.Data {
diff = math.Max(diff, math.Abs(xWant.Data[i]-x.Data[i]))
}
if diff > tol {
t.Errorf("%v: unexpected result, diff=%v", name, diff)
}
// |A*X - scale*RHS| / |A| / |X| is an indicator that solution is good
// AxResult := zeros(n, 1, 1)
// blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, x, 1, AxResult)
// blas64.Scal(scale, blas64.Vector{N: n, Data: rhsCopy.Data, Inc: 1})
// floats.Sub(AxResult.Data, rhsCopy.Data)
// residualNorm := floats.Norm(rhsCopy.Data, 1) / anorm / xnorm
// if residualNorm > tol {
// t.Errorf("%v: |A*X - scale*RHS| / |A| / |X| = %g is greater than permissible tol", name, residualNorm)
// }
}