diff --git a/lapack/gonum/dpstrf.go b/lapack/gonum/dpstrf.go new file mode 100644 index 00000000..1958238d --- /dev/null +++ b/lapack/gonum/dpstrf.go @@ -0,0 +1,221 @@ +// Copyright ©2021 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "math" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" +) + +// Dpstrf computes the Cholesky factorization with complete pivoting of an n×n +// symmetric positive semidefinite matrix A. +// +// The factorization has the form +// Pᵀ * A * P = Uᵀ * U , if uplo = blas.Upper, +// Pᵀ * A * P = L * Lᵀ, if uplo = blas.Lower, +// where U is an upper triangular matrix and L is lower triangular, and P is +// stored as vector piv. +// +// Dpstrf does not attempt to check that A is positive semidefinite. +// +// The length of piv must be n and the length of work must be at least 2*n, +// otherwise Dpstrf will panic. +// +// Dpstrf is an internal routine. It is exported for testing purposes. +func (impl Implementation) Dpstrf(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) { + switch { + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + // Quick return if possible. + if n == 0 { + return 0, true + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(piv) != n: + panic(badLenPiv) + case len(work) < 2*n: + panic(shortWork) + } + + // Get block size. + nb := impl.Ilaenv(1, "DPOTRF", string(uplo), n, -1, -1, -1) + if nb <= 1 || n <= nb { + // Use unblocked code. + return impl.Dpstf2(uplo, n, a, lda, piv, tol, work) + } + + // Initialize piv. + for i := range piv[:n] { + piv[i] = i + } + + // Compute the first pivot. + pvt := 0 + ajj := a[0] + for i := 1; i < n; i++ { + aii := a[i*lda+i] + if aii > ajj { + pvt = i + ajj = aii + } + } + if ajj <= 0 || math.IsNaN(ajj) { + return 0, false + } + + // Compute stopping value if not supplied. + dstop := tol + if dstop < 0 { + dstop = float64(n) * dlamchE * ajj + } + + bi := blas64.Implementation() + // Split work in half, the first half holds dot products. + dots := work[:n] + work2 := work[n : 2*n] + if uplo == blas.Upper { + // Compute the Cholesky factorization Pᵀ * A * P = Uᵀ * U. + for k := 0; k < n; k += nb { + // Account for last block not being nb wide. + jb := min(nb, n-k) + // Set relevant part of dot products to zero. + for i := k; i < n; i++ { + dots[i] = 0 + } + for j := k; j < k+jb; j++ { + // Update dot products and compute possible pivots which are stored + // in the second half of work. + for i := j; i < n; i++ { + if j > k { + tmp := a[(j-1)*lda+i] + dots[i] += tmp * tmp + } + work2[i] = a[i*lda+i] - dots[i] + } + if j > 0 { + // Find the pivot. + pvt = j + ajj = work2[pvt] + for l := j + 1; l < n; l++ { + wl := work2[l] + if wl > ajj { + pvt = l + ajj = wl + } + } + // Test for exit. + if ajj <= dstop || math.IsNaN(ajj) { + a[j*lda+j] = ajj + return j, false + } + } + if j != pvt { + // Swap pivot rows and columns. + a[pvt*lda+pvt] = a[j*lda+j] + bi.Dswap(j, a[j:], lda, a[pvt:], lda) + if pvt < n-1 { + bi.Dswap(n-pvt-1, a[j*lda+(pvt+1):], 1, a[pvt*lda+(pvt+1):], 1) + } + bi.Dswap(pvt-j-1, a[j*lda+(j+1):], 1, a[(j+1)*lda+pvt:], lda) + // Swap dot products and piv. + dots[j], dots[pvt] = dots[pvt], dots[j] + piv[j], piv[pvt] = piv[pvt], piv[j] + } + ajj = math.Sqrt(ajj) + a[j*lda+j] = ajj + // Compute elements j+1:n of row j. + if j < n-1 { + bi.Dgemv(blas.Trans, j-k, n-j-1, + -1, a[k*lda+j+1:], lda, a[k*lda+j:], lda, + 1, a[j*lda+j+1:], 1) + bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1) + } + } + // Update trailing matrix. + if k+jb < n { + j := k + jb + bi.Dsyrk(blas.Upper, blas.Trans, n-j, jb, + -1, a[k*lda+j:], lda, 1, a[j*lda+j:], lda) + } + } + } else { + // Compute the Cholesky factorization Pᵀ * A * P = L * Lᵀ. + for k := 0; k < n; k += nb { + // Account for last block not being nb wide. + jb := min(nb, n-k) + // Set relevant part of dot products to zero. + for i := k; i < n; i++ { + dots[i] = 0 + } + for j := k; j < k+jb; j++ { + // Update dot products and compute possible pivots which are stored + // in the second half of work. + for i := j; i < n; i++ { + if j > k { + tmp := a[i*lda+(j-1)] + dots[i] += tmp * tmp + } + work2[i] = a[i*lda+i] - dots[i] + } + if j > 0 { + // Find the pivot. + pvt = j + ajj = work2[pvt] + for l := j + 1; l < n; l++ { + wl := work2[l] + if wl > ajj { + pvt = l + ajj = wl + } + } + // Test for exit. + if ajj <= dstop || math.IsNaN(ajj) { + a[j*lda+j] = ajj + return j, false + } + } + if j != pvt { + // Swap pivot rows and columns. + a[pvt*lda+pvt] = a[j*lda+j] + bi.Dswap(j, a[j*lda:], 1, a[pvt*lda:], 1) + if pvt < n-1 { + bi.Dswap(n-pvt-1, a[(pvt+1)*lda+j:], lda, a[(pvt+1)*lda+pvt:], lda) + } + bi.Dswap(pvt-j-1, a[(j+1)*lda+j:], lda, a[pvt*lda+(j+1):], 1) + // Swap dot products and piv. + dots[j], dots[pvt] = dots[pvt], dots[j] + piv[j], piv[pvt] = piv[pvt], piv[j] + } + ajj = math.Sqrt(ajj) + a[j*lda+j] = ajj + // Compute elements j+1:n of column j. + if j < n-1 { + bi.Dgemv(blas.NoTrans, n-j-1, j-k, + -1, a[(j+1)*lda+k:], lda, a[j*lda+k:], 1, + 1, a[(j+1)*lda+j:], lda) + bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda) + } + } + // Update trailing matrix. + if k+jb < n { + j := k + jb + bi.Dsyrk(blas.Lower, blas.NoTrans, n-j, jb, + -1, a[j*lda+k:], lda, 1, a[j*lda+j:], lda) + } + } + } + return n, true +} diff --git a/lapack/gonum/lapack_test.go b/lapack/gonum/lapack_test.go index ab00b6fe..50ae3e0f 100644 --- a/lapack/gonum/lapack_test.go +++ b/lapack/gonum/lapack_test.go @@ -558,6 +558,11 @@ func TestDpstf2(t *testing.T) { testlapack.Dpstf2Test(t, impl) } +func TestDpstrf(t *testing.T) { + t.Parallel() + testlapack.DpstrfTest(t, impl) +} + func TestDrscl(t *testing.T) { t.Parallel() testlapack.DrsclTest(t, impl) diff --git a/lapack/testlapack/dpstf2.go b/lapack/testlapack/dpstf2.go index f7223688..51fc5ede 100644 --- a/lapack/testlapack/dpstf2.go +++ b/lapack/testlapack/dpstf2.go @@ -13,7 +13,6 @@ import ( "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/blas/blas64" - "gonum.org/v1/gonum/lapack" ) type Dpstf2er interface { @@ -76,90 +75,9 @@ func dpstf2Test(t *testing.T, impl Dpstf2er, rnd *rand.Rand, uplo blas.Uplo, n, return } - // Reconstruct the symmetric positive semi-definite matrix A from its L or U - // factors and the permutation matrix P. - perm := zeros(n, n, n) - if uplo == blas.Upper { - // Change notation. - u, ldu := aFac, lda - // Zero out last n-rank rows of the factor U. - for i := rank; i < n; i++ { - for j := i; j < n; j++ { - u[i*ldu+j] = 0 - } - } - // Extract U to aRec. - aRec := zeros(n, n, n) - for i := 0; i < n; i++ { - for j := i; j < n; j++ { - aRec.Data[i*aRec.Stride+j] = u[i*ldu+j] - } - } - // Multiply U by Uᵀ from the left. - bi.Dtrmm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit, n, n, - 1, u, ldu, aRec.Data, aRec.Stride) - // Form P * Uᵀ * U * Pᵀ. - for i := 0; i < n; i++ { - for j := 0; j < n; j++ { - if piv[i] > piv[j] { - // Don't set the lower triangle. - continue - } - if i <= j { - perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[i*aRec.Stride+j] - } else { - perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[j*aRec.Stride+i] - } - } - } - // Compute the difference P*Uᵀ*U*Pᵀ - A. - for i := 0; i < n; i++ { - for j := i; j < n; j++ { - perm.Data[i*perm.Stride+j] -= a[i*lda+j] - } - } - } else { - // Change notation. - l, ldl := aFac, lda - // Zero out last n-rank columns of the factor L. - for i := rank; i < n; i++ { - for j := rank; j <= i; j++ { - l[i*ldl+j] = 0 - } - } - // Extract L to aRec. - aRec := zeros(n, n, n) - for i := 0; i < n; i++ { - for j := 0; j <= i; j++ { - aRec.Data[i*aRec.Stride+j] = l[i*ldl+j] - } - } - // Multiply L by Lᵀ from the right. - bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.NonUnit, n, n, - 1, l, ldl, aRec.Data, aRec.Stride) - // Form P * L * Lᵀ * Pᵀ. - for i := 0; i < n; i++ { - for j := 0; j < n; j++ { - if piv[i] < piv[j] { - // Don't set the upper triangle. - continue - } - if i >= j { - perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[i*aRec.Stride+j] - } else { - perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[j*aRec.Stride+i] - } - } - } - // Compute the difference P*L*Lᵀ*Pᵀ - A. - for i := 0; i < n; i++ { - for j := 0; j <= i; j++ { - perm.Data[i*perm.Stride+j] -= a[i*lda+j] - } - } - } - // Compute |P*Uᵀ*U*Pᵀ - A| / n or |P*L*Lᵀ*Pᵀ - A| / n. - resid := dlansy(lapack.MaxColumnSum, uplo, n, perm.Data, perm.Stride) / float64(n) + // Check that the residual |P*Uᵀ*U*Pᵀ - A| / n or |P*L*Lᵀ*Pᵀ - A| / n is + // sufficiently small. + resid := residualDpstrf(uplo, n, a, aFac, lda, rank, piv) if resid > tol || math.IsNaN(resid) { t.Errorf("%v: residual too large; got %v, want<=%v", name, resid, tol) } diff --git a/lapack/testlapack/dpstrf.go b/lapack/testlapack/dpstrf.go new file mode 100644 index 00000000..a097b30a --- /dev/null +++ b/lapack/testlapack/dpstrf.go @@ -0,0 +1,173 @@ +// Copyright ©2021 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testlapack + +import ( + "fmt" + "math" + "testing" + + "golang.org/x/exp/rand" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" + "gonum.org/v1/gonum/lapack" +) + +type Dpstrfer interface { + Dpstrf(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) +} + +func DpstrfTest(t *testing.T, impl Dpstrfer) { + rnd := rand.New(rand.NewSource(1)) + for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { + t.Run(uploToString(uplo), func(t *testing.T) { + for _, n := range []int{0, 1, 2, 3, 4, 5, 31, 32, 33, 63, 64, 65, 127, 128, 129} { + for _, lda := range []int{max(1, n), n + 5} { + for _, rank := range []int{int(0.7 * float64(n)), n} { + dpstrfTest(t, impl, rnd, uplo, n, lda, rank) + } + } + } + }) + } +} + +func dpstrfTest(t *testing.T, impl Dpstrfer, rnd *rand.Rand, uplo blas.Uplo, n, lda, rankWant int) { + const tol = 1e-13 + + name := fmt.Sprintf("n=%v,lda=%v", n, lda) + bi := blas64.Implementation() + + // Generate a random, symmetric A with the given rank by applying rankWant + // rank-1 updates to the zero matrix. + a := make([]float64, n*lda) + for i := 0; i < rankWant; i++ { + x := randomSlice(n, rnd) + bi.Dsyr(uplo, n, 1, x, 1, a, lda) + } + + // Make a copy of A for storing the factorization. + aFac := make([]float64, len(a)) + copy(aFac, a) + + // Allocate a slice for pivots and fill it with invalid index values. + piv := make([]int, n) + for i := range piv { + piv[i] = -1 + } + + // Allocate the work slice. + work := make([]float64, 2*n) + + // Call Dpstrf to Compute the Cholesky factorization with complete pivoting. + rank, ok := impl.Dpstrf(uplo, n, aFac, lda, piv, -1, work) + + if ok != (rank == n) { + t.Errorf("%v: unexpected ok; got %v, want %v", name, ok, rank == n) + } + if rank != rankWant { + t.Errorf("%v: unexpected rank; got %v, want %v", name, rank, rankWant) + } + + if n == 0 { + return + } + + // Check that the residual |P*Uᵀ*U*Pᵀ - A| / n or |P*L*Lᵀ*Pᵀ - A| / n is + // sufficiently small. + resid := residualDpstrf(uplo, n, a, aFac, lda, rank, piv) + if resid > tol || math.IsNaN(resid) { + t.Errorf("%v: residual too large; got %v, want<=%v", name, resid, tol) + } +} + +func residualDpstrf(uplo blas.Uplo, n int, a, aFac []float64, lda int, rank int, piv []int) float64 { + bi := blas64.Implementation() + // Reconstruct the symmetric positive semi-definite matrix A from its L or U + // factors and the permutation matrix P. + perm := zeros(n, n, n) + if uplo == blas.Upper { + // Change notation. + u, ldu := aFac, lda + // Zero out last n-rank rows of the factor U. + for i := rank; i < n; i++ { + for j := i; j < n; j++ { + u[i*ldu+j] = 0 + } + } + // Extract U to aRec. + aRec := zeros(n, n, n) + for i := 0; i < n; i++ { + for j := i; j < n; j++ { + aRec.Data[i*aRec.Stride+j] = u[i*ldu+j] + } + } + // Multiply U by Uᵀ from the left. + bi.Dtrmm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit, n, n, + 1, u, ldu, aRec.Data, aRec.Stride) + // Form P * Uᵀ * U * Pᵀ. + for i := 0; i < n; i++ { + for j := 0; j < n; j++ { + if piv[i] > piv[j] { + // Don't set the lower triangle. + continue + } + if i <= j { + perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[i*aRec.Stride+j] + } else { + perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[j*aRec.Stride+i] + } + } + } + // Compute the difference P*Uᵀ*U*Pᵀ - A. + for i := 0; i < n; i++ { + for j := i; j < n; j++ { + perm.Data[i*perm.Stride+j] -= a[i*lda+j] + } + } + } else { + // Change notation. + l, ldl := aFac, lda + // Zero out last n-rank columns of the factor L. + for i := rank; i < n; i++ { + for j := rank; j <= i; j++ { + l[i*ldl+j] = 0 + } + } + // Extract L to aRec. + aRec := zeros(n, n, n) + for i := 0; i < n; i++ { + for j := 0; j <= i; j++ { + aRec.Data[i*aRec.Stride+j] = l[i*ldl+j] + } + } + // Multiply L by Lᵀ from the right. + bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.NonUnit, n, n, + 1, l, ldl, aRec.Data, aRec.Stride) + // Form P * L * Lᵀ * Pᵀ. + for i := 0; i < n; i++ { + for j := 0; j < n; j++ { + if piv[i] < piv[j] { + // Don't set the upper triangle. + continue + } + if i >= j { + perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[i*aRec.Stride+j] + } else { + perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[j*aRec.Stride+i] + } + } + } + // Compute the difference P*L*Lᵀ*Pᵀ - A. + for i := 0; i < n; i++ { + for j := 0; j <= i; j++ { + perm.Data[i*perm.Stride+j] -= a[i*lda+j] + } + } + } + // Compute |P*Uᵀ*U*Pᵀ - A| / n or |P*L*Lᵀ*Pᵀ - A| / n. + return dlansy(lapack.MaxColumnSum, uplo, n, perm.Data, perm.Stride) / float64(n) +}