mirror of
https://github.com/gonum/gonum.git
synced 2025-09-27 11:32:32 +08:00
190 lines
4.5 KiB
Go
190 lines
4.5 KiB
Go
// Copyright ©2016 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package testlapack
|
|
|
|
import (
|
|
"fmt"
|
|
"math/rand/v2"
|
|
"testing"
|
|
|
|
"gonum.org/v1/gonum/blas"
|
|
"gonum.org/v1/gonum/blas/blas64"
|
|
"gonum.org/v1/gonum/lapack"
|
|
)
|
|
|
|
type Dlaexcer interface {
|
|
Dlaexc(wantq bool, n int, t []float64, ldt int, q []float64, ldq int, j1, n1, n2 int, work []float64) bool
|
|
}
|
|
|
|
func DlaexcTest(t *testing.T, impl Dlaexcer) {
|
|
rnd := rand.New(rand.NewPCG(1, 1))
|
|
|
|
for _, n := range []int{1, 2, 3, 4, 5, 6, 10, 18, 31, 53} {
|
|
for _, extra := range []int{0, 3} {
|
|
for cas := 0; cas < 100; cas++ {
|
|
testDlaexc(t, impl, rnd, n, extra)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func testDlaexc(t *testing.T, impl Dlaexcer, rnd *rand.Rand, n, extra int) {
|
|
const tol = 1e-14
|
|
|
|
// Generate random T in Schur canonical form.
|
|
tmat, _, _ := randomSchurCanonical(n, n+extra, true, rnd)
|
|
tmatCopy := cloneGeneral(tmat)
|
|
|
|
// Randomly pick the index of the first block.
|
|
j1 := rnd.IntN(n)
|
|
if j1 > 0 && tmat.Data[j1*tmat.Stride+j1-1] != 0 {
|
|
// Adjust j1 if it points to the second row of a 2x2 block.
|
|
j1--
|
|
}
|
|
// Read sizes of the two blocks based on properties of T.
|
|
var n1, n2 int
|
|
switch j1 {
|
|
case n - 1:
|
|
n1, n2 = 1, 0
|
|
case n - 2:
|
|
if tmat.Data[(j1+1)*tmat.Stride+j1] == 0 {
|
|
n1, n2 = 1, 1
|
|
} else {
|
|
n1, n2 = 2, 0
|
|
}
|
|
case n - 3:
|
|
if tmat.Data[(j1+1)*tmat.Stride+j1] == 0 {
|
|
n1, n2 = 1, 2
|
|
} else {
|
|
n1, n2 = 2, 1
|
|
}
|
|
default:
|
|
if tmat.Data[(j1+1)*tmat.Stride+j1] == 0 {
|
|
n1 = 1
|
|
if tmat.Data[(j1+2)*tmat.Stride+j1+1] == 0 {
|
|
n2 = 1
|
|
} else {
|
|
n2 = 2
|
|
}
|
|
} else {
|
|
n1 = 2
|
|
if tmat.Data[(j1+3)*tmat.Stride+j1+2] == 0 {
|
|
n2 = 1
|
|
} else {
|
|
n2 = 2
|
|
}
|
|
}
|
|
}
|
|
|
|
name := fmt.Sprintf("Case n=%v,j1=%v,n1=%v,n2=%v,extra=%v", n, j1, n1, n2, extra)
|
|
|
|
// 1. Test without accumulating Q.
|
|
|
|
wantq := false
|
|
|
|
work := nanSlice(n)
|
|
|
|
ok := impl.Dlaexc(wantq, n, tmat.Data, tmat.Stride, nil, 1, j1, n1, n2, work)
|
|
|
|
// 2. Test with accumulating Q.
|
|
|
|
wantq = true
|
|
|
|
tmat2 := cloneGeneral(tmatCopy)
|
|
q := eye(n, n+extra)
|
|
qCopy := cloneGeneral(q)
|
|
work = nanSlice(n)
|
|
|
|
ok2 := impl.Dlaexc(wantq, n, tmat2.Data, tmat2.Stride, q.Data, q.Stride, j1, n1, n2, work)
|
|
|
|
if !generalOutsideAllNaN(tmat) {
|
|
t.Errorf("%v: out-of-range write to T", name)
|
|
}
|
|
if !generalOutsideAllNaN(tmat2) {
|
|
t.Errorf("%v: out-of-range write to T2", name)
|
|
}
|
|
if !generalOutsideAllNaN(q) {
|
|
t.Errorf("%v: out-of-range write to Q", name)
|
|
}
|
|
|
|
// Check that outputs from cases 1. and 2. are exactly equal, then check one of them.
|
|
if ok != ok2 {
|
|
t.Errorf("%v: ok != ok2", name)
|
|
}
|
|
if !equalGeneral(tmat, tmat2) {
|
|
t.Errorf("%v: T != T2", name)
|
|
}
|
|
|
|
if !ok {
|
|
if n1 == 1 && n2 == 1 {
|
|
t.Errorf("%v: unexpected failure", name)
|
|
} else {
|
|
t.Logf("%v: Dlaexc returned false", name)
|
|
}
|
|
}
|
|
|
|
if !ok || n1 == 0 || n2 == 0 || j1+n1 >= n {
|
|
// Check that T is not modified.
|
|
if !equalGeneral(tmat, tmatCopy) {
|
|
t.Errorf("%v: unexpected modification of T", name)
|
|
}
|
|
// Check that Q is not modified.
|
|
if !equalGeneral(q, qCopy) {
|
|
t.Errorf("%v: unexpected modification of Q", name)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Check that T is not modified outside of rows and columns [j1:j1+n1+n2].
|
|
for i := 0; i < n; i++ {
|
|
if j1 <= i && i < j1+n1+n2 {
|
|
continue
|
|
}
|
|
for j := 0; j < n; j++ {
|
|
if j1 <= j && j < j1+n1+n2 {
|
|
continue
|
|
}
|
|
diff := tmat.Data[i*tmat.Stride+j] - tmatCopy.Data[i*tmatCopy.Stride+j]
|
|
if diff != 0 {
|
|
t.Errorf("%v: unexpected modification of T[%v,%v]", name, i, j)
|
|
}
|
|
}
|
|
}
|
|
|
|
if !isSchurCanonicalGeneral(tmat) {
|
|
t.Errorf("%v: T is not in Schur canonical form", name)
|
|
}
|
|
|
|
// Check that Q is orthogonal.
|
|
resid := residualOrthogonal(q, false)
|
|
if resid > tol {
|
|
t.Errorf("%v: Q is not orthogonal; resid=%v, want<=%v", name, resid, tol)
|
|
}
|
|
|
|
// Check that Q is unchanged outside of columns [j1:j1+n1+n2].
|
|
for i := 0; i < n; i++ {
|
|
for j := 0; j < n; j++ {
|
|
if j1 <= j && j < j1+n1+n2 {
|
|
continue
|
|
}
|
|
diff := q.Data[i*q.Stride+j] - qCopy.Data[i*qCopy.Stride+j]
|
|
if diff != 0 {
|
|
t.Errorf("%v: unexpected modification of Q[%v,%v]", name, i, j)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check that Qᵀ * TOrig * Q == T
|
|
qt := zeros(n, n, n)
|
|
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, tmatCopy, 0, qt)
|
|
qtq := cloneGeneral(tmat)
|
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, -1, qt, q, 1, qtq)
|
|
resid = dlange(lapack.MaxColumnSum, n, n, qtq.Data, qtq.Stride)
|
|
if resid > float64(n)*tol {
|
|
t.Errorf("%v: mismatch between Qᵀ*(initial T)*Q and (final T); resid=%v, want<=%v",
|
|
name, resid, float64(n)*tol)
|
|
}
|
|
}
|