diff --git a/lapack/gonum/dgetc2.go b/lapack/gonum/dgetc2.go new file mode 100644 index 00000000..2ba2a91e --- /dev/null +++ b/lapack/gonum/dgetc2.go @@ -0,0 +1,117 @@ +// Copyright ©2021 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "math" + + "gonum.org/v1/gonum/blas/blas64" +) + +// Dgetc2 computes an LU factorization with complete pivoting of the +// n×n matrix A. The factorization has the form +// A = P * L * U * Q, +// where P and Q are permutation matrices, L is lower triangular with +// unit diagonal elements and U is upper triangular. +// +// a is modified to the information to construct L and U. +// The lower triangle of a contains the matrix L (not including diagonal). +// The upper triangle contains the matrix U. The matrices P and Q can +// be constructed from ipiv and jpiv, respectively. k is non-negative if U(k, k) +// is likely to produce overflow when we try to solve for x in Ax = b. +// U is perturbed in this case to avoid the overflow. +// +// Dgetc2 is an internal routine. It is exported for testing purposes. +func (impl Implementation) Dgetc2(n int, a []float64, lda int, ipiv, jpiv []int) (k int) { + switch { + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + // Negative k indicates U was not perturbed. + k = -1 + // Quick return if possible. + if n == 0 { + return k + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(ipiv) != n: + panic(badLenIpiv) + case len(jpiv) != n: + panic(badLenJpvt) + } + + const ( + eps = dlamchP + smlnum = dlamchS / eps + ) + if n == 1 { + ipiv[0], jpiv[0] = 0, 0 + if math.Abs(a[0]) < smlnum { + a[0] = smlnum + k = 0 + } + return k + } + + // Factorize A using complete pivoting. + // Set pivots less than lc to lc. + var lc float64 + var ipv, jpv int + bi := blas64.Implementation() + for i := 0; i < n-1; i++ { + xmax := 0.0 + for ip := i; ip < n; ip++ { + for jp := i; jp < n; jp++ { + if math.Abs(a[ip*lda+jp]) >= xmax { + xmax = math.Abs(a[ip*lda+jp]) + ipv = ip + jpv = jp + } + } + } + if i == 0 { + lc = math.Max(eps*xmax, smlnum) + } + + // Swap rows. + if ipv != i { + bi.Dswap(n, a[ipv*lda:], 1, a[i*lda:], 1) + } + ipiv[i] = ipv + + // Swap columns. + if jpv != i { + bi.Dswap(n, a[jpv:], lda, a[i:], lda) + } + jpiv[i] = jpv + + // Check for singularity. + if math.Abs(a[i*lda+i]) < lc { + k = i + a[i*lda+i] = lc + } + + for j := i + 1; j < n; j++ { + a[j*lda+i] /= a[i*lda+i] + } + bi.Dger(n-i-1, n-i-1, -1, a[(i+1)*lda+i:], lda, a[i*lda+i+1:], 1, a[(i+1)*lda+i+1:], lda) + } + + if math.Abs(a[(n-1)*lda+n-1]) < lc { + k = n - 1 + a[(n-1)*lda+(n-1)] = lc + } + + // Set last pivots to last index. + ipiv[n-1] = n - 1 + jpiv[n-1] = n - 1 + return k +} diff --git a/lapack/gonum/lapack_test.go b/lapack/gonum/lapack_test.go index 4d76525b..2dded735 100644 --- a/lapack/gonum/lapack_test.go +++ b/lapack/gonum/lapack_test.go @@ -118,6 +118,11 @@ func TestDgesvd(t *testing.T) { testlapack.DgesvdTest(t, impl, tol) } +func TestDgetc2(t *testing.T) { + t.Parallel() + testlapack.Dgetc2Test(t, impl) +} + func TestDgetri(t *testing.T) { t.Parallel() testlapack.DgetriTest(t, impl) diff --git a/lapack/testlapack/dgetc2.go b/lapack/testlapack/dgetc2.go new file mode 100644 index 00000000..461c28df --- /dev/null +++ b/lapack/testlapack/dgetc2.go @@ -0,0 +1,98 @@ +// Copyright ©2021 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testlapack + +import ( + "fmt" + "math" + "testing" + + "golang.org/x/exp/rand" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" +) + +type Dgetc2er interface { + Dgetc2(n int, a []float64, lda int, ipiv, jpiv []int) (k int) +} + +func Dgetc2Test(t *testing.T, impl Dgetc2er) { + const tol = 1e-12 + rnd := rand.New(rand.NewSource(1)) + for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 20} { + for _, lda := range []int{n, n + 5} { + dgetc2Test(t, impl, rnd, n, lda, tol) + } + } +} + +func dgetc2Test(t *testing.T, impl Dgetc2er, rnd *rand.Rand, n, lda int, tol float64) { + name := fmt.Sprintf("n=%v,lda=%v", n, lda) + if lda == 0 { + lda = 1 + } + // Generate a random general matrix A. + a := randomGeneral(n, n, lda, rnd) + // ipiv and jpiv are outputs. + ipiv := make([]int, n) + jpiv := make([]int, n) + for i := 0; i < n; i++ { + ipiv[i], jpiv[i] = -1, -1 // Set to non-indices. + } + // Copy to store output (LU decomposition). + lu := cloneGeneral(a) + k := impl.Dgetc2(n, lu.Data, lu.Stride, ipiv, jpiv) + if k >= 0 { + t.Logf("%v: matrix was perturbed at %d", name, k) + } + + // Verify all indices are set. + for i := 0; i < n; i++ { + if ipiv[i] < 0 { + t.Errorf("%v: ipiv[%d] is negative", name, i) + } + if jpiv[i] < 0 { + t.Errorf("%v: jpiv[%d] is negative", name, i) + } + } + bi := blas64.Implementation() + // Construct L and U triangular matrices from Dgetc2 output. + L := zeros(n, n, lda) + U := zeros(n, n, lda) + for i := 0; i < n; i++ { + for j := 0; j < n; j++ { + idx := i*lda + j + if j >= i { // On upper triangle and setting of L's unit diagonal elements. + U.Data[idx] = lu.Data[idx] + if j == i { + L.Data[idx] = 1.0 + } + } else if i > j { // On diagonal or lower triangle. + L.Data[idx] = lu.Data[idx] + } + } + } + work := zeros(n, n, lda) + bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, L.Data, L.Stride, U.Data, U.Stride, 0, work.Data, work.Stride) + + // Apply Permutations P and Q to L*U. + for i := n - 1; i >= 0; i-- { + ipv, jpv := ipiv[i], jpiv[i] + if ipv != i { + bi.Dswap(n, work.Data[i*lda:], 1, work.Data[ipv*lda:], 1) + } + if jpv != i { + bi.Dswap(n, work.Data[i:], work.Stride, work.Data[jpv:], work.Stride) + } + } + + // A should be reconstructed by now. + for i := range work.Data { + if math.Abs(work.Data[i]-a.Data[i]) > tol { + t.Errorf("%v: matrix %d idx not equal after reconstruction. got %g, expected %g", name, i, work.Data[i], a.Data[i]) + } + } +}