mirror of
https://github.com/gonum/gonum.git
synced 2025-10-20 21:59:25 +08:00
Add Dlantr
This commit is contained in:
@@ -18,6 +18,7 @@ const (
|
|||||||
badDirect = "lapack: bad direct"
|
badDirect = "lapack: bad direct"
|
||||||
badIpiv = "lapack: insufficient permutation length"
|
badIpiv = "lapack: insufficient permutation length"
|
||||||
badLdA = "lapack: index of a out of range"
|
badLdA = "lapack: index of a out of range"
|
||||||
|
badNorm = "lapack: bad norm"
|
||||||
badSide = "lapack: bad side"
|
badSide = "lapack: bad side"
|
||||||
badStore = "lapack: bad store"
|
badStore = "lapack: bad store"
|
||||||
badTau = "lapack: tau has insufficient length"
|
badTau = "lapack: tau has insufficient length"
|
||||||
@@ -66,6 +67,49 @@ type Implementation struct{}
|
|||||||
|
|
||||||
var _ lapack.Float64 = Implementation{}
|
var _ lapack.Float64 = Implementation{}
|
||||||
|
|
||||||
|
// Dlange computes the matrix norm of the general m×n matrix a. The input norm
|
||||||
|
// specifies the norm computed.
|
||||||
|
// lapack.MaxAbs: the maximum absolute value of an element.
|
||||||
|
// lapack.MaxColumnSum: the maximum column sum of the absolute values of the entries.
|
||||||
|
// lapack.MaxRowSum: the maximum row sum of the absolute values of the entries.
|
||||||
|
// lapack.Frobenius: the square root of the sum of the squares of the entries.
|
||||||
|
// If norm == lapack.MaxColumnSum, work must be of length n, and this function will panic otherwise.
|
||||||
|
// There are no restrictions on work for the other matrix norms.
|
||||||
|
func (impl Implementation) Dlange(norm lapack.MatrixNorm, m, n int, a []float64, lda int, work []float64) float64 {
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
switch norm {
|
||||||
|
case lapack.MaxRowSum, lapack.MaxColumnSum, lapack.NormFrob, lapack.MaxAbs:
|
||||||
|
default:
|
||||||
|
panic(badNorm)
|
||||||
|
}
|
||||||
|
if norm == lapack.MaxColumnSum && len(work) < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
return clapack.Dlange(byte(norm), m, n, a, lda)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dlantr computes the specified norm of an m×n trapezoidal matrix A. If
|
||||||
|
// norm == lapack.MaxColumnSum work must have length at least n, otherwise work
|
||||||
|
// is unused.
|
||||||
|
func (impl Implementation) Dlantr(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, m, n int, a []float64, lda int, work []float64) float64 {
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
switch norm {
|
||||||
|
case lapack.MaxRowSum, lapack.MaxColumnSum, lapack.NormFrob, lapack.MaxAbs:
|
||||||
|
default:
|
||||||
|
panic(badNorm)
|
||||||
|
}
|
||||||
|
if uplo != blas.Upper && uplo != blas.Lower {
|
||||||
|
panic(badUplo)
|
||||||
|
}
|
||||||
|
if diag != blas.Unit && diag != blas.NonUnit {
|
||||||
|
panic(badDiag)
|
||||||
|
}
|
||||||
|
if norm == lapack.MaxColumnSum && len(work) < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
return clapack.Dlantr(byte(norm), uplo, diag, m, n, a, lda)
|
||||||
|
}
|
||||||
|
|
||||||
// Dpotrf computes the cholesky decomposition of the symmetric positive definite
|
// Dpotrf computes the cholesky decomposition of the symmetric positive definite
|
||||||
// matrix a. If ul == blas.Upper, then a is stored as an upper-triangular matrix,
|
// matrix a. If ul == blas.Upper, then a is stored as an upper-triangular matrix,
|
||||||
// and a = U U^T is stored in place into a. If ul == blas.Lower, then a = L L^T
|
// and a = U U^T is stored in place into a. If ul == blas.Lower, then a = L L^T
|
||||||
|
@@ -13,6 +13,17 @@ import (
|
|||||||
|
|
||||||
var impl = Implementation{}
|
var impl = Implementation{}
|
||||||
|
|
||||||
|
func TestDlange(t *testing.T) {
|
||||||
|
testlapack.DlangeTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The results from Dlantr do not match the results from Dlange. In some cases,
|
||||||
|
// there also appear to be memory corruption issues.
|
||||||
|
// TODO(btracey): Re-enable this test when the implementations are fixed.
|
||||||
|
// func TestDlantr(t *testing.T) {
|
||||||
|
// testlapack.DlantrTest(t, impl)
|
||||||
|
// }
|
||||||
|
|
||||||
func TestDpotrf(t *testing.T) {
|
func TestDpotrf(t *testing.T) {
|
||||||
testlapack.DpotrfTest(t, impl)
|
testlapack.DpotrfTest(t, impl)
|
||||||
}
|
}
|
||||||
|
@@ -21,6 +21,14 @@ import (
|
|||||||
func (impl Implementation) Dlange(norm lapack.MatrixNorm, m, n int, a []float64, lda int, work []float64) float64 {
|
func (impl Implementation) Dlange(norm lapack.MatrixNorm, m, n int, a []float64, lda int, work []float64) float64 {
|
||||||
// TODO(btracey): These should probably be refactored to use BLAS calls.
|
// TODO(btracey): These should probably be refactored to use BLAS calls.
|
||||||
checkMatrix(m, n, a, lda)
|
checkMatrix(m, n, a, lda)
|
||||||
|
switch norm {
|
||||||
|
case lapack.MaxRowSum, lapack.MaxColumnSum, lapack.NormFrob, lapack.MaxAbs:
|
||||||
|
default:
|
||||||
|
panic(badNorm)
|
||||||
|
}
|
||||||
|
if norm == lapack.MaxColumnSum && len(work) < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
if m == 0 && n == 0 {
|
if m == 0 && n == 0 {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
248
native/dlantr.go
Normal file
248
native/dlantr.go
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dlantr computes the specified norm of an m×n trapezoidal matrix A. If
|
||||||
|
// norm == lapack.MaxColumnSum work must have length at least n, otherwise work
|
||||||
|
// is unused.
|
||||||
|
func (impl Implementation) Dlantr(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, m, n int, a []float64, lda int, work []float64) float64 {
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
switch norm {
|
||||||
|
case lapack.MaxRowSum, lapack.MaxColumnSum, lapack.NormFrob, lapack.MaxAbs:
|
||||||
|
default:
|
||||||
|
panic(badNorm)
|
||||||
|
}
|
||||||
|
if uplo != blas.Upper && uplo != blas.Lower {
|
||||||
|
panic(badUplo)
|
||||||
|
}
|
||||||
|
if diag != blas.Unit && diag != blas.NonUnit {
|
||||||
|
panic(badDiag)
|
||||||
|
}
|
||||||
|
if norm == lapack.MaxColumnSum && len(work) < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
if min(m, n) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
switch norm {
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
case lapack.MaxAbs:
|
||||||
|
if diag == blas.Unit {
|
||||||
|
value := 1.0
|
||||||
|
if uplo == blas.Upper {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := i + 1; j < n; j++ {
|
||||||
|
tmp := math.Abs(a[i*lda+j])
|
||||||
|
if math.IsNaN(tmp) {
|
||||||
|
return tmp
|
||||||
|
}
|
||||||
|
if tmp > value {
|
||||||
|
value = tmp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
for i := 1; i < m; i++ {
|
||||||
|
for j := 0; j < min(i, n); j++ {
|
||||||
|
tmp := math.Abs(a[i*lda+j])
|
||||||
|
if math.IsNaN(tmp) {
|
||||||
|
return tmp
|
||||||
|
}
|
||||||
|
if tmp > value {
|
||||||
|
value = tmp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
var value float64
|
||||||
|
if uplo == blas.Upper {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := i; j < n; j++ {
|
||||||
|
tmp := math.Abs(a[i*lda+j])
|
||||||
|
if math.IsNaN(tmp) {
|
||||||
|
return tmp
|
||||||
|
}
|
||||||
|
if tmp > value {
|
||||||
|
value = tmp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j <= min(i, n-1); j++ {
|
||||||
|
tmp := math.Abs(a[i*lda+j])
|
||||||
|
if math.IsNaN(tmp) {
|
||||||
|
return tmp
|
||||||
|
}
|
||||||
|
if tmp > value {
|
||||||
|
value = tmp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
case lapack.MaxColumnSum:
|
||||||
|
if diag == blas.Unit {
|
||||||
|
for i := 0; i < min(m, n); i++ {
|
||||||
|
work[i] = 1
|
||||||
|
}
|
||||||
|
for i := min(m, n); i < n; i++ {
|
||||||
|
work[i] = 0
|
||||||
|
}
|
||||||
|
if uplo == blas.Upper {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := i + 1; j < n; j++ {
|
||||||
|
work[j] += math.Abs(a[i*lda+j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 1; i < m; i++ {
|
||||||
|
for j := 0; j < min(i, n); j++ {
|
||||||
|
work[j] += math.Abs(a[i*lda+j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
work[i] = 0
|
||||||
|
}
|
||||||
|
if uplo == blas.Upper {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := i; j < n; j++ {
|
||||||
|
work[j] += math.Abs(a[i*lda+j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j <= min(i, n-1); j++ {
|
||||||
|
work[j] += math.Abs(a[i*lda+j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var max float64
|
||||||
|
for _, v := range work {
|
||||||
|
if math.IsNaN(v) {
|
||||||
|
return math.NaN()
|
||||||
|
}
|
||||||
|
if v > max {
|
||||||
|
max = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max
|
||||||
|
case lapack.MaxRowSum:
|
||||||
|
var maxsum float64
|
||||||
|
if diag == blas.Unit {
|
||||||
|
if uplo == blas.Upper {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
var sum float64
|
||||||
|
if i < min(m, n) {
|
||||||
|
sum = 1
|
||||||
|
}
|
||||||
|
for j := i + 1; j < n; j++ {
|
||||||
|
sum += math.Abs(a[i*lda+j])
|
||||||
|
}
|
||||||
|
if math.IsNaN(sum) {
|
||||||
|
return math.NaN()
|
||||||
|
}
|
||||||
|
if sum > maxsum {
|
||||||
|
maxsum = sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return maxsum
|
||||||
|
} else {
|
||||||
|
for i := 1; i < m; i++ {
|
||||||
|
var sum float64
|
||||||
|
if i < min(m, n) {
|
||||||
|
sum = 1
|
||||||
|
}
|
||||||
|
for j := 0; j < min(i, n); j++ {
|
||||||
|
sum += math.Abs(a[i*lda+j])
|
||||||
|
}
|
||||||
|
if math.IsNaN(sum) {
|
||||||
|
return math.NaN()
|
||||||
|
}
|
||||||
|
if sum > maxsum {
|
||||||
|
maxsum = sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return maxsum
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if uplo == blas.Upper {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
var sum float64
|
||||||
|
for j := i; j < n; j++ {
|
||||||
|
sum += math.Abs(a[i*lda+j])
|
||||||
|
}
|
||||||
|
if math.IsNaN(sum) {
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
if sum > maxsum {
|
||||||
|
maxsum = sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return maxsum
|
||||||
|
} else {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
var sum float64
|
||||||
|
for j := 0; j <= min(i, n-1); j++ {
|
||||||
|
sum += math.Abs(a[i*lda+j])
|
||||||
|
}
|
||||||
|
if math.IsNaN(sum) {
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
if sum > maxsum {
|
||||||
|
maxsum = sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return maxsum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case lapack.NormFrob:
|
||||||
|
var nrm float64
|
||||||
|
if diag == blas.Unit {
|
||||||
|
if uplo == blas.Upper {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := i + 1; j < n; j++ {
|
||||||
|
tmp := a[i*lda+j]
|
||||||
|
nrm += tmp * tmp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 1; i < m; i++ {
|
||||||
|
for j := 0; j < min(i, n); j++ {
|
||||||
|
tmp := a[i*lda+j]
|
||||||
|
nrm += tmp * tmp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nrm += float64(min(m, n))
|
||||||
|
} else {
|
||||||
|
if uplo == blas.Upper {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := i; j < n; j++ {
|
||||||
|
tmp := math.Abs(a[i*lda+j])
|
||||||
|
nrm += tmp * tmp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j <= min(i, n-1); j++ {
|
||||||
|
tmp := math.Abs(a[i*lda+j])
|
||||||
|
nrm += tmp * tmp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return math.Sqrt(nrm)
|
||||||
|
}
|
||||||
|
}
|
@@ -24,6 +24,7 @@ const (
|
|||||||
badDirect = "lapack: bad direct"
|
badDirect = "lapack: bad direct"
|
||||||
badIpiv = "lapack: insufficient permutation length"
|
badIpiv = "lapack: insufficient permutation length"
|
||||||
badLdA = "lapack: index of a out of range"
|
badLdA = "lapack: index of a out of range"
|
||||||
|
badNorm = "lapack: bad norm"
|
||||||
badSide = "lapack: bad side"
|
badSide = "lapack: bad side"
|
||||||
badStore = "lapack: bad store"
|
badStore = "lapack: bad store"
|
||||||
badTau = "lapack: tau has insufficient length"
|
badTau = "lapack: tau has insufficient length"
|
||||||
|
@@ -48,6 +48,10 @@ func TestDlange(t *testing.T) {
|
|||||||
testlapack.DlangeTest(t, impl)
|
testlapack.DlangeTest(t, impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDlantr(t *testing.T) {
|
||||||
|
testlapack.DlantrTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDlarfb(t *testing.T) {
|
func TestDlarfb(t *testing.T) {
|
||||||
testlapack.DlarfbTest(t, impl)
|
testlapack.DlarfbTest(t, impl)
|
||||||
}
|
}
|
||||||
|
84
testlapack/dlantr.go
Normal file
84
testlapack/dlantr.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dlantrer interface {
|
||||||
|
Dlanger
|
||||||
|
Dlantr(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, m, n int, a []float64, lda int, work []float64) float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func DlantrTest(t *testing.T, impl Dlantrer) {
|
||||||
|
for _, norm := range []lapack.MatrixNorm{lapack.MaxAbs, lapack.MaxColumnSum, lapack.MaxRowSum, lapack.NormFrob} {
|
||||||
|
for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} {
|
||||||
|
for _, uplo := range []blas.Uplo{blas.Lower, blas.Upper} {
|
||||||
|
for _, test := range []struct {
|
||||||
|
m, n, lda int
|
||||||
|
}{
|
||||||
|
{3, 3, 0},
|
||||||
|
{3, 5, 0},
|
||||||
|
{10, 5, 0},
|
||||||
|
|
||||||
|
{5, 5, 11},
|
||||||
|
{5, 10, 11},
|
||||||
|
{10, 5, 11},
|
||||||
|
} {
|
||||||
|
// Do a couple of random trials since the values change.
|
||||||
|
for trial := 0; trial < 100; trial++ {
|
||||||
|
m := test.m
|
||||||
|
n := test.n
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = n
|
||||||
|
}
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
if trial == 0 {
|
||||||
|
for i := range a {
|
||||||
|
a[i] = float64(i)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.NormFloat64()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
aDense := make([]float64, len(a))
|
||||||
|
if uplo == blas.Lower {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j <= min(i, n-1); j++ {
|
||||||
|
aDense[i*lda+j] = a[i*lda+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := i; j < n; j++ {
|
||||||
|
aDense[i*lda+j] = a[i*lda+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if diag == blas.Unit {
|
||||||
|
for i := 0; i < min(m, n); i++ {
|
||||||
|
aDense[i*lda+i] = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
work := make([]float64, n)
|
||||||
|
for i := range work {
|
||||||
|
work[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
got := impl.Dlantr(norm, uplo, diag, m, n, a, lda, work)
|
||||||
|
want := impl.Dlange(norm, m, n, aDense, lda, work)
|
||||||
|
if math.Abs(got-want) > 1e-13 {
|
||||||
|
t.Errorf("Norm mismatch. norm = %c, unitdiag = %v, upper = %v, m = %v, n = %v, lda = %v, Want %v, got %v.",
|
||||||
|
norm, diag == blas.Unit, uplo == blas.Upper, m, n, lda, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user