// Copyright ©2015 The gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package native import ( "math" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/blas/blas64" ) // Dlatrs solves a triangular system of equations scaled to prevent overflow. It // solves // A * x = scale * b if trans == blas.NoTrans // A^T * x = scale * b if trans == blas.Trans // where the scale s is set for numeric stability. // // A is an n×n triangular matrix. On entry, the slice x contains the values of // of b, and on exit it contains the solution vector x. // // If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal // part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater // than or equal to the infinity norm, and greater than or equal to the one-norm // otherwise. If normin == false, then cnorm is treated as an output, and is set // to contain the 1-norm of the off-diagonal part of the j^th column of A. // // Dlatrs is an internal routine. It is exported for testing purposes. func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) { if uplo != blas.Upper && uplo != blas.Lower { panic(badUplo) } if trans != blas.Trans && trans != blas.NoTrans { panic(badTrans) } if diag != blas.Unit && diag != blas.NonUnit { panic(badDiag) } upper := uplo == blas.Upper noTrans := trans == blas.NoTrans nonUnit := diag == blas.NonUnit if n < 0 { panic(nLT0) } checkMatrix(n, n, a, lda) checkVector(n, x, 1) checkVector(n, cnorm, 1) if n == 0 { return 0 } smlnum := dlamchS / dlamchP bignum := 1 / smlnum scale = 1 bi := blas64.Implementation() if !normin { if upper { cnorm[0] = 0 for j := 1; j < n; j++ { cnorm[j] = bi.Dasum(j, a[j:], lda) } } else { for j := 0; j < n-1; j++ { cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda) } cnorm[n-1] = 0 } } // Scale the column norms by tscal if the maximum element in cnorm is greater than bignum. imax := bi.Idamax(n, cnorm, 1) tmax := cnorm[imax] var tscal float64 if tmax <= bignum { tscal = 1 } else { tscal = 1 / (smlnum * tmax) bi.Dscal(n, tscal, cnorm, 1) } // Compute a bound on the computed solution vector to see if bi.Dtrsv can be used. j := bi.Idamax(n, x, 1) xmax := math.Abs(x[j]) xbnd := xmax var grow float64 var jfirst, jlast, jinc int if noTrans { if upper { jfirst = n - 1 jlast = -1 jinc = -1 } else { jfirst = 0 jlast = n jinc = 1 } // Compute the growth in A * x = b. if tscal != 1 { grow = 0 goto Solve } if nonUnit { grow = 1 / math.Max(xbnd, smlnum) xbnd = grow for j := jfirst; j != jlast; j += jinc { if grow <= smlnum { goto Solve } tjj := math.Abs(a[j*lda+j]) xbnd = math.Min(xbnd, math.Min(1, tjj)*grow) if tjj+cnorm[j] >= smlnum { grow *= tjj / (tjj + cnorm[j]) } else { grow = 0 } } grow = xbnd } else { grow = math.Min(1, 1/math.Max(xbnd, smlnum)) for j := jfirst; j != jlast; j += jinc { if grow <= smlnum { goto Solve } grow *= 1 / (1 + cnorm[j]) } } } else { if upper { jfirst = 0 jlast = n jinc = 1 } else { jfirst = n - 1 jlast = -1 jinc = -1 } if tscal != 1 { grow = 0 goto Solve } if nonUnit { grow = 1 / (math.Max(xbnd, smlnum)) xbnd = grow for j := jfirst; j != jlast; j += jinc { if grow <= smlnum { goto Solve } xj := 1 + cnorm[j] grow = math.Min(grow, xbnd/xj) tjj := math.Abs(a[j*lda+j]) if xj > tjj { xbnd *= tjj / xj } } grow = math.Min(grow, xbnd) } else { grow = math.Min(1, 1/math.Max(xbnd, smlnum)) for j := jfirst; j != jlast; j += jinc { if grow <= smlnum { goto Solve } xj := 1 + cnorm[j] grow /= xj } } } Solve: if grow*tscal > smlnum { // Use the Level 2 BLAS solve if the reciprocal of the bound on // elements of X is not too small. bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1) if tscal != 1 { bi.Dscal(n, 1/tscal, cnorm, 1) } return scale } // Use a Level 1 BLAS solve, scaling intermediate results. if xmax > bignum { scale = bignum / xmax bi.Dscal(n, scale, x, 1) xmax = bignum } if noTrans { for j := jfirst; j != jlast; j += jinc { xj := math.Abs(x[j]) var tjj, tjjs float64 if nonUnit { tjjs = a[j*lda+j] * tscal } else { tjjs = tscal if tscal == 1 { goto Skip1 } } tjj = math.Abs(tjjs) if tjj > smlnum { if tjj < 1 { if xj > tjj*bignum { rec := 1 / xj bi.Dscal(n, rec, x, 1) scale *= rec xmax *= rec } } x[j] /= tjjs xj = math.Abs(x[j]) } else if tjj > 0 { if xj > tjj*bignum { rec := (tjj * bignum) / xj if cnorm[j] > 1 { rec /= cnorm[j] } bi.Dscal(n, rec, x, 1) scale *= rec xmax *= rec } x[j] /= tjjs xj = math.Abs(x[j]) } else { for i := 0; i < n; i++ { x[i] = 0 } x[j] = 1 xj = 1 scale = 0 xmax = 0 } Skip1: if xj > 1 { rec := 1 / xj if cnorm[j] > (bignum-xmax)*rec { rec *= 0.5 bi.Dscal(n, rec, x, 1) scale *= rec } } else if xj*cnorm[j] > bignum-xmax { bi.Dscal(n, 0.5, x, 1) scale *= 0.5 } if upper { if j > 0 { bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1) i := bi.Idamax(j, x, 1) xmax = math.Abs(x[i]) } } else { if j < n-1 { bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1) i := j + bi.Idamax(n-j-1, x[j+1:], 1) xmax = math.Abs(x[i]) } } } } else { for j := jfirst; j != jlast; j += jinc { xj := math.Abs(x[j]) uscal := tscal rec := 1 / math.Max(xmax, 1) var tjjs float64 if cnorm[j] > (bignum-xj)*rec { rec *= 0.5 if nonUnit { tjjs = a[j*lda+j] * tscal } else { tjjs = tscal } tjj := math.Abs(tjjs) if tjj > 1 { rec = math.Min(1, rec*tjj) uscal /= tjjs } if rec < 1 { bi.Dscal(n, rec, x, 1) scale *= rec xmax *= rec } } var sumj float64 if uscal == 1 { if upper { sumj = bi.Ddot(j, a[j:], lda, x, 1) } else if j < n-1 { sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1) } } else { if upper { for i := 0; i < j; i++ { sumj += (a[i*lda+j] * uscal) * x[i] } } else if j < n { for i := j + 1; i < n; i++ { sumj += (a[i*lda+j] * uscal) * x[i] } } } if uscal == tscal { x[j] -= sumj xj := math.Abs(x[j]) var tjjs float64 if nonUnit { tjjs = a[j*lda+j] * tscal } else { tjjs = tscal if tscal == 1 { goto Skip2 } } tjj := math.Abs(tjjs) if tjj > smlnum { if tjj < 1 { if xj > tjj*bignum { rec = 1 / xj bi.Dscal(n, rec, x, 1) scale *= rec xmax *= rec } } x[j] /= tjjs } else if tjj > 0 { if xj > tjj*bignum { rec = (tjj * bignum) / xj bi.Dscal(n, rec, x, 1) scale *= rec xmax *= rec } x[j] /= tjjs } else { for i := 0; i < n; i++ { x[i] = 0 } x[j] = 1 scale = 0 xmax = 0 } } else { x[j] = x[j]/tjjs - sumj } Skip2: xmax = math.Max(xmax, math.Abs(x[j])) } } scale /= tscal if tscal != 1 { bi.Dscal(n, 1/tscal, cnorm, 1) } return scale }