Add dlasr

This commit is contained in:
btracey
2015-10-06 08:00:20 -06:00
parent 28f483e1da
commit 1ea1bb2a96
8 changed files with 478 additions and 8 deletions

View File

@@ -19,7 +19,9 @@ const (
badIpiv = "lapack: insufficient permutation length" badIpiv = "lapack: insufficient permutation length"
badLdA = "lapack: index of a out of range" badLdA = "lapack: index of a out of range"
badNorm = "lapack: bad norm" badNorm = "lapack: bad norm"
badPivot = "lapack: bad pivot"
badSide = "lapack: bad side" badSide = "lapack: bad side"
badSlice = "lapack: bad input slice length"
badStore = "lapack: bad store" badStore = "lapack: bad store"
badTau = "lapack: tau has insufficient length" badTau = "lapack: tau has insufficient length"
badTrans = "lapack: bad trans" badTrans = "lapack: bad trans"

View File

@@ -74,3 +74,12 @@ type MatrixType byte
const ( const (
General MatrixType = 'G' // A dense matrix (like blas64.General). General MatrixType = 'G' // A dense matrix (like blas64.General).
) )
// Pivot specifies the pivot type for plane rotations
type Pivot byte
const (
Variable Pivot = 'V'
Top Pivot = 'T'
Bottom Pivot = 'B'
)

266
native/dlasr.go Normal file
View File

@@ -0,0 +1,266 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"github.com/gonum/blas"
"github.com/gonum/lapack"
)
// Dlasr applies a sequence of plane rotations to the m×n matrix A. This series
// of plane rotations is implicitly represented by a matrix P. P is multiplied
// by a depending on the value of side -- A = P * A if side == lapack.Left,
// A = A * P^T if side == lapack.Right.
//
//The exact value of P depends on the value of pivot, but in all cases P is
// implicitly represented by a series of 2×2 rotation matrices. The entries of
// rotation matrix k are defined by s[k] and c[k]
// R(k) = [ c[k] s[k]]
// [-s[k] s[k]]
// If direct == lapack.Forward, the rotation matrices are applied as
// P = P(z-1) * ... * P(2) * P(1), while if direct == lapack.Backward they are
// applied as P = P(1) * P(2) * ... * P(n).
//
// pivot defines the mapping of the elements in R(k) to P(k).
// If pivot == lapack.Variable, the rotation is performed for the (k, k+1) plane.
// P(k) = [1 ]
// [ ... ]
// [ 1 ]
// [ c[k] s[k] ]
// [ -s[k] c[k] ]
// [ 1 ]
// [ ... ]
// [ 1]
// if pivot == lapack.Top, the rotation is performed for the (1, k+1) plane,
// P(k) = [c[k] s[k] ]
// [ 1 ]
// [ ... ]
// [ 1 ]
// [-s[k] c[k] ]
// [ 1 ]
// [ ... ]
// [ 1]
// and if pivot == lapack.Bottom, the rotation is performed for the (k, z) plane.
// P(k) = [1 ]
// [ ... ]
// [ 1 ]
// [ c[k] s[k]]
// [ 1 ]
// [ ... ]
// [ 1 ]
// [ -s[k] c[k]]
// s and c have length m - 1 if side == blas.Left, and n - 1 if side == blas.Right.
func (impl Implementation) Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int) {
checkMatrix(m, n, a, lda)
if side != blas.Left && side != blas.Right {
panic(badSide)
}
if pivot != lapack.Variable && pivot != lapack.Top && pivot != lapack.Bottom {
panic(badPivot)
}
if direct != lapack.Forward && direct != lapack.Backward {
panic(badDirect)
}
if side == blas.Left {
if len(c) < m-1 {
panic(badSlice)
}
if len(s) < m-1 {
panic(badSlice)
}
} else {
if len(c) < n-1 {
panic(badSlice)
}
if len(s) < n-1 {
panic(badSlice)
}
}
if m == 0 || n == 0 {
return
}
if side == blas.Left {
if pivot == lapack.Variable {
if direct == lapack.Forward {
for j := 0; j < m-1; j++ {
ctmp := c[j]
stmp := s[j]
if ctmp != 1 || stmp != 0 {
for i := 0; i < n; i++ {
tmp2 := a[j*lda+i]
tmp := a[(j+1)*lda+i]
a[(j+1)*lda+i] = ctmp*tmp - stmp*tmp2
a[j*lda+i] = stmp*tmp + ctmp*tmp2
}
}
}
return
}
for j := m - 2; j >= 0; j-- {
ctmp := c[j]
stmp := s[j]
if ctmp != 1 || stmp != 0 {
for i := 0; i < n; i++ {
tmp2 := a[j*lda+i]
tmp := a[(j+1)*lda+i]
a[(j+1)*lda+i] = ctmp*tmp - stmp*tmp2
a[j*lda+i] = stmp*tmp + ctmp*tmp2
}
}
}
return
} else if pivot == lapack.Top {
if direct == lapack.Forward {
for j := 1; j < m; j++ {
ctmp := c[j-1]
stmp := s[j-1]
if ctmp != 1 || stmp != 0 {
for i := 0; i < n; i++ {
tmp := a[j*lda+i]
tmp2 := a[i]
a[j*lda+i] = ctmp*tmp - stmp*tmp2
a[i] = stmp*tmp + ctmp*tmp2
}
}
}
return
}
for j := m - 1; j >= 1; j-- {
ctmp := c[j-1]
stmp := s[j-1]
if ctmp != 1 || stmp != 0 {
for i := 0; i < n; i++ {
ctmp := c[j-1]
stmp := s[j-1]
if ctmp != 1 || stmp != 0 {
for i := 0; i < n; i++ {
tmp := a[j*lda+i]
tmp2 := a[i]
a[j*lda+i] = ctmp*tmp - stmp*tmp2
a[i] = stmp*tmp + ctmp*tmp2
}
}
}
}
}
return
}
if direct == lapack.Forward {
for j := 0; j < m-1; j++ {
ctmp := c[j]
stmp := s[j]
if ctmp != 1 || stmp != 0 {
for i := 0; i < n; i++ {
tmp := a[j*lda+i]
tmp2 := a[(m-1)*lda+i]
a[j*lda+i] = stmp*tmp2 + ctmp*tmp
a[(m-1)*lda+i] = ctmp*tmp2 - stmp*tmp
}
}
}
return
}
for j := m - 2; j >= 0; j-- {
ctmp := c[j]
stmp := s[j]
if ctmp != 1 || stmp != 0 {
for i := 0; i < n; i++ {
tmp := a[j*lda+i]
tmp2 := a[(m-1)*lda+i]
a[j*lda+i] = stmp*tmp2 + ctmp*tmp
a[(m-1)*lda+i] = ctmp*tmp2 - stmp*tmp
}
}
}
return
}
if pivot == lapack.Variable {
if direct == lapack.Forward {
for j := 0; j < n-1; j++ {
ctmp := c[j]
stmp := s[j]
if ctmp != 1 || stmp != 0 {
for i := 0; i < m; i++ {
tmp := a[i*lda+j+1]
tmp2 := a[i*lda+j]
a[i*lda+j+1] = ctmp*tmp - stmp*tmp2
a[i*lda+j] = stmp*tmp + ctmp*tmp2
}
}
}
return
}
for j := n - 2; j >= 0; j-- {
ctmp := c[j]
stmp := s[j]
if ctmp != 1 || stmp != 0 {
for i := 0; i < m; i++ {
tmp := a[i*lda+j+1]
tmp2 := a[i*lda+j]
a[i*lda+j+1] = ctmp*tmp - stmp*tmp2
a[i*lda+j] = stmp*tmp + ctmp*tmp2
}
}
}
return
} else if pivot == lapack.Top {
if direct == lapack.Forward {
for j := 1; j < n; j++ {
ctmp := c[j-1]
stmp := s[j-1]
if ctmp != 1 || stmp != 0 {
for i := 0; i < m; i++ {
tmp := a[i*lda+j]
tmp2 := a[i*lda]
a[i*lda+j] = ctmp*tmp - stmp*tmp2
a[i*lda] = stmp*tmp + ctmp*tmp2
}
}
}
return
}
for j := n - 1; j >= 1; j-- {
ctmp := c[j-1]
stmp := s[j-1]
if ctmp != 1 || stmp != 0 {
for i := 0; i < m; i++ {
tmp := a[i*lda+j]
tmp2 := a[i*lda]
a[i*lda+j] = ctmp*tmp - stmp*tmp2
a[i*lda] = stmp*tmp + ctmp*tmp2
}
}
}
return
}
if direct == lapack.Forward {
for j := 0; j < n-1; j++ {
ctmp := c[j]
stmp := s[j]
if ctmp != 1 || stmp != 0 {
for i := 0; i < m; i++ {
tmp := a[i*lda+j]
tmp2 := a[i*lda+n-1]
a[i*lda+j] = stmp*tmp2 + ctmp*tmp
a[i*lda+n-1] = ctmp*tmp2 - stmp*tmp
}
}
}
return
}
for j := n - 2; j >= 0; j-- {
ctmp := c[j]
stmp := s[j]
if ctmp != 1 || stmp != 0 {
for i := 0; i < m; i++ {
tmp := a[i*lda+j]
tmp2 := a[i*lda+n-1]
a[i*lda+j] = stmp*tmp2 + ctmp*tmp
a[i*lda+n-1] = ctmp*tmp2 - stmp*tmp
}
}
}
}

View File

@@ -25,7 +25,9 @@ const (
badIpiv = "lapack: insufficient permutation length" badIpiv = "lapack: insufficient permutation length"
badLdA = "lapack: index of a out of range" badLdA = "lapack: index of a out of range"
badNorm = "lapack: bad norm" badNorm = "lapack: bad norm"
badPivot = "lapack: bad pivot"
badSide = "lapack: bad side" badSide = "lapack: bad side"
badSlice = "lapack: bad input slice length"
badStore = "lapack: bad store" badStore = "lapack: bad store"
badTau = "lapack: tau has insufficient length" badTau = "lapack: tau has insufficient length"
badTrans = "lapack: bad trans" badTrans = "lapack: bad trans"

View File

@@ -88,6 +88,10 @@ func TestDlarft(t *testing.T) {
testlapack.DlarftTest(t, impl) testlapack.DlarftTest(t, impl)
} }
func TestDlasr(t *testing.T) {
testlapack.DlasrTest(t, impl)
}
func TestDorml2(t *testing.T) { func TestDorml2(t *testing.T) {
testlapack.Dorml2Test(t, impl) testlapack.Dorml2Test(t, impl)
} }

View File

@@ -1,9 +1,10 @@
package testlapack package testlapack
import ( import (
"math" "log"
"testing" "testing"
"github.com/gonum/floats"
"github.com/gonum/lapack" "github.com/gonum/lapack"
) )
@@ -44,6 +45,18 @@ func DgeconTest(t *testing.T, impl Dgeconer) {
condOne: 0.024740155174938, condOne: 0.024740155174938,
condInf: 0.012034465570035, condInf: 0.012034465570035,
}, },
// Dgecon does not match Dpocon for this case. https://github.com/xianyi/OpenBLAS/issues/664.
{
a: []float64{
2.9995576045549965, -2.0898894566158663, 3.965560740124006,
-2.0898894566158663, 1.9634729526261008, -2.8681002706874104,
3.965560740124006, -2.8681002706874104, 5.502416670471008,
},
m: 3,
n: 3,
condOne: 0.024054837369015203,
condInf: 0.024054837369015203,
},
} { } {
m := test.m m := test.m
n := test.n n := test.n
@@ -64,11 +77,17 @@ func DgeconTest(t *testing.T, impl Dgeconer) {
iwork := make([]int, n) iwork := make([]int, n)
condOne := impl.Dgecon(lapack.MaxColumnSum, n, a, lda, oneNorm, work, iwork) condOne := impl.Dgecon(lapack.MaxColumnSum, n, a, lda, oneNorm, work, iwork)
condInf := impl.Dgecon(lapack.MaxRowSum, n, a, lda, infNorm, work, iwork) condInf := impl.Dgecon(lapack.MaxRowSum, n, a, lda, infNorm, work, iwork)
if math.Abs(condOne-test.condOne) > 1e-13 {
// Error if not the same order, otherwise log the difference.
if !floats.EqualWithinAbsOrRel(condOne, test.condOne, 1e0, 1e0) {
t.Errorf("One norm mismatch. Want %v, got %v.", test.condOne, condOne) t.Errorf("One norm mismatch. Want %v, got %v.", test.condOne, condOne)
} else if !floats.EqualWithinAbsOrRel(condOne, test.condOne, 1e-14, 1e-14) {
log.Printf("Dgecon one norm mismatch. Want %v, got %v.", test.condOne, condOne)
} }
if math.Abs(condInf-test.condInf) > 1e-13 { if !floats.EqualWithinAbsOrRel(condInf, test.condInf, 1e0, 1e0) {
t.Errorf("Inf norm mismatch. Want %v, got %v.", test.condInf, condInf) t.Errorf("One norm mismatch. Want %v, got %v.", test.condInf, condInf)
} else if !floats.EqualWithinAbsOrRel(condInf, test.condInf, 1e-14, 1e-14) {
log.Printf("Dgecon one norm mismatch. Want %v, got %v.", test.condInf, condInf)
} }
} }
} }

147
testlapack/dlasr.go Normal file
View File

@@ -0,0 +1,147 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testlapack
import (
"math"
"math/rand"
"testing"
"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
"github.com/gonum/floats"
"github.com/gonum/lapack"
)
type Dlasrer interface {
Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int)
}
func DlasrTest(t *testing.T, impl Dlasrer) {
for _, side := range []blas.Side{blas.Left, blas.Right} {
for _, pivot := range []lapack.Pivot{lapack.Variable, lapack.Top, lapack.Bottom} {
for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
for _, test := range []struct {
m, n, lda int
}{
{5, 5, 0},
{5, 10, 0},
{10, 5, 0},
{5, 5, 20},
{5, 10, 20},
{10, 5, 20},
} {
m := test.m
n := test.n
lda := test.lda
if lda == 0 {
lda = n
}
a := make([]float64, m*lda)
for i := range a {
a[i] = rand.Float64()
}
var s, c []float64
if side == blas.Left {
s = make([]float64, m-1)
c = make([]float64, m-1)
} else {
s = make([]float64, n-1)
c = make([]float64, n-1)
}
for k := range s {
theta := rand.Float64() * 2 * math.Pi
s[k] = math.Sin(theta)
c[k] = math.Cos(theta)
}
aCopy := make([]float64, len(a))
copy(a, aCopy)
impl.Dlasr(side, pivot, direct, m, n, c, s, a, lda)
pSize := m
if side == blas.Right {
pSize = n
}
p := blas64.General{
Rows: pSize,
Cols: pSize,
Stride: pSize,
Data: make([]float64, pSize*pSize),
}
pk := blas64.General{
Rows: pSize,
Cols: pSize,
Stride: pSize,
Data: make([]float64, pSize*pSize),
}
ptmp := blas64.General{
Rows: pSize,
Cols: pSize,
Stride: pSize,
Data: make([]float64, pSize*pSize),
}
for i := 0; i < pSize; i++ {
p.Data[i*p.Stride+i] = 1
ptmp.Data[i*p.Stride+i] = 1
}
// Compare to direct computation.
for k := range s {
for i := range p.Data {
pk.Data[i] = 0
}
for i := 0; i < pSize; i++ {
pk.Data[i*p.Stride+i] = 1
}
if pivot == lapack.Variable {
pk.Data[k*p.Stride+k] = c[k]
pk.Data[k*p.Stride+k+1] = s[k]
pk.Data[(k+1)*p.Stride+k] = -s[k]
pk.Data[(k+1)*p.Stride+k+1] = c[k]
} else if pivot == lapack.Top {
pk.Data[0] = c[k]
pk.Data[k+1] = s[k]
pk.Data[(k+1)*p.Stride] = -s[k]
pk.Data[(k+1)*p.Stride+k+1] = c[k]
} else {
pk.Data[(pSize-1-k)*p.Stride+pSize-k-1] = c[k]
pk.Data[(pSize-1-k)*p.Stride+pSize-1] = s[k]
pk.Data[(pSize-1)*p.Stride+pSize-1-k] = -s[k]
pk.Data[(pSize-1)*p.Stride+pSize-1] = c[k]
}
if direct == lapack.Forward {
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, pk, ptmp, 0, p)
} else {
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, ptmp, pk, 0, p)
}
copy(ptmp.Data, p.Data)
}
aMat := blas64.General{
Rows: m,
Cols: n,
Stride: lda,
Data: make([]float64, m*lda),
}
copy(a, aCopy)
newA := blas64.General{
Rows: m,
Cols: n,
Stride: lda,
Data: make([]float64, m*lda),
}
if side == blas.Left {
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, aMat, 0, newA)
} else {
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, p, 0, newA)
}
if !floats.EqualApprox(newA.Data, a, 1e-12) {
t.Errorf("A update mismatch")
}
}
}
}
}
}

View File

@@ -1,12 +1,13 @@
package testlapack package testlapack
import ( import (
"math" "log"
"math/rand" "math/rand"
"testing" "testing"
"github.com/gonum/blas" "github.com/gonum/blas"
"github.com/gonum/blas/blas64" "github.com/gonum/blas/blas64"
"github.com/gonum/floats"
"github.com/gonum/lapack" "github.com/gonum/lapack"
) )
@@ -44,6 +45,17 @@ func DpoconTest(t *testing.T, impl Dpoconer) {
n: 3, n: 3,
cond: 0.050052137643379, cond: 0.050052137643379,
}, },
// Dgecon does not match Dpocon for this case. https://github.com/xianyi/OpenBLAS/issues/664.
{
a: []float64{
2.9995576045549965, -2.0898894566158663, 3.965560740124006,
0, 1.9634729526261008, -2.8681002706874104,
0, 0, 5.502416670471008,
},
uplo: blas.Upper,
n: 3,
cond: 0.024054837369015203,
},
} { } {
n := test.n n := test.n
a := make([]float64, len(test.a)) a := make([]float64, len(test.a))
@@ -60,8 +72,11 @@ func DpoconTest(t *testing.T, impl Dpoconer) {
} }
iwork := make([]int, n) iwork := make([]int, n)
cond := impl.Dpocon(uplo, n, a, lda, anorm, work, iwork) cond := impl.Dpocon(uplo, n, a, lda, anorm, work, iwork)
if math.Abs(cond-test.cond) > 1e-14 { // Error if not the same order, otherwise log the difference.
if !floats.EqualWithinAbsOrRel(cond, test.cond, 1e0, 1e0) {
t.Errorf("Cond mismatch. Want %v, got %v.", test.cond, cond) t.Errorf("Cond mismatch. Want %v, got %v.", test.cond, cond)
} else if !floats.EqualWithinAbsOrRel(cond, test.cond, 1e-14, 1e-14) {
log.Printf("Dpocon cond mismatch. Want %v, got %v.", test.cond, cond)
} }
} }
bi := blas64.Implementation() bi := blas64.Implementation()
@@ -89,6 +104,9 @@ func DpoconTest(t *testing.T, impl Dpoconer) {
copy(aCopy, a) copy(aCopy, a)
bi.Dgemm(blas.Trans, blas.NoTrans, n, n, n, 1, aCopy, lda, aCopy, lda, 0, a, lda) bi.Dgemm(blas.Trans, blas.NoTrans, n, n, n, 1, aCopy, lda, aCopy, lda, 0, a, lda)
aDat := make([]float64, len(aCopy))
copy(aDat, a)
aDense := make([]float64, len(a)) aDense := make([]float64, len(a))
if uplo == blas.Upper { if uplo == blas.Upper {
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
@@ -122,8 +140,11 @@ func DpoconTest(t *testing.T, impl Dpoconer) {
ipiv := make([]int, n) ipiv := make([]int, n)
impl.Dgetrf(n, n, aDense, lda, ipiv) impl.Dgetrf(n, n, aDense, lda, ipiv)
want := impl.Dgecon(lapack.MaxColumnSum, n, aDense, lda, denseNorm, work, iwork) want := impl.Dgecon(lapack.MaxColumnSum, n, aDense, lda, denseNorm, work, iwork)
if math.Abs(got-want) > 1e-14 { // Error if not the same order, otherwise log the difference.
t.Errorf("Cond mismatch. Want %v, got %v.", want, got) if !floats.EqualWithinAbsOrRel(want, got, 1e0, 1e0) {
t.Errorf("Dpocon and Dgecon mismatch. Dpocon %v, Dgecon %v.", got, want)
} else if !floats.EqualWithinAbsOrRel(want, got, 1e-14, 1e-14) {
log.Printf("Dpocon and Dgecon mismatch. Dpocon %v, Dgecon %v.", got, want)
} }
} }
} }