mirror of
https://github.com/gonum/gonum.git
synced 2025-10-04 14:52:57 +08:00
150 lines
3.7 KiB
Go
150 lines
3.7 KiB
Go
// Copyright ©2015 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package testlapack
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
|
|
"golang.org/x/exp/rand"
|
|
|
|
"gonum.org/v1/gonum/blas"
|
|
"gonum.org/v1/gonum/blas/blas64"
|
|
"gonum.org/v1/gonum/floats"
|
|
"gonum.org/v1/gonum/lapack"
|
|
)
|
|
|
|
type Dlasrer interface {
|
|
Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int)
|
|
}
|
|
|
|
func DlasrTest(t *testing.T, impl Dlasrer) {
|
|
rnd := rand.New(rand.NewSource(1))
|
|
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
|
for _, pivot := range []lapack.Pivot{lapack.Variable, lapack.Top, lapack.Bottom} {
|
|
for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
|
|
for _, test := range []struct {
|
|
m, n, lda int
|
|
}{
|
|
{5, 5, 0},
|
|
{5, 10, 0},
|
|
{10, 5, 0},
|
|
|
|
{5, 5, 20},
|
|
{5, 10, 20},
|
|
{10, 5, 20},
|
|
} {
|
|
m := test.m
|
|
n := test.n
|
|
lda := test.lda
|
|
if lda == 0 {
|
|
lda = n
|
|
}
|
|
a := make([]float64, m*lda)
|
|
for i := range a {
|
|
a[i] = rnd.Float64()
|
|
}
|
|
var s, c []float64
|
|
if side == blas.Left {
|
|
s = make([]float64, m-1)
|
|
c = make([]float64, m-1)
|
|
} else {
|
|
s = make([]float64, n-1)
|
|
c = make([]float64, n-1)
|
|
}
|
|
for k := range s {
|
|
theta := rnd.Float64() * 2 * math.Pi
|
|
s[k] = math.Sin(theta)
|
|
c[k] = math.Cos(theta)
|
|
}
|
|
aCopy := make([]float64, len(a))
|
|
copy(a, aCopy)
|
|
impl.Dlasr(side, pivot, direct, m, n, c, s, a, lda)
|
|
|
|
pSize := m
|
|
if side == blas.Right {
|
|
pSize = n
|
|
}
|
|
p := blas64.General{
|
|
Rows: pSize,
|
|
Cols: pSize,
|
|
Stride: pSize,
|
|
Data: make([]float64, pSize*pSize),
|
|
}
|
|
pk := blas64.General{
|
|
Rows: pSize,
|
|
Cols: pSize,
|
|
Stride: pSize,
|
|
Data: make([]float64, pSize*pSize),
|
|
}
|
|
ptmp := blas64.General{
|
|
Rows: pSize,
|
|
Cols: pSize,
|
|
Stride: pSize,
|
|
Data: make([]float64, pSize*pSize),
|
|
}
|
|
for i := 0; i < pSize; i++ {
|
|
p.Data[i*p.Stride+i] = 1
|
|
ptmp.Data[i*p.Stride+i] = 1
|
|
}
|
|
// Compare to direct computation.
|
|
for k := range s {
|
|
for i := range p.Data {
|
|
pk.Data[i] = 0
|
|
}
|
|
for i := 0; i < pSize; i++ {
|
|
pk.Data[i*p.Stride+i] = 1
|
|
}
|
|
if pivot == lapack.Variable {
|
|
pk.Data[k*p.Stride+k] = c[k]
|
|
pk.Data[k*p.Stride+k+1] = s[k]
|
|
pk.Data[(k+1)*p.Stride+k] = -s[k]
|
|
pk.Data[(k+1)*p.Stride+k+1] = c[k]
|
|
} else if pivot == lapack.Top {
|
|
pk.Data[0] = c[k]
|
|
pk.Data[k+1] = s[k]
|
|
pk.Data[(k+1)*p.Stride] = -s[k]
|
|
pk.Data[(k+1)*p.Stride+k+1] = c[k]
|
|
} else {
|
|
pk.Data[(pSize-1-k)*p.Stride+pSize-k-1] = c[k]
|
|
pk.Data[(pSize-1-k)*p.Stride+pSize-1] = s[k]
|
|
pk.Data[(pSize-1)*p.Stride+pSize-1-k] = -s[k]
|
|
pk.Data[(pSize-1)*p.Stride+pSize-1] = c[k]
|
|
}
|
|
if direct == lapack.Forward {
|
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, pk, ptmp, 0, p)
|
|
} else {
|
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, ptmp, pk, 0, p)
|
|
}
|
|
copy(ptmp.Data, p.Data)
|
|
}
|
|
|
|
aMat := blas64.General{
|
|
Rows: m,
|
|
Cols: n,
|
|
Stride: lda,
|
|
Data: make([]float64, m*lda),
|
|
}
|
|
copy(a, aCopy)
|
|
newA := blas64.General{
|
|
Rows: m,
|
|
Cols: n,
|
|
Stride: lda,
|
|
Data: make([]float64, m*lda),
|
|
}
|
|
if side == blas.Left {
|
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, aMat, 0, newA)
|
|
} else {
|
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, p, 0, newA)
|
|
}
|
|
if !floats.EqualApprox(newA.Data, a, 1e-12) {
|
|
t.Errorf("A update mismatch")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|