// Copyright ©2015 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package testlapack import ( "math" "testing" "golang.org/x/exp/rand" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/blas/blas64" "gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/lapack" ) type Dlasrer interface { Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int) } func DlasrTest(t *testing.T, impl Dlasrer) { rnd := rand.New(rand.NewSource(1)) for _, side := range []blas.Side{blas.Left, blas.Right} { for _, pivot := range []lapack.Pivot{lapack.Variable, lapack.Top, lapack.Bottom} { for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} { for _, test := range []struct { m, n, lda int }{ {5, 5, 0}, {5, 10, 0}, {10, 5, 0}, {5, 5, 20}, {5, 10, 20}, {10, 5, 20}, } { m := test.m n := test.n lda := test.lda if lda == 0 { lda = n } a := make([]float64, m*lda) for i := range a { a[i] = rnd.Float64() } var s, c []float64 if side == blas.Left { s = make([]float64, m-1) c = make([]float64, m-1) } else { s = make([]float64, n-1) c = make([]float64, n-1) } for k := range s { theta := rnd.Float64() * 2 * math.Pi s[k] = math.Sin(theta) c[k] = math.Cos(theta) } aCopy := make([]float64, len(a)) copy(a, aCopy) impl.Dlasr(side, pivot, direct, m, n, c, s, a, lda) pSize := m if side == blas.Right { pSize = n } p := blas64.General{ Rows: pSize, Cols: pSize, Stride: pSize, Data: make([]float64, pSize*pSize), } pk := blas64.General{ Rows: pSize, Cols: pSize, Stride: pSize, Data: make([]float64, pSize*pSize), } ptmp := blas64.General{ Rows: pSize, Cols: pSize, Stride: pSize, Data: make([]float64, pSize*pSize), } for i := 0; i < pSize; i++ { p.Data[i*p.Stride+i] = 1 ptmp.Data[i*p.Stride+i] = 1 } // Compare to direct computation. for k := range s { for i := range p.Data { pk.Data[i] = 0 } for i := 0; i < pSize; i++ { pk.Data[i*p.Stride+i] = 1 } if pivot == lapack.Variable { pk.Data[k*p.Stride+k] = c[k] pk.Data[k*p.Stride+k+1] = s[k] pk.Data[(k+1)*p.Stride+k] = -s[k] pk.Data[(k+1)*p.Stride+k+1] = c[k] } else if pivot == lapack.Top { pk.Data[0] = c[k] pk.Data[k+1] = s[k] pk.Data[(k+1)*p.Stride] = -s[k] pk.Data[(k+1)*p.Stride+k+1] = c[k] } else { pk.Data[(pSize-1-k)*p.Stride+pSize-k-1] = c[k] pk.Data[(pSize-1-k)*p.Stride+pSize-1] = s[k] pk.Data[(pSize-1)*p.Stride+pSize-1-k] = -s[k] pk.Data[(pSize-1)*p.Stride+pSize-1] = c[k] } if direct == lapack.Forward { blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, pk, ptmp, 0, p) } else { blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, ptmp, pk, 0, p) } copy(ptmp.Data, p.Data) } aMat := blas64.General{ Rows: m, Cols: n, Stride: lda, Data: make([]float64, m*lda), } copy(a, aCopy) newA := blas64.General{ Rows: m, Cols: n, Stride: lda, Data: make([]float64, m*lda), } if side == blas.Left { blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, aMat, 0, newA) } else { blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, p, 0, newA) } if !floats.EqualApprox(newA.Data, a, 1e-12) { t.Errorf("A update mismatch") } } } } } }