mirror of
https://github.com/gonum/gonum.git
synced 2025-09-27 11:32:32 +08:00
118 lines
2.8 KiB
Go
118 lines
2.8 KiB
Go
// Copyright ©2015 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package testlapack
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"gonum.org/v1/gonum/blas"
|
|
"gonum.org/v1/gonum/floats"
|
|
)
|
|
|
|
type Dpotf2er interface {
|
|
Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
|
|
}
|
|
|
|
func Dpotf2Test(t *testing.T, impl Dpotf2er) {
|
|
for _, test := range []struct {
|
|
a [][]float64
|
|
pos bool
|
|
U [][]float64
|
|
}{
|
|
{
|
|
a: [][]float64{
|
|
{23, 37, 34, 32},
|
|
{108, 71, 48, 48},
|
|
{109, 109, 67, 58},
|
|
{106, 107, 106, 63},
|
|
},
|
|
pos: true,
|
|
U: [][]float64{
|
|
{4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393},
|
|
{0, 3.387958215439679, -1.976308959006481, -1.026654004678691},
|
|
{0, 0, 3.582364210034111, 2.419258947036024},
|
|
{0, 0, 0, 3.401680257083044},
|
|
},
|
|
},
|
|
{
|
|
a: [][]float64{
|
|
{8, 2},
|
|
{2, 4},
|
|
},
|
|
pos: true,
|
|
U: [][]float64{
|
|
{2.82842712474619, 0.707106781186547},
|
|
{0, 1.870828693386971},
|
|
},
|
|
},
|
|
} {
|
|
testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper)
|
|
testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper)
|
|
aT := transpose(test.a)
|
|
L := transpose(test.U)
|
|
testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower)
|
|
testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower)
|
|
}
|
|
}
|
|
|
|
func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) {
|
|
aFlat := flattenTri(a, stride, ul)
|
|
ansFlat := flattenTri(ans, stride, ul)
|
|
pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride)
|
|
if pos != testPos {
|
|
t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos)
|
|
return
|
|
}
|
|
if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) {
|
|
t.Errorf("Result mismatch: Want %v, Got %v", ansFlat, aFlat)
|
|
}
|
|
}
|
|
|
|
// flattenTri with a certain stride. stride must be >= dimension. Puts repeatable
|
|
// nonce values in non-accessed places
|
|
func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 {
|
|
m := len(a)
|
|
n := len(a[0])
|
|
if stride < n {
|
|
panic("bad stride")
|
|
}
|
|
upper := ul == blas.Upper
|
|
v := make([]float64, m*stride)
|
|
count := 1000.0
|
|
for i := 0; i < m; i++ {
|
|
for j := 0; j < stride; j++ {
|
|
if j >= n || (upper && j < i) || (!upper && j > i) {
|
|
// not accessed, so give a unique crazy number
|
|
v[i*stride+j] = count
|
|
count++
|
|
continue
|
|
}
|
|
v[i*stride+j] = a[i][j]
|
|
}
|
|
}
|
|
return v
|
|
}
|
|
|
|
func transpose(a [][]float64) [][]float64 {
|
|
m := len(a)
|
|
n := len(a[0])
|
|
if m != n {
|
|
panic("not square")
|
|
}
|
|
aNew := make([][]float64, m)
|
|
for i := 0; i < m; i++ {
|
|
aNew[i] = make([]float64, n)
|
|
}
|
|
for i := 0; i < m; i++ {
|
|
if len(a[i]) != n {
|
|
panic("bad n size")
|
|
}
|
|
for j := 0; j < n; j++ {
|
|
aNew[j][i] = a[i][j]
|
|
}
|
|
}
|
|
return aNew
|
|
}
|