diff --git a/cgo/lapack.go b/cgo/lapack.go index a876aec6..5035901f 100644 --- a/cgo/lapack.go +++ b/cgo/lapack.go @@ -14,6 +14,7 @@ import ( // Copied from lapack/native. Keep in sync. const ( absIncNotOne = "lapack: increment not one or negative one" + badDiag = "lapack: bad diag" badDirect = "lapack: bad direct" badIpiv = "lapack: insufficient permutation length" badLdA = "lapack: index of a out of range" diff --git a/native/dlatrs.go b/native/dlatrs.go new file mode 100644 index 00000000..283a5d3a --- /dev/null +++ b/native/dlatrs.go @@ -0,0 +1,334 @@ +package native + +import ( + "math" + + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Dlatrs solves a triangular system of equations scaled to prevent overflow. It +// solves +// A * x = scale * b if trans == blas.NoTrans +// A^T * x = scale * b if trans == blas.Trans +// where the scale s is set for numeric stability. +// +// A is an n×n triangular matrix. On entry, the slice x contains the values of +// of b, and on exit it contains the solution vector x. +// +// If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal +// part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater +// than or equal to the infinity norm, and greater than or equal to the one-norm +// otherwise. If normin == false, then cnorm is treated as an output, and is set +// to contain the 1-norm of the off-diagonal part of the j^th column of A. +func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) { + if uplo != blas.Upper && uplo != blas.Lower { + panic(badUplo) + } + if trans != blas.Trans && trans != blas.NoTrans { + panic(badTrans) + } + if diag != blas.Unit && diag != blas.NonUnit { + panic(badDiag) + } + upper := uplo == blas.Upper + noTrans := trans == blas.NoTrans + nonUnit := diag == blas.NonUnit + + if n < 0 { + panic(nLT0) + } + checkMatrix(n, n, a, lda) + checkVector(n, x, 1) + checkVector(n, cnorm, 1) + + if n == 0 { + return + } + scale = 1 + bi := blas64.Implementation() + if !normin { + if upper { + for j := 0; j < n; j++ { + cnorm[j] = bi.Dasum(j, a[j:], lda) + } + } else { + for j := 0; j < n-1; j++ { + cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda) + } + cnorm[n-1] = 0 + } + } + // Scale the column norms by tscal if the maximum element in cnorm is greater than bignum. + imax := bi.Idamax(n, cnorm, 1) + tmax := cnorm[imax] + var tscal float64 + if tmax <= bignum { + tscal = 1 + } else { + tscal = 1 / (smlnum * tmax) + bi.Dscal(n, tscal, cnorm, 1) + } + + // Compute a bound on the computed solution vector to see if bi.Dtrsv can be used. + j := bi.Idamax(n, x, 1) + xmax := math.Abs(x[j]) + xbnd := xmax + var grow float64 + var jfirst, jlast, jinc int + if noTrans { + if upper { + jfirst = n - 1 + jlast = 0 + jinc = -1 + } else { + jfirst = 0 + jlast = n - 1 + jinc = 1 + } + // Compute the growth in A * x = b. + if tscal != 1 { + grow = 0 + goto Finish + } + if nonUnit { + grow = 1 / math.Max(xbnd, smlnum) + xbnd = grow + for j := jfirst; j != jlast; j += jinc { + if grow <= smlnum { + goto Finish + } + tjj := math.Abs(a[j*lda+j]) + xbnd = math.Min(xbnd, math.Min(1, tjj)*grow) + if tjj+cnorm[j] >= smlnum { + grow *= tjj / (tjj + cnorm[j]) + } else { + grow = 0 + } + } + grow = xbnd + } else { + grow = math.Min(1, 1/math.Max(xbnd, smlnum)) + for j := jfirst; j != jlast; j += jinc { + if grow <= smlnum { + goto Finish + } + grow *= 1 / (1 + cnorm[j]) + } + } + } else { + if upper { + jfirst = 0 + jlast = n - 1 + jinc = 1 + } else { + jfirst = n - 1 + jlast = 0 + jinc = -1 + } + if tscal != 1 { + grow = 0 + goto Finish + } + if nonUnit { + grow = 1 / (math.Max(xbnd, smlnum)) + xbnd = grow + for j := jfirst; j != jlast; j += jinc { + if grow <= smlnum { + goto Finish + } + xj := 1 + cnorm[j] + grow = math.Min(grow, xbnd/xj) + tjj := math.Abs(a[j*lda+j]) + if xj > tjj { + xbnd *= tjj / xj + } + } + grow = math.Min(grow, xbnd) + } else { + grow = math.Min(1, 1/math.Max(xbnd, smlnum)) + for j := jfirst; j != jlast; j += jinc { + if grow <= smlnum { + goto Finish + } + xj := 1 + cnorm[j] + grow /= xj + } + } + } + +Finish: + if grow*tscal > smlnum { + bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1) + // TODO(btracey): check if this else is everything + } else { + if xmax > bignum { + scale = bignum / xmax + bi.Dscal(n, scale, x, 1) + xmax = bignum + } + if noTrans { + for j := jfirst; j != jlast; j += jinc { + xj := math.Abs(x[j]) + var tjjs float64 + if nonUnit { + tjjs = a[j*lda+j] * tscal + } else { + tjjs = tscal + if tscal == 1 { + break + } + } + tjj := math.Abs(tjjs) + if tjj > smlnum { + if tjj < 1 { + if xj > tjj*bignum { + rec := 1 / xj + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + } + x[j] /= tjjs + xj = math.Abs(x[j]) + } else if tjj > 0 { + if xj > tjj*bignum { + rec := (tjj * bignum) / xj + if cnorm[j] > 1 { + rec /= cnorm[j] + } + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + x[j] /= tjjs + xj = math.Abs(x[j]) + } else { + for i := 0; i < n; i++ { + x[i] = 0 + } + x[j] = 1 + xj = 1 + scale = 0 + xmax = 0 + } + if xj > 1 { + rec := 1 / xj + if cnorm[j] > (bignum-xmax)*rec { + rec *= 0.5 + bi.Dscal(n, rec, x, 1) + scale *= rec + } + } else if xj*cnorm[j] > bignum-xmax { + bi.Dscal(n, 0.5, x, 1) + scale *= 0.5 + } + if upper { + if j > 0 { + bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1) + i := bi.Idamax(j, x, 1) + xmax = math.Abs(x[i]) + } + } else { + if j < n-1 { + bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1) + i := j + bi.Idamax(n-j-1, x[j+1:], 1) + xmax = math.Abs(x[i]) + } + } + } + } else { + for j := jfirst; j != jlast; j += jinc { + xj := math.Abs(x[j]) + uscal := tscal + rec := 1 / math.Max(xmax, 1) + var tjjs float64 + if cnorm[j] > (bignum-xj)*rec { + rec *= 0.5 + if nonUnit { + tjjs = a[j*lda+j] * tscal + } else { + tjjs = tscal + } + tjj := math.Abs(tjjs) + if tjj > 1 { + rec = math.Min(1, rec*tjj) + uscal /= tjjs + } + if rec < 1 { + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + } + var sumj float64 + if uscal == 1 { + if upper { + sumj = bi.Ddot(j, a[j:], lda, x, 1) + } else if j < n-1 { + sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1) + } + } else { + if upper { + for i := 0; i < j; i++ { + sumj += (a[i*lda+j] * uscal) * x[i] + } + } else if j < n { + for i := j + 1; i < n; i++ { + sumj += (a[i*lda+j] * uscal) * x[i] + } + } + } + if uscal == tscal { + x[j] -= sumj + xj := math.Abs(x[j]) + var tjjs float64 + if nonUnit { + tjjs = a[j*lda+j] * tscal + } else { + tjjs = tscal + if tscal == 1 { + goto Out2 + } + } + tjj := math.Abs(tjjs) + if tjj > smlnum { + if tjj < 1 { + if xj > tjj*bignum { + rec = 1 / xj + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + } + x[j] /= tjjs + } else if tjj > 0 { + if xj > tjj*bignum { + rec = (tjj * bignum) / xj + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + x[j] /= tjjs + } else { + for i := 0; i < n; i++ { + x[i] = 0 + } + x[j] = 1 + scale = 0 + xmax = 0 + } + } else { + x[j] = x[j]/tjjs - sumj + } + Out2: + xmax = math.Max(xmax, math.Abs(x[j])) + } + } + scale /= tscal + } + if tscal != 1 { + bi.Dscal(n, 1/tscal, cnorm, 1) + } + return scale +} diff --git a/native/general.go b/native/general.go index d91b4ad3..c635319f 100644 --- a/native/general.go +++ b/native/general.go @@ -20,6 +20,7 @@ var _ lapack.Float64 = Implementation{} // This list is duplicated in lapack/cgo. Keep in sync. const ( absIncNotOne = "lapack: increment not one or negative one" + badDiag = "lapack: bad diag" badDirect = "lapack: bad direct" badIpiv = "lapack: insufficient permutation length" badLdA = "lapack: index of a out of range" @@ -79,6 +80,7 @@ func max(a, b int) int { // TODO(btracey): Is there a better way to find the smallest number such that 1+E > 1 var dlamchE, dlamchS, dlamchP float64 +var smlnum, bignum float64 func init() { onePlusEps := math.Nextafter(1, math.Inf(1)) @@ -92,4 +94,6 @@ func init() { dlamchS = sfmin radix := 2.0 dlamchP = radix * eps + smlnum = dlamchS / dlamchP + bignum = 1 / smlnum }