mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 15:16:59 +08:00
92 lines
2.6 KiB
Go
92 lines
2.6 KiB
Go
// Copyright ©2021 The Gonum Authors. All rights reserved.
|
||
// Use of this source code is governed by a BSD-style
|
||
// license that can be found in the LICENSE file.
|
||
package testlapack
|
||
|
||
import (
|
||
"fmt"
|
||
"math"
|
||
"testing"
|
||
|
||
"golang.org/x/exp/rand"
|
||
|
||
"gonum.org/v1/gonum/blas"
|
||
"gonum.org/v1/gonum/blas/blas64"
|
||
)
|
||
|
||
type Dgesc2er interface {
|
||
Dgetc2er
|
||
// Dgesc2 solves a system of linear equations
|
||
// A * X = scale * RHS
|
||
// with a general n×n matrix A using the LU factorization with
|
||
// complete pivoting computed by Dgetc2. The result is placed in
|
||
// rhs on exit.
|
||
Dgesc2(n int, a []float64, lda int, rhs []float64, ipiv, jpiv []int) (scale float64)
|
||
}
|
||
|
||
func Dgesc2Test(t *testing.T, impl Dgesc2er) {
|
||
const tol = 1e-12
|
||
rnd := rand.New(rand.NewSource(1))
|
||
for _, test := range []struct {
|
||
n, lda int
|
||
}{
|
||
{3, 0},
|
||
{5, 0},
|
||
{20, 30},
|
||
{200, 0},
|
||
} {
|
||
testSolveDgesc2(t, impl, rnd, test.n, test.lda, tol)
|
||
}
|
||
}
|
||
|
||
func testSolveDgesc2(t *testing.T, impl Dgesc2er, rnd *rand.Rand, n, lda int, tol float64) {
|
||
name := fmt.Sprintf("n=%v,lda=%v", n, lda)
|
||
if lda == 0 {
|
||
lda = n
|
||
}
|
||
// Generate random general matrix.
|
||
a := randomGeneral(n, n, lda, rnd)
|
||
// anorm := floats.Norm(a.Data, 1)
|
||
|
||
// Generate a random solution.
|
||
xWant := randomGeneral(n, 1, 1, rnd)
|
||
// xnorm := floats.Norm(xWant.Data, 1)
|
||
|
||
// Compute RHS vector that solves for X such that A*X = scale * RHS
|
||
rhs := zeros(n, 1, 1)
|
||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, xWant, 1, rhs)
|
||
rhsCopy := zeros(n, 1, 1) // Will contain A*x result.
|
||
copyGeneral(rhsCopy, rhs)
|
||
// Compute LU factorization with full pivoting.
|
||
lu := zeros(n, n, lda)
|
||
copyGeneral(lu, a)
|
||
ipiv := make([]int, n)
|
||
jpiv := make([]int, n)
|
||
impl.Dgetc2(n, lu.Data, lu.Stride, ipiv, jpiv)
|
||
|
||
// Solve using lu factorization.
|
||
scale := impl.Dgesc2(lu.Rows, lu.Data, lu.Stride, rhs.Data, ipiv, jpiv)
|
||
x := rhs
|
||
if scale < 0 || scale > 1 {
|
||
t.Errorf("%v: resulting scale out of bounds [0,1]", name)
|
||
}
|
||
|
||
var diff float64
|
||
for i := range x.Data {
|
||
diff = math.Max(diff, math.Abs(xWant.Data[i]-x.Data[i]))
|
||
}
|
||
if diff > tol {
|
||
t.Errorf("%v: unexpected result, diff=%v", name, diff)
|
||
}
|
||
// |A*X - scale*RHS| / |A| / |X| is an indicator that solution is good
|
||
// AxResult := zeros(n, 1, 1)
|
||
// blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, x, 1, AxResult)
|
||
// blas64.Scal(scale, blas64.Vector{N: n, Data: rhsCopy.Data, Inc: 1})
|
||
// floats.Sub(AxResult.Data, rhsCopy.Data)
|
||
|
||
// residualNorm := floats.Norm(rhsCopy.Data, 1) / anorm / xnorm
|
||
// if residualNorm > tol {
|
||
// t.Errorf("%v: |A*X - scale*RHS| / |A| / |X| = %g is greater than permissible tol", name, residualNorm)
|
||
// }
|
||
}
|