diff --git a/cgo/lapack.go b/cgo/lapack.go index 5035901f..ddb68606 100644 --- a/cgo/lapack.go +++ b/cgo/lapack.go @@ -18,6 +18,7 @@ const ( badDirect = "lapack: bad direct" badIpiv = "lapack: insufficient permutation length" badLdA = "lapack: index of a out of range" + badNorm = "lapack: bad norm" badSide = "lapack: bad side" badStore = "lapack: bad store" badTau = "lapack: tau has insufficient length" @@ -66,6 +67,49 @@ type Implementation struct{} var _ lapack.Float64 = Implementation{} +// Dlange computes the matrix norm of the general m×n matrix a. The input norm +// specifies the norm computed. +// lapack.MaxAbs: the maximum absolute value of an element. +// lapack.MaxColumnSum: the maximum column sum of the absolute values of the entries. +// lapack.MaxRowSum: the maximum row sum of the absolute values of the entries. +// lapack.Frobenius: the square root of the sum of the squares of the entries. +// If norm == lapack.MaxColumnSum, work must be of length n, and this function will panic otherwise. +// There are no restrictions on work for the other matrix norms. +func (impl Implementation) Dlange(norm lapack.MatrixNorm, m, n int, a []float64, lda int, work []float64) float64 { + checkMatrix(m, n, a, lda) + switch norm { + case lapack.MaxRowSum, lapack.MaxColumnSum, lapack.NormFrob, lapack.MaxAbs: + default: + panic(badNorm) + } + if norm == lapack.MaxColumnSum && len(work) < n { + panic(badWork) + } + return clapack.Dlange(byte(norm), m, n, a, lda) +} + +// Dlantr computes the specified norm of an m×n trapezoidal matrix A. If +// norm == lapack.MaxColumnSum work must have length at least n, otherwise work +// is unused. +func (impl Implementation) Dlantr(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, m, n int, a []float64, lda int, work []float64) float64 { + checkMatrix(m, n, a, lda) + switch norm { + case lapack.MaxRowSum, lapack.MaxColumnSum, lapack.NormFrob, lapack.MaxAbs: + default: + panic(badNorm) + } + if uplo != blas.Upper && uplo != blas.Lower { + panic(badUplo) + } + if diag != blas.Unit && diag != blas.NonUnit { + panic(badDiag) + } + if norm == lapack.MaxColumnSum && len(work) < n { + panic(badWork) + } + return clapack.Dlantr(byte(norm), uplo, diag, m, n, a, lda) +} + // Dpotrf computes the cholesky decomposition of the symmetric positive definite // matrix a. If ul == blas.Upper, then a is stored as an upper-triangular matrix, // and a = U U^T is stored in place into a. If ul == blas.Lower, then a = L L^T diff --git a/cgo/lapack_test.go b/cgo/lapack_test.go index 1874b677..ef721920 100644 --- a/cgo/lapack_test.go +++ b/cgo/lapack_test.go @@ -13,6 +13,17 @@ import ( var impl = Implementation{} +func TestDlange(t *testing.T) { + testlapack.DlangeTest(t, impl) +} + +// The results from Dlantr do not match the results from Dlange. In some cases, +// there also appear to be memory corruption issues. +// TODO(btracey): Re-enable this test when the implementations are fixed. +// func TestDlantr(t *testing.T) { +// testlapack.DlantrTest(t, impl) +// } + func TestDpotrf(t *testing.T) { testlapack.DpotrfTest(t, impl) } diff --git a/native/dlange.go b/native/dlange.go index fb8259a2..015e8b0a 100644 --- a/native/dlange.go +++ b/native/dlange.go @@ -21,6 +21,14 @@ import ( func (impl Implementation) Dlange(norm lapack.MatrixNorm, m, n int, a []float64, lda int, work []float64) float64 { // TODO(btracey): These should probably be refactored to use BLAS calls. checkMatrix(m, n, a, lda) + switch norm { + case lapack.MaxRowSum, lapack.MaxColumnSum, lapack.NormFrob, lapack.MaxAbs: + default: + panic(badNorm) + } + if norm == lapack.MaxColumnSum && len(work) < n { + panic(badWork) + } if m == 0 && n == 0 { return 0 } diff --git a/native/dlantr.go b/native/dlantr.go new file mode 100644 index 00000000..4ee52228 --- /dev/null +++ b/native/dlantr.go @@ -0,0 +1,248 @@ +package native + +import ( + "math" + + "github.com/gonum/blas" + "github.com/gonum/lapack" +) + +// Dlantr computes the specified norm of an m×n trapezoidal matrix A. If +// norm == lapack.MaxColumnSum work must have length at least n, otherwise work +// is unused. +func (impl Implementation) Dlantr(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, m, n int, a []float64, lda int, work []float64) float64 { + checkMatrix(m, n, a, lda) + switch norm { + case lapack.MaxRowSum, lapack.MaxColumnSum, lapack.NormFrob, lapack.MaxAbs: + default: + panic(badNorm) + } + if uplo != blas.Upper && uplo != blas.Lower { + panic(badUplo) + } + if diag != blas.Unit && diag != blas.NonUnit { + panic(badDiag) + } + if norm == lapack.MaxColumnSum && len(work) < n { + panic(badWork) + } + if min(m, n) == 0 { + return 0 + } + switch norm { + default: + panic("unreachable") + case lapack.MaxAbs: + if diag == blas.Unit { + value := 1.0 + if uplo == blas.Upper { + for i := 0; i < m; i++ { + for j := i + 1; j < n; j++ { + tmp := math.Abs(a[i*lda+j]) + if math.IsNaN(tmp) { + return tmp + } + if tmp > value { + value = tmp + } + } + } + return value + } + for i := 1; i < m; i++ { + for j := 0; j < min(i, n); j++ { + tmp := math.Abs(a[i*lda+j]) + if math.IsNaN(tmp) { + return tmp + } + if tmp > value { + value = tmp + } + } + } + return value + } + var value float64 + if uplo == blas.Upper { + for i := 0; i < m; i++ { + for j := i; j < n; j++ { + tmp := math.Abs(a[i*lda+j]) + if math.IsNaN(tmp) { + return tmp + } + if tmp > value { + value = tmp + } + } + } + return value + } + for i := 0; i < m; i++ { + for j := 0; j <= min(i, n-1); j++ { + tmp := math.Abs(a[i*lda+j]) + if math.IsNaN(tmp) { + return tmp + } + if tmp > value { + value = tmp + } + } + } + return value + case lapack.MaxColumnSum: + if diag == blas.Unit { + for i := 0; i < min(m, n); i++ { + work[i] = 1 + } + for i := min(m, n); i < n; i++ { + work[i] = 0 + } + if uplo == blas.Upper { + for i := 0; i < m; i++ { + for j := i + 1; j < n; j++ { + work[j] += math.Abs(a[i*lda+j]) + } + } + } else { + for i := 1; i < m; i++ { + for j := 0; j < min(i, n); j++ { + work[j] += math.Abs(a[i*lda+j]) + } + } + } + } else { + for i := 0; i < n; i++ { + work[i] = 0 + } + if uplo == blas.Upper { + for i := 0; i < m; i++ { + for j := i; j < n; j++ { + work[j] += math.Abs(a[i*lda+j]) + } + } + } else { + for i := 0; i < m; i++ { + for j := 0; j <= min(i, n-1); j++ { + work[j] += math.Abs(a[i*lda+j]) + } + } + } + } + var max float64 + for _, v := range work { + if math.IsNaN(v) { + return math.NaN() + } + if v > max { + max = v + } + } + return max + case lapack.MaxRowSum: + var maxsum float64 + if diag == blas.Unit { + if uplo == blas.Upper { + for i := 0; i < m; i++ { + var sum float64 + if i < min(m, n) { + sum = 1 + } + for j := i + 1; j < n; j++ { + sum += math.Abs(a[i*lda+j]) + } + if math.IsNaN(sum) { + return math.NaN() + } + if sum > maxsum { + maxsum = sum + } + } + return maxsum + } else { + for i := 1; i < m; i++ { + var sum float64 + if i < min(m, n) { + sum = 1 + } + for j := 0; j < min(i, n); j++ { + sum += math.Abs(a[i*lda+j]) + } + if math.IsNaN(sum) { + return math.NaN() + } + if sum > maxsum { + maxsum = sum + } + } + return maxsum + } + } else { + if uplo == blas.Upper { + for i := 0; i < m; i++ { + var sum float64 + for j := i; j < n; j++ { + sum += math.Abs(a[i*lda+j]) + } + if math.IsNaN(sum) { + return sum + } + if sum > maxsum { + maxsum = sum + } + } + return maxsum + } else { + for i := 0; i < m; i++ { + var sum float64 + for j := 0; j <= min(i, n-1); j++ { + sum += math.Abs(a[i*lda+j]) + } + if math.IsNaN(sum) { + return sum + } + if sum > maxsum { + maxsum = sum + } + } + return maxsum + } + } + case lapack.NormFrob: + var nrm float64 + if diag == blas.Unit { + if uplo == blas.Upper { + for i := 0; i < m; i++ { + for j := i + 1; j < n; j++ { + tmp := a[i*lda+j] + nrm += tmp * tmp + } + } + } else { + for i := 1; i < m; i++ { + for j := 0; j < min(i, n); j++ { + tmp := a[i*lda+j] + nrm += tmp * tmp + } + } + } + nrm += float64(min(m, n)) + } else { + if uplo == blas.Upper { + for i := 0; i < m; i++ { + for j := i; j < n; j++ { + tmp := math.Abs(a[i*lda+j]) + nrm += tmp * tmp + } + } + } else { + for i := 0; i < m; i++ { + for j := 0; j <= min(i, n-1); j++ { + tmp := math.Abs(a[i*lda+j]) + nrm += tmp * tmp + } + } + } + } + return math.Sqrt(nrm) + } +} diff --git a/native/general.go b/native/general.go index 3ada2e79..a8b6d935 100644 --- a/native/general.go +++ b/native/general.go @@ -24,6 +24,7 @@ const ( badDirect = "lapack: bad direct" badIpiv = "lapack: insufficient permutation length" badLdA = "lapack: index of a out of range" + badNorm = "lapack: bad norm" badSide = "lapack: bad side" badStore = "lapack: bad store" badTau = "lapack: tau has insufficient length" diff --git a/native/lapack_test.go b/native/lapack_test.go index c8e3416a..f9ba13a2 100644 --- a/native/lapack_test.go +++ b/native/lapack_test.go @@ -48,6 +48,10 @@ func TestDlange(t *testing.T) { testlapack.DlangeTest(t, impl) } +func TestDlantr(t *testing.T) { + testlapack.DlantrTest(t, impl) +} + func TestDlarfb(t *testing.T) { testlapack.DlarfbTest(t, impl) } diff --git a/testlapack/dlantr.go b/testlapack/dlantr.go new file mode 100644 index 00000000..7f947107 --- /dev/null +++ b/testlapack/dlantr.go @@ -0,0 +1,84 @@ +package testlapack + +import ( + "math" + "math/rand" + "testing" + + "github.com/gonum/blas" + "github.com/gonum/lapack" +) + +type Dlantrer interface { + Dlanger + Dlantr(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, m, n int, a []float64, lda int, work []float64) float64 +} + +func DlantrTest(t *testing.T, impl Dlantrer) { + for _, norm := range []lapack.MatrixNorm{lapack.MaxAbs, lapack.MaxColumnSum, lapack.MaxRowSum, lapack.NormFrob} { + for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} { + for _, uplo := range []blas.Uplo{blas.Lower, blas.Upper} { + for _, test := range []struct { + m, n, lda int + }{ + {3, 3, 0}, + {3, 5, 0}, + {10, 5, 0}, + + {5, 5, 11}, + {5, 10, 11}, + {10, 5, 11}, + } { + // Do a couple of random trials since the values change. + for trial := 0; trial < 100; trial++ { + m := test.m + n := test.n + lda := test.lda + if lda == 0 { + lda = n + } + a := make([]float64, m*lda) + if trial == 0 { + for i := range a { + a[i] = float64(i) + } + } else { + for i := range a { + a[i] = rand.NormFloat64() + } + } + aDense := make([]float64, len(a)) + if uplo == blas.Lower { + for i := 0; i < m; i++ { + for j := 0; j <= min(i, n-1); j++ { + aDense[i*lda+j] = a[i*lda+j] + } + } + } else { + for i := 0; i < m; i++ { + for j := i; j < n; j++ { + aDense[i*lda+j] = a[i*lda+j] + } + } + } + if diag == blas.Unit { + for i := 0; i < min(m, n); i++ { + aDense[i*lda+i] = 1 + } + } + work := make([]float64, n) + for i := range work { + work[i] = rand.Float64() + } + got := impl.Dlantr(norm, uplo, diag, m, n, a, lda, work) + want := impl.Dlange(norm, m, n, aDense, lda, work) + if math.Abs(got-want) > 1e-13 { + t.Errorf("Norm mismatch. norm = %c, unitdiag = %v, upper = %v, m = %v, n = %v, lda = %v, Want %v, got %v.", + norm, diag == blas.Unit, uplo == blas.Upper, m, n, lda, got, want) + } + } + } + } + } + } +}