mirror of
https://github.com/gonum/gonum.git
synced 2025-10-20 13:55:20 +08:00
Merge from master
This commit is contained in:
@@ -19,7 +19,9 @@ const (
|
||||
badIpiv = "lapack: insufficient permutation length"
|
||||
badLdA = "lapack: index of a out of range"
|
||||
badNorm = "lapack: bad norm"
|
||||
badPivot = "lapack: bad pivot"
|
||||
badSide = "lapack: bad side"
|
||||
badSlice = "lapack: bad input slice length"
|
||||
badStore = "lapack: bad store"
|
||||
badTau = "lapack: tau has insufficient length"
|
||||
badTrans = "lapack: bad trans"
|
||||
|
17
lapack.go
17
lapack.go
@@ -50,6 +50,14 @@ const (
|
||||
Backward Direct = 'B' // Reflectors are left-multiplied, H_k * ... * H_2 * H_1
|
||||
)
|
||||
|
||||
// Sort is the sorting order.
|
||||
type Sort byte
|
||||
|
||||
const (
|
||||
SortIncreasing Sort = 'I'
|
||||
SortDecreasing Sort = 'D'
|
||||
)
|
||||
|
||||
// StoreV indicates the storage direction of elementary reflectors.
|
||||
type StoreV byte
|
||||
|
||||
@@ -74,3 +82,12 @@ type MatrixType byte
|
||||
const (
|
||||
General MatrixType = 'G' // A dense matrix (like blas64.General).
|
||||
)
|
||||
|
||||
// Pivot specifies the pivot type for plane rotations
|
||||
type Pivot byte
|
||||
|
||||
const (
|
||||
Variable Pivot = 'V'
|
||||
Top Pivot = 'T'
|
||||
Bottom Pivot = 'B'
|
||||
)
|
||||
|
41
native/dlas2.go
Normal file
41
native/dlas2.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package native
|
||||
|
||||
import "math"
|
||||
|
||||
// Dlas2 computes the singular values of the 2×2 matrix defined by
|
||||
// [F G]
|
||||
// [0 H]
|
||||
// The smaller and larger singular values are returned in that order.
|
||||
func (impl Implementation) Dlas2(f, g, h float64) (ssmin, ssmax float64) {
|
||||
fa := math.Abs(f)
|
||||
ga := math.Abs(g)
|
||||
ha := math.Abs(h)
|
||||
fhmin := math.Min(fa, ha)
|
||||
fhmax := math.Max(fa, ha)
|
||||
if fhmin == 0 {
|
||||
if fhmax == 0 {
|
||||
return 0, ga
|
||||
}
|
||||
v := math.Min(fhmax, ga) / math.Max(fhmax, ga)
|
||||
return 0, math.Max(fhmax, ga) * math.Sqrt(1+v*v)
|
||||
}
|
||||
if ga < fhmax {
|
||||
as := 1 + fhmin/fhmax
|
||||
at := (fhmax - fhmin) / fhmax
|
||||
au := (ga / fhmax) * (ga / fhmax)
|
||||
c := 2 / (math.Sqrt(as*as+au) + math.Sqrt(at*at+au))
|
||||
return fhmin * c, fhmax / c
|
||||
}
|
||||
au := fhmax / ga
|
||||
if au == 0 {
|
||||
return fhmin * fhmax / ga, ga
|
||||
}
|
||||
as := 1 + fhmin/fhmax
|
||||
at := (fhmax - fhmin) / fhmax
|
||||
c := 1 / (math.Sqrt(1+(as*au)*(as*au)) + math.Sqrt(1+(at*au)*(at*au)))
|
||||
return 2 * (fhmin * c) * au, ga / (c + c)
|
||||
}
|
266
native/dlasr.go
Normal file
266
native/dlasr.go
Normal file
@@ -0,0 +1,266 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package native
|
||||
|
||||
import (
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
// Dlasr applies a sequence of plane rotations to the m×n matrix A. This series
|
||||
// of plane rotations is implicitly represented by a matrix P. P is multiplied
|
||||
// by a depending on the value of side -- A = P * A if side == lapack.Left,
|
||||
// A = A * P^T if side == lapack.Right.
|
||||
//
|
||||
//The exact value of P depends on the value of pivot, but in all cases P is
|
||||
// implicitly represented by a series of 2×2 rotation matrices. The entries of
|
||||
// rotation matrix k are defined by s[k] and c[k]
|
||||
// R(k) = [ c[k] s[k]]
|
||||
// [-s[k] s[k]]
|
||||
// If direct == lapack.Forward, the rotation matrices are applied as
|
||||
// P = P(z-1) * ... * P(2) * P(1), while if direct == lapack.Backward they are
|
||||
// applied as P = P(1) * P(2) * ... * P(n).
|
||||
//
|
||||
// pivot defines the mapping of the elements in R(k) to P(k).
|
||||
// If pivot == lapack.Variable, the rotation is performed for the (k, k+1) plane.
|
||||
// P(k) = [1 ]
|
||||
// [ ... ]
|
||||
// [ 1 ]
|
||||
// [ c[k] s[k] ]
|
||||
// [ -s[k] c[k] ]
|
||||
// [ 1 ]
|
||||
// [ ... ]
|
||||
// [ 1]
|
||||
// if pivot == lapack.Top, the rotation is performed for the (1, k+1) plane,
|
||||
// P(k) = [c[k] s[k] ]
|
||||
// [ 1 ]
|
||||
// [ ... ]
|
||||
// [ 1 ]
|
||||
// [-s[k] c[k] ]
|
||||
// [ 1 ]
|
||||
// [ ... ]
|
||||
// [ 1]
|
||||
// and if pivot == lapack.Bottom, the rotation is performed for the (k, z) plane.
|
||||
// P(k) = [1 ]
|
||||
// [ ... ]
|
||||
// [ 1 ]
|
||||
// [ c[k] s[k]]
|
||||
// [ 1 ]
|
||||
// [ ... ]
|
||||
// [ 1 ]
|
||||
// [ -s[k] c[k]]
|
||||
// s and c have length m - 1 if side == blas.Left, and n - 1 if side == blas.Right.
|
||||
func (impl Implementation) Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int) {
|
||||
checkMatrix(m, n, a, lda)
|
||||
if side != blas.Left && side != blas.Right {
|
||||
panic(badSide)
|
||||
}
|
||||
if pivot != lapack.Variable && pivot != lapack.Top && pivot != lapack.Bottom {
|
||||
panic(badPivot)
|
||||
}
|
||||
if direct != lapack.Forward && direct != lapack.Backward {
|
||||
panic(badDirect)
|
||||
}
|
||||
if side == blas.Left {
|
||||
if len(c) < m-1 {
|
||||
panic(badSlice)
|
||||
}
|
||||
if len(s) < m-1 {
|
||||
panic(badSlice)
|
||||
}
|
||||
} else {
|
||||
if len(c) < n-1 {
|
||||
panic(badSlice)
|
||||
}
|
||||
if len(s) < n-1 {
|
||||
panic(badSlice)
|
||||
}
|
||||
}
|
||||
if m == 0 || n == 0 {
|
||||
return
|
||||
}
|
||||
if side == blas.Left {
|
||||
if pivot == lapack.Variable {
|
||||
if direct == lapack.Forward {
|
||||
for j := 0; j < m-1; j++ {
|
||||
ctmp := c[j]
|
||||
stmp := s[j]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < n; i++ {
|
||||
tmp2 := a[j*lda+i]
|
||||
tmp := a[(j+1)*lda+i]
|
||||
a[(j+1)*lda+i] = ctmp*tmp - stmp*tmp2
|
||||
a[j*lda+i] = stmp*tmp + ctmp*tmp2
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
for j := m - 2; j >= 0; j-- {
|
||||
ctmp := c[j]
|
||||
stmp := s[j]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < n; i++ {
|
||||
tmp2 := a[j*lda+i]
|
||||
tmp := a[(j+1)*lda+i]
|
||||
a[(j+1)*lda+i] = ctmp*tmp - stmp*tmp2
|
||||
a[j*lda+i] = stmp*tmp + ctmp*tmp2
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
} else if pivot == lapack.Top {
|
||||
if direct == lapack.Forward {
|
||||
for j := 1; j < m; j++ {
|
||||
ctmp := c[j-1]
|
||||
stmp := s[j-1]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < n; i++ {
|
||||
tmp := a[j*lda+i]
|
||||
tmp2 := a[i]
|
||||
a[j*lda+i] = ctmp*tmp - stmp*tmp2
|
||||
a[i] = stmp*tmp + ctmp*tmp2
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
for j := m - 1; j >= 1; j-- {
|
||||
ctmp := c[j-1]
|
||||
stmp := s[j-1]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < n; i++ {
|
||||
ctmp := c[j-1]
|
||||
stmp := s[j-1]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < n; i++ {
|
||||
tmp := a[j*lda+i]
|
||||
tmp2 := a[i]
|
||||
a[j*lda+i] = ctmp*tmp - stmp*tmp2
|
||||
a[i] = stmp*tmp + ctmp*tmp2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if direct == lapack.Forward {
|
||||
for j := 0; j < m-1; j++ {
|
||||
ctmp := c[j]
|
||||
stmp := s[j]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < n; i++ {
|
||||
tmp := a[j*lda+i]
|
||||
tmp2 := a[(m-1)*lda+i]
|
||||
a[j*lda+i] = stmp*tmp2 + ctmp*tmp
|
||||
a[(m-1)*lda+i] = ctmp*tmp2 - stmp*tmp
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
for j := m - 2; j >= 0; j-- {
|
||||
ctmp := c[j]
|
||||
stmp := s[j]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < n; i++ {
|
||||
tmp := a[j*lda+i]
|
||||
tmp2 := a[(m-1)*lda+i]
|
||||
a[j*lda+i] = stmp*tmp2 + ctmp*tmp
|
||||
a[(m-1)*lda+i] = ctmp*tmp2 - stmp*tmp
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if pivot == lapack.Variable {
|
||||
if direct == lapack.Forward {
|
||||
for j := 0; j < n-1; j++ {
|
||||
ctmp := c[j]
|
||||
stmp := s[j]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < m; i++ {
|
||||
tmp := a[i*lda+j+1]
|
||||
tmp2 := a[i*lda+j]
|
||||
a[i*lda+j+1] = ctmp*tmp - stmp*tmp2
|
||||
a[i*lda+j] = stmp*tmp + ctmp*tmp2
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
for j := n - 2; j >= 0; j-- {
|
||||
ctmp := c[j]
|
||||
stmp := s[j]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < m; i++ {
|
||||
tmp := a[i*lda+j+1]
|
||||
tmp2 := a[i*lda+j]
|
||||
a[i*lda+j+1] = ctmp*tmp - stmp*tmp2
|
||||
a[i*lda+j] = stmp*tmp + ctmp*tmp2
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
} else if pivot == lapack.Top {
|
||||
if direct == lapack.Forward {
|
||||
for j := 1; j < n; j++ {
|
||||
ctmp := c[j-1]
|
||||
stmp := s[j-1]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < m; i++ {
|
||||
tmp := a[i*lda+j]
|
||||
tmp2 := a[i*lda]
|
||||
a[i*lda+j] = ctmp*tmp - stmp*tmp2
|
||||
a[i*lda] = stmp*tmp + ctmp*tmp2
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
for j := n - 1; j >= 1; j-- {
|
||||
ctmp := c[j-1]
|
||||
stmp := s[j-1]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < m; i++ {
|
||||
tmp := a[i*lda+j]
|
||||
tmp2 := a[i*lda]
|
||||
a[i*lda+j] = ctmp*tmp - stmp*tmp2
|
||||
a[i*lda] = stmp*tmp + ctmp*tmp2
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if direct == lapack.Forward {
|
||||
for j := 0; j < n-1; j++ {
|
||||
ctmp := c[j]
|
||||
stmp := s[j]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < m; i++ {
|
||||
tmp := a[i*lda+j]
|
||||
tmp2 := a[i*lda+n-1]
|
||||
a[i*lda+j] = stmp*tmp2 + ctmp*tmp
|
||||
a[i*lda+n-1] = ctmp*tmp2 - stmp*tmp
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
for j := n - 2; j >= 0; j-- {
|
||||
ctmp := c[j]
|
||||
stmp := s[j]
|
||||
if ctmp != 1 || stmp != 0 {
|
||||
for i := 0; i < m; i++ {
|
||||
tmp := a[i*lda+j]
|
||||
tmp2 := a[i*lda+n-1]
|
||||
a[i*lda+j] = stmp*tmp2 + ctmp*tmp
|
||||
a[i*lda+n-1] = ctmp*tmp2 - stmp*tmp
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
27
native/dlasrt.go
Normal file
27
native/dlasrt.go
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package native
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
// Dlasrt sorts the numbers in the input slice d. If sort == lapack.SortIncreasing,
|
||||
// the elements are sorted in increasing order. If sort == lapack.SortDecreasing,
|
||||
// the elements are sorted in decreasing order.
|
||||
func (impl Implementation) Dlasrt(s lapack.Sort, n int, d []float64) {
|
||||
checkVector(n, d, 1)
|
||||
d = d[:n]
|
||||
switch s {
|
||||
default:
|
||||
panic("lapack: bad sort")
|
||||
case lapack.SortIncreasing:
|
||||
sort.Sort(sort.Reverse(sort.Float64Slice(d)))
|
||||
case lapack.SortDecreasing:
|
||||
sort.Float64s(d)
|
||||
}
|
||||
}
|
@@ -25,7 +25,9 @@ const (
|
||||
badIpiv = "lapack: insufficient permutation length"
|
||||
badLdA = "lapack: index of a out of range"
|
||||
badNorm = "lapack: bad norm"
|
||||
badPivot = "lapack: bad pivot"
|
||||
badSide = "lapack: bad side"
|
||||
badSlice = "lapack: bad input slice length"
|
||||
badStore = "lapack: bad store"
|
||||
badTau = "lapack: tau has insufficient length"
|
||||
badTrans = "lapack: bad trans"
|
||||
|
@@ -64,6 +64,10 @@ func TestDlange(t *testing.T) {
|
||||
testlapack.DlangeTest(t, impl)
|
||||
}
|
||||
|
||||
func TestDlas2(t *testing.T) {
|
||||
testlapack.Dlas2Test(t, impl)
|
||||
}
|
||||
|
||||
func TestDlansy(t *testing.T) {
|
||||
testlapack.DlansyTest(t, impl)
|
||||
}
|
||||
@@ -92,6 +96,10 @@ func TestDlartg(t *testing.T) {
|
||||
testlapack.DlartgTest(t, impl)
|
||||
}
|
||||
|
||||
func TestDlasr(t *testing.T) {
|
||||
testlapack.DlasrTest(t, impl)
|
||||
}
|
||||
|
||||
func TestDorml2(t *testing.T) {
|
||||
testlapack.Dorml2Test(t, impl)
|
||||
}
|
||||
|
@@ -1,9 +1,10 @@
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
@@ -44,6 +45,18 @@ func DgeconTest(t *testing.T, impl Dgeconer) {
|
||||
condOne: 0.024740155174938,
|
||||
condInf: 0.012034465570035,
|
||||
},
|
||||
// Dgecon does not match Dpocon for this case. https://github.com/xianyi/OpenBLAS/issues/664.
|
||||
{
|
||||
a: []float64{
|
||||
2.9995576045549965, -2.0898894566158663, 3.965560740124006,
|
||||
-2.0898894566158663, 1.9634729526261008, -2.8681002706874104,
|
||||
3.965560740124006, -2.8681002706874104, 5.502416670471008,
|
||||
},
|
||||
m: 3,
|
||||
n: 3,
|
||||
condOne: 0.024054837369015203,
|
||||
condInf: 0.024054837369015203,
|
||||
},
|
||||
} {
|
||||
m := test.m
|
||||
n := test.n
|
||||
@@ -64,11 +77,17 @@ func DgeconTest(t *testing.T, impl Dgeconer) {
|
||||
iwork := make([]int, n)
|
||||
condOne := impl.Dgecon(lapack.MaxColumnSum, n, a, lda, oneNorm, work, iwork)
|
||||
condInf := impl.Dgecon(lapack.MaxRowSum, n, a, lda, infNorm, work, iwork)
|
||||
if math.Abs(condOne-test.condOne) > 1e-13 {
|
||||
|
||||
// Error if not the same order, otherwise log the difference.
|
||||
if !floats.EqualWithinAbsOrRel(condOne, test.condOne, 1e0, 1e0) {
|
||||
t.Errorf("One norm mismatch. Want %v, got %v.", test.condOne, condOne)
|
||||
} else if !floats.EqualWithinAbsOrRel(condOne, test.condOne, 1e-14, 1e-14) {
|
||||
log.Printf("Dgecon one norm mismatch. Want %v, got %v.", test.condOne, condOne)
|
||||
}
|
||||
if math.Abs(condInf-test.condInf) > 1e-13 {
|
||||
t.Errorf("Inf norm mismatch. Want %v, got %v.", test.condInf, condInf)
|
||||
if !floats.EqualWithinAbsOrRel(condInf, test.condInf, 1e0, 1e0) {
|
||||
t.Errorf("One norm mismatch. Want %v, got %v.", test.condInf, condInf)
|
||||
} else if !floats.EqualWithinAbsOrRel(condInf, test.condInf, 1e-14, 1e-14) {
|
||||
log.Printf("Dgecon one norm mismatch. Want %v, got %v.", test.condInf, condInf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
34
testlapack/dlas2.go
Normal file
34
testlapack/dlas2.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type Dlas2er interface {
|
||||
Dlas2(f, g, h float64) (min, max float64)
|
||||
}
|
||||
|
||||
func Dlas2Test(t *testing.T, impl Dlas2er) {
|
||||
for i, test := range []struct {
|
||||
f, g, h, ssmin, ssmax float64
|
||||
}{
|
||||
// Singular values computed from Octave.
|
||||
{10, 30, 12, 3.567778859365365, 33.634371616111189},
|
||||
{10, 30, -12, 3.567778859365365, 33.634371616111189},
|
||||
{2, 30, -12, 0.741557056404952, 32.364333658088754},
|
||||
{-2, 5, 12, 1.842864429909778, 13.023204317408728},
|
||||
} {
|
||||
ssmin, ssmax := impl.Dlas2(test.f, test.g, test.h)
|
||||
if math.Abs(ssmin-test.ssmin) > 1e-12 {
|
||||
t.Errorf("Case %d, minimal singular value mismatch. Want %v, got %v", i, test.ssmin, ssmin)
|
||||
}
|
||||
if math.Abs(ssmax-test.ssmax) > 1e-12 {
|
||||
t.Errorf("Case %d, minimal singular value mismatch. Want %v, got %v", i, test.ssmin, ssmin)
|
||||
}
|
||||
}
|
||||
}
|
147
testlapack/dlasr.go
Normal file
147
testlapack/dlasr.go
Normal file
@@ -0,0 +1,147 @@
|
||||
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
type Dlasrer interface {
|
||||
Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int)
|
||||
}
|
||||
|
||||
func DlasrTest(t *testing.T, impl Dlasrer) {
|
||||
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||
for _, pivot := range []lapack.Pivot{lapack.Variable, lapack.Top, lapack.Bottom} {
|
||||
for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
|
||||
for _, test := range []struct {
|
||||
m, n, lda int
|
||||
}{
|
||||
{5, 5, 0},
|
||||
{5, 10, 0},
|
||||
{10, 5, 0},
|
||||
|
||||
{5, 5, 20},
|
||||
{5, 10, 20},
|
||||
{10, 5, 20},
|
||||
} {
|
||||
m := test.m
|
||||
n := test.n
|
||||
lda := test.lda
|
||||
if lda == 0 {
|
||||
lda = n
|
||||
}
|
||||
a := make([]float64, m*lda)
|
||||
for i := range a {
|
||||
a[i] = rand.Float64()
|
||||
}
|
||||
var s, c []float64
|
||||
if side == blas.Left {
|
||||
s = make([]float64, m-1)
|
||||
c = make([]float64, m-1)
|
||||
} else {
|
||||
s = make([]float64, n-1)
|
||||
c = make([]float64, n-1)
|
||||
}
|
||||
for k := range s {
|
||||
theta := rand.Float64() * 2 * math.Pi
|
||||
s[k] = math.Sin(theta)
|
||||
c[k] = math.Cos(theta)
|
||||
}
|
||||
aCopy := make([]float64, len(a))
|
||||
copy(a, aCopy)
|
||||
impl.Dlasr(side, pivot, direct, m, n, c, s, a, lda)
|
||||
|
||||
pSize := m
|
||||
if side == blas.Right {
|
||||
pSize = n
|
||||
}
|
||||
p := blas64.General{
|
||||
Rows: pSize,
|
||||
Cols: pSize,
|
||||
Stride: pSize,
|
||||
Data: make([]float64, pSize*pSize),
|
||||
}
|
||||
pk := blas64.General{
|
||||
Rows: pSize,
|
||||
Cols: pSize,
|
||||
Stride: pSize,
|
||||
Data: make([]float64, pSize*pSize),
|
||||
}
|
||||
ptmp := blas64.General{
|
||||
Rows: pSize,
|
||||
Cols: pSize,
|
||||
Stride: pSize,
|
||||
Data: make([]float64, pSize*pSize),
|
||||
}
|
||||
for i := 0; i < pSize; i++ {
|
||||
p.Data[i*p.Stride+i] = 1
|
||||
ptmp.Data[i*p.Stride+i] = 1
|
||||
}
|
||||
// Compare to direct computation.
|
||||
for k := range s {
|
||||
for i := range p.Data {
|
||||
pk.Data[i] = 0
|
||||
}
|
||||
for i := 0; i < pSize; i++ {
|
||||
pk.Data[i*p.Stride+i] = 1
|
||||
}
|
||||
if pivot == lapack.Variable {
|
||||
pk.Data[k*p.Stride+k] = c[k]
|
||||
pk.Data[k*p.Stride+k+1] = s[k]
|
||||
pk.Data[(k+1)*p.Stride+k] = -s[k]
|
||||
pk.Data[(k+1)*p.Stride+k+1] = c[k]
|
||||
} else if pivot == lapack.Top {
|
||||
pk.Data[0] = c[k]
|
||||
pk.Data[k+1] = s[k]
|
||||
pk.Data[(k+1)*p.Stride] = -s[k]
|
||||
pk.Data[(k+1)*p.Stride+k+1] = c[k]
|
||||
} else {
|
||||
pk.Data[(pSize-1-k)*p.Stride+pSize-k-1] = c[k]
|
||||
pk.Data[(pSize-1-k)*p.Stride+pSize-1] = s[k]
|
||||
pk.Data[(pSize-1)*p.Stride+pSize-1-k] = -s[k]
|
||||
pk.Data[(pSize-1)*p.Stride+pSize-1] = c[k]
|
||||
}
|
||||
if direct == lapack.Forward {
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, pk, ptmp, 0, p)
|
||||
} else {
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, ptmp, pk, 0, p)
|
||||
}
|
||||
copy(ptmp.Data, p.Data)
|
||||
}
|
||||
|
||||
aMat := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: lda,
|
||||
Data: make([]float64, m*lda),
|
||||
}
|
||||
copy(a, aCopy)
|
||||
newA := blas64.General{
|
||||
Rows: m,
|
||||
Cols: n,
|
||||
Stride: lda,
|
||||
Data: make([]float64, m*lda),
|
||||
}
|
||||
if side == blas.Left {
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, aMat, 0, newA)
|
||||
} else {
|
||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, p, 0, newA)
|
||||
}
|
||||
if !floats.EqualApprox(newA.Data, a, 1e-12) {
|
||||
t.Errorf("A update mismatch")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,12 +1,13 @@
|
||||
package testlapack
|
||||
|
||||
import (
|
||||
"math"
|
||||
"log"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas"
|
||||
"github.com/gonum/blas/blas64"
|
||||
"github.com/gonum/floats"
|
||||
"github.com/gonum/lapack"
|
||||
)
|
||||
|
||||
@@ -44,6 +45,17 @@ func DpoconTest(t *testing.T, impl Dpoconer) {
|
||||
n: 3,
|
||||
cond: 0.050052137643379,
|
||||
},
|
||||
// Dgecon does not match Dpocon for this case. https://github.com/xianyi/OpenBLAS/issues/664.
|
||||
{
|
||||
a: []float64{
|
||||
2.9995576045549965, -2.0898894566158663, 3.965560740124006,
|
||||
0, 1.9634729526261008, -2.8681002706874104,
|
||||
0, 0, 5.502416670471008,
|
||||
},
|
||||
uplo: blas.Upper,
|
||||
n: 3,
|
||||
cond: 0.024054837369015203,
|
||||
},
|
||||
} {
|
||||
n := test.n
|
||||
a := make([]float64, len(test.a))
|
||||
@@ -60,8 +72,11 @@ func DpoconTest(t *testing.T, impl Dpoconer) {
|
||||
}
|
||||
iwork := make([]int, n)
|
||||
cond := impl.Dpocon(uplo, n, a, lda, anorm, work, iwork)
|
||||
if math.Abs(cond-test.cond) > 1e-14 {
|
||||
// Error if not the same order, otherwise log the difference.
|
||||
if !floats.EqualWithinAbsOrRel(cond, test.cond, 1e0, 1e0) {
|
||||
t.Errorf("Cond mismatch. Want %v, got %v.", test.cond, cond)
|
||||
} else if !floats.EqualWithinAbsOrRel(cond, test.cond, 1e-14, 1e-14) {
|
||||
log.Printf("Dpocon cond mismatch. Want %v, got %v.", test.cond, cond)
|
||||
}
|
||||
}
|
||||
bi := blas64.Implementation()
|
||||
@@ -89,6 +104,9 @@ func DpoconTest(t *testing.T, impl Dpoconer) {
|
||||
copy(aCopy, a)
|
||||
bi.Dgemm(blas.Trans, blas.NoTrans, n, n, n, 1, aCopy, lda, aCopy, lda, 0, a, lda)
|
||||
|
||||
aDat := make([]float64, len(aCopy))
|
||||
copy(aDat, a)
|
||||
|
||||
aDense := make([]float64, len(a))
|
||||
if uplo == blas.Upper {
|
||||
for i := 0; i < n; i++ {
|
||||
@@ -122,8 +140,11 @@ func DpoconTest(t *testing.T, impl Dpoconer) {
|
||||
ipiv := make([]int, n)
|
||||
impl.Dgetrf(n, n, aDense, lda, ipiv)
|
||||
want := impl.Dgecon(lapack.MaxColumnSum, n, aDense, lda, denseNorm, work, iwork)
|
||||
if math.Abs(got-want) > 1e-14 {
|
||||
t.Errorf("Cond mismatch. Want %v, got %v.", want, got)
|
||||
// Error if not the same order, otherwise log the difference.
|
||||
if !floats.EqualWithinAbsOrRel(want, got, 1e0, 1e0) {
|
||||
t.Errorf("Dpocon and Dgecon mismatch. Dpocon %v, Dgecon %v.", got, want)
|
||||
} else if !floats.EqualWithinAbsOrRel(want, got, 1e-14, 1e-14) {
|
||||
log.Printf("Dpocon and Dgecon mismatch. Dpocon %v, Dgecon %v.", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user