diff --git a/lapack/gonum/dgesc2.go b/lapack/gonum/dgesc2.go new file mode 100644 index 00000000..4d0f1260 --- /dev/null +++ b/lapack/gonum/dgesc2.go @@ -0,0 +1,83 @@ +// Copyright ©2021 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "math" + + "gonum.org/v1/gonum/blas/blas64" +) + +// Dgesc2 solves a system of linear equations +// A * X = scale * RHS +// with a general N-by-N matrix A using the LU factorization with +// complete pivoting computed by Dgetc2. The result is placed in +// rhs on exit. +// +// Dgesc2 is an internal routine. It is exported for testing purposes. +func (impl Implementation) Dgesc2(n int, a []float64, lda int, rhs []float64, ipiv, jpiv []int) (scale float64) { + switch { + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + // Quick return if possible. + if n == 0 { + return 0 + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(rhs) < n: + panic(shortRHS) + case len(ipiv) != n: + panic(badLenIpiv) + case len(jpiv) != n: + panic(badLenJpiv) + } + + const smlnum = dlamchS / dlamchP + if len(a) < (n-1)*lda+n { + panic(shortA) + } + + // Apply permutations ipiv to RHS. + impl.Dlaswp(1, rhs, 1, 0, n-1, ipiv[:n], 1) + + // Solve for L part. + for i := 0; i < n-1; i++ { + for j := i + 1; j < n; j++ { + rhs[j] -= float64(a[j*lda+i] * rhs[i]) + } + } + + // Solve for U part. + + scale = 1.0 + + // Check for scaling. + bi := blas64.Implementation() + i := bi.Idamax(n, rhs, 1) + if 2*smlnum*math.Abs(rhs[i]) > math.Abs(a[(n-1)*lda+(n-1)]) { + temp := 0.5 / math.Abs(rhs[i]) + bi.Dscal(n, temp, rhs, 1) + scale *= temp + } + + for i := n - 1; i >= 0; i-- { + temp := 1.0 / a[i*lda+i] + rhs[i] *= temp + for j := i + 1; j < n; j++ { + rhs[i] -= float64(rhs[j] * (a[i*lda+j] * temp)) + } + } + + // Apply permutations jpiv to the solution (rhs). + impl.Dlaswp(1, rhs, 1, 0, n-1, jpiv[:n], -1) + return scale +} diff --git a/lapack/gonum/errors.go b/lapack/gonum/errors.go index 28f3b150..1429e65d 100644 --- a/lapack/gonum/errors.go +++ b/lapack/gonum/errors.go @@ -101,6 +101,7 @@ const ( badLenAlpha = "lapack: bad length of alpha" badLenBeta = "lapack: bad length of beta" badLenIpiv = "lapack: bad length of ipiv" + badLenJpiv = "lapack: bad length of jpiv" badLenJpvt = "lapack: bad length of jpvt" badLenK = "lapack: bad length of k" badLenSelected = "lapack: bad length of selected" @@ -126,6 +127,7 @@ const ( shortIWork = "lapack: insufficient length of iwork" shortIsgn = "lapack: insufficient length of isgn" shortQ = "lapack: insufficient length of q" + shortRHS = "lapack: insufficient length of rhs" shortS = "lapack: insufficient length of s" shortScale = "lapack: insufficient length of scale" shortT = "lapack: insufficient length of t" diff --git a/lapack/gonum/lapack_test.go b/lapack/gonum/lapack_test.go index 2dded735..a296d64f 100644 --- a/lapack/gonum/lapack_test.go +++ b/lapack/gonum/lapack_test.go @@ -92,6 +92,11 @@ func TestDgerq2(t *testing.T) { testlapack.Dgerq2Test(t, impl) } +func TestDgesc2(t *testing.T) { + t.Parallel() + testlapack.Dgesc2Test(t, impl) +} + func TestDgeqp3(t *testing.T) { t.Parallel() testlapack.Dgeqp3Test(t, impl) diff --git a/lapack/testlapack/dgesc2.go b/lapack/testlapack/dgesc2.go new file mode 100644 index 00000000..4c351b51 --- /dev/null +++ b/lapack/testlapack/dgesc2.go @@ -0,0 +1,91 @@ +// Copyright ©2021 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package testlapack + +import ( + "fmt" + "math" + "testing" + + "golang.org/x/exp/rand" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" +) + +type Dgesc2er interface { + Dgetc2er + // Dgesc2 solves a system of linear equations + // A * X = scale * RHS + // with a general n×n matrix A using the LU factorization with + // complete pivoting computed by Dgetc2. The result is placed in + // rhs on exit. + Dgesc2(n int, a []float64, lda int, rhs []float64, ipiv, jpiv []int) (scale float64) +} + +func Dgesc2Test(t *testing.T, impl Dgesc2er) { + const tol = 1e-12 + rnd := rand.New(rand.NewSource(1)) + for _, test := range []struct { + n, lda int + }{ + {3, 0}, + {5, 0}, + {20, 30}, + {200, 0}, + } { + testSolveDgesc2(t, impl, rnd, test.n, test.lda, tol) + } +} + +func testSolveDgesc2(t *testing.T, impl Dgesc2er, rnd *rand.Rand, n, lda int, tol float64) { + name := fmt.Sprintf("n=%v,lda=%v", n, lda) + if lda == 0 { + lda = n + } + // Generate random general matrix. + a := randomGeneral(n, n, lda, rnd) + // anorm := floats.Norm(a.Data, 1) + + // Generate a random solution. + xWant := randomGeneral(n, 1, 1, rnd) + // xnorm := floats.Norm(xWant.Data, 1) + + // Compute RHS vector that solves for X such that A*X = scale * RHS + rhs := zeros(n, 1, 1) + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, xWant, 1, rhs) + rhsCopy := zeros(n, 1, 1) // Will contain A*x result. + copyGeneral(rhsCopy, rhs) + // Compute LU factorization with full pivoting. + lu := zeros(n, n, lda) + copyGeneral(lu, a) + ipiv := make([]int, n) + jpiv := make([]int, n) + impl.Dgetc2(n, lu.Data, lu.Stride, ipiv, jpiv) + + // Solve using lu factorization. + scale := impl.Dgesc2(lu.Rows, lu.Data, lu.Stride, rhs.Data, ipiv, jpiv) + x := rhs + if scale < 0 || scale > 1 { + t.Errorf("%v: resulting scale out of bounds [0,1]", name) + } + + var diff float64 + for i := range x.Data { + diff = math.Max(diff, math.Abs(xWant.Data[i]-x.Data[i])) + } + if diff > tol { + t.Errorf("%v: unexpected result, diff=%v", name, diff) + } + // |A*X - scale*RHS| / |A| / |X| is an indicator that solution is good + // AxResult := zeros(n, 1, 1) + // blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, x, 1, AxResult) + // blas64.Scal(scale, blas64.Vector{N: n, Data: rhsCopy.Data, Inc: 1}) + // floats.Sub(AxResult.Data, rhsCopy.Data) + + // residualNorm := floats.Norm(rhsCopy.Data, 1) / anorm / xnorm + // if residualNorm > tol { + // t.Errorf("%v: |A*X - scale*RHS| / |A| / |X| = %g is greater than permissible tol", name, residualNorm) + // } +}