Files
gonum/lapack/testlapack/dgesc2.go
2021-06-16 09:23:56 +02:00

92 lines
2.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright ©2021 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testlapack
import (
"fmt"
"math"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
)
type Dgesc2er interface {
Dgetc2er
// Dgesc2 solves a system of linear equations
// A * X = scale * RHS
// with a general n×n matrix A using the LU factorization with
// complete pivoting computed by Dgetc2. The result is placed in
// rhs on exit.
Dgesc2(n int, a []float64, lda int, rhs []float64, ipiv, jpiv []int) (scale float64)
}
func Dgesc2Test(t *testing.T, impl Dgesc2er) {
const tol = 1e-12
rnd := rand.New(rand.NewSource(1))
for _, test := range []struct {
n, lda int
}{
{3, 0},
{5, 0},
{20, 30},
{200, 0},
} {
testSolveDgesc2(t, impl, rnd, test.n, test.lda, tol)
}
}
func testSolveDgesc2(t *testing.T, impl Dgesc2er, rnd *rand.Rand, n, lda int, tol float64) {
name := fmt.Sprintf("n=%v,lda=%v", n, lda)
if lda == 0 {
lda = n
}
// Generate random general matrix.
a := randomGeneral(n, n, lda, rnd)
// anorm := floats.Norm(a.Data, 1)
// Generate a random solution.
xWant := randomGeneral(n, 1, 1, rnd)
// xnorm := floats.Norm(xWant.Data, 1)
// Compute RHS vector that solves for X such that A*X = scale * RHS
rhs := zeros(n, 1, 1)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, xWant, 1, rhs)
rhsCopy := zeros(n, 1, 1) // Will contain A*x result.
copyGeneral(rhsCopy, rhs)
// Compute LU factorization with full pivoting.
lu := zeros(n, n, lda)
copyGeneral(lu, a)
ipiv := make([]int, n)
jpiv := make([]int, n)
impl.Dgetc2(n, lu.Data, lu.Stride, ipiv, jpiv)
// Solve using lu factorization.
scale := impl.Dgesc2(lu.Rows, lu.Data, lu.Stride, rhs.Data, ipiv, jpiv)
x := rhs
if scale < 0 || scale > 1 {
t.Errorf("%v: resulting scale out of bounds [0,1]", name)
}
var diff float64
for i := range x.Data {
diff = math.Max(diff, math.Abs(xWant.Data[i]-x.Data[i]))
}
if diff > tol {
t.Errorf("%v: unexpected result, diff=%v", name, diff)
}
// |A*X - scale*RHS| / |A| / |X| is an indicator that solution is good
// AxResult := zeros(n, 1, 1)
// blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, x, 1, AxResult)
// blas64.Scal(scale, blas64.Vector{N: n, Data: rhsCopy.Data, Inc: 1})
// floats.Sub(AxResult.Data, rhsCopy.Data)
// residualNorm := floats.Norm(rhsCopy.Data, 1) / anorm / xnorm
// if residualNorm > tol {
// t.Errorf("%v: |A*X - scale*RHS| / |A| / |X| = %g is greater than permissible tol", name, residualNorm)
// }
}