diff --git a/native/dlatrs.go b/native/dlatrs.go index 2599348f..4ee6e4e9 100644 --- a/native/dlatrs.go +++ b/native/dlatrs.go @@ -55,7 +55,8 @@ func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag bla bi := blas64.Implementation() if !normin { if upper { - for j := 0; j < n; j++ { + cnorm[0] = 0 + for j := 1; j < n; j++ { cnorm[j] = bi.Dasum(j, a[j:], lda) } } else { @@ -95,14 +96,14 @@ func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag bla // Compute the growth in A * x = b. if tscal != 1 { grow = 0 - goto Finish + goto Solve } if nonUnit { grow = 1 / math.Max(xbnd, smlnum) xbnd = grow for j := jfirst; j != jlast; j += jinc { if grow <= smlnum { - goto Finish + goto Solve } tjj := math.Abs(a[j*lda+j]) xbnd = math.Min(xbnd, math.Min(1, tjj)*grow) @@ -117,7 +118,7 @@ func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag bla grow = math.Min(1, 1/math.Max(xbnd, smlnum)) for j := jfirst; j != jlast; j += jinc { if grow <= smlnum { - goto Finish + goto Solve } grow *= 1 / (1 + cnorm[j]) } @@ -134,14 +135,14 @@ func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag bla } if tscal != 1 { grow = 0 - goto Finish + goto Solve } if nonUnit { grow = 1 / (math.Max(xbnd, smlnum)) xbnd = grow for j := jfirst; j != jlast; j += jinc { if grow <= smlnum { - goto Finish + goto Solve } xj := 1 + cnorm[j] grow = math.Min(grow, xbnd/xj) @@ -155,7 +156,7 @@ func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag bla grow = math.Min(1, 1/math.Max(xbnd, smlnum)) for j := jfirst; j != jlast; j += jinc { if grow <= smlnum { - goto Finish + goto Solve } xj := 1 + cnorm[j] grow /= xj @@ -163,177 +164,183 @@ func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag bla } } -Finish: +Solve: if grow*tscal > smlnum { + // Use the Level 2 BLAS solve if the reciprocal of the bound on + // elements of X is not too small. bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1) - // TODO(btracey): check if this else is everything - } else { - if xmax > bignum { - scale = bignum / xmax - bi.Dscal(n, scale, x, 1) - xmax = bignum + if tscal != 1 { + bi.Dscal(n, 1/tscal, cnorm, 1) } - if noTrans { - for j := jfirst; j != jlast; j += jinc { + return scale + } + + // Use a Level 1 BLAS solve, scaling intermediate results. + if xmax > bignum { + scale = bignum / xmax + bi.Dscal(n, scale, x, 1) + xmax = bignum + } + if noTrans { + for j := jfirst; j != jlast; j += jinc { + xj := math.Abs(x[j]) + var tjj, tjjs float64 + if nonUnit { + tjjs = a[j*lda+j] * tscal + } else { + tjjs = tscal + if tscal == 1 { + goto Skip1 + } + } + tjj = math.Abs(tjjs) + if tjj > smlnum { + if tjj < 1 { + if xj > tjj*bignum { + rec := 1 / xj + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + } + x[j] /= tjjs + xj = math.Abs(x[j]) + } else if tjj > 0 { + if xj > tjj*bignum { + rec := (tjj * bignum) / xj + if cnorm[j] > 1 { + rec /= cnorm[j] + } + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + x[j] /= tjjs + xj = math.Abs(x[j]) + } else { + for i := 0; i < n; i++ { + x[i] = 0 + } + x[j] = 1 + xj = 1 + scale = 0 + xmax = 0 + } + Skip1: + if xj > 1 { + rec := 1 / xj + if cnorm[j] > (bignum-xmax)*rec { + rec *= 0.5 + bi.Dscal(n, rec, x, 1) + scale *= rec + } + } else if xj*cnorm[j] > bignum-xmax { + bi.Dscal(n, 0.5, x, 1) + scale *= 0.5 + } + if upper { + if j > 0 { + bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1) + i := bi.Idamax(j, x, 1) + xmax = math.Abs(x[i]) + } + } else { + if j < n-1 { + bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1) + i := j + bi.Idamax(n-j-1, x[j+1:], 1) + xmax = math.Abs(x[i]) + } + } + } + } else { + for j := jfirst; j != jlast; j += jinc { + xj := math.Abs(x[j]) + uscal := tscal + rec := 1 / math.Max(xmax, 1) + var tjjs float64 + if cnorm[j] > (bignum-xj)*rec { + rec *= 0.5 + if nonUnit { + tjjs = a[j*lda+j] * tscal + } else { + tjjs = tscal + } + tjj := math.Abs(tjjs) + if tjj > 1 { + rec = math.Min(1, rec*tjj) + uscal /= tjjs + } + if rec < 1 { + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + } + var sumj float64 + if uscal == 1 { + if upper { + sumj = bi.Ddot(j, a[j:], lda, x, 1) + } else if j < n-1 { + sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1) + } + } else { + if upper { + for i := 0; i < j; i++ { + sumj += (a[i*lda+j] * uscal) * x[i] + } + } else if j < n { + for i := j + 1; i < n; i++ { + sumj += (a[i*lda+j] * uscal) * x[i] + } + } + } + if uscal == tscal { + x[j] -= sumj xj := math.Abs(x[j]) - var tjj, tjjs float64 + var tjjs float64 if nonUnit { tjjs = a[j*lda+j] * tscal } else { tjjs = tscal if tscal == 1 { - goto Skip1 + goto Skip2 } } - tjj = math.Abs(tjjs) + tjj := math.Abs(tjjs) if tjj > smlnum { if tjj < 1 { if xj > tjj*bignum { - rec := 1 / xj + rec = 1 / xj bi.Dscal(n, rec, x, 1) scale *= rec xmax *= rec } } x[j] /= tjjs - xj = math.Abs(x[j]) } else if tjj > 0 { if xj > tjj*bignum { - rec := (tjj * bignum) / xj - if cnorm[j] > 1 { - rec /= cnorm[j] - } + rec = (tjj * bignum) / xj bi.Dscal(n, rec, x, 1) scale *= rec xmax *= rec } x[j] /= tjjs - xj = math.Abs(x[j]) } else { for i := 0; i < n; i++ { x[i] = 0 } x[j] = 1 - xj = 1 scale = 0 xmax = 0 } - Skip1: - if xj > 1 { - rec := 1 / xj - if cnorm[j] > (bignum-xmax)*rec { - rec *= 0.5 - bi.Dscal(n, rec, x, 1) - scale *= rec - } - } else if xj*cnorm[j] > bignum-xmax { - bi.Dscal(n, 0.5, x, 1) - scale *= 0.5 - } - if upper { - if j > 0 { - bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1) - i := bi.Idamax(j, x, 1) - xmax = math.Abs(x[i]) - } - } else { - if j < n-1 { - bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1) - i := j + bi.Idamax(n-j-1, x[j+1:], 1) - xmax = math.Abs(x[i]) - } - } - } - } else { - for j := jfirst; j != jlast; j += jinc { - xj := math.Abs(x[j]) - uscal := tscal - rec := 1 / math.Max(xmax, 1) - var tjjs float64 - if cnorm[j] > (bignum-xj)*rec { - rec *= 0.5 - if nonUnit { - tjjs = a[j*lda+j] * tscal - } else { - tjjs = tscal - } - tjj := math.Abs(tjjs) - if tjj > 1 { - rec = math.Min(1, rec*tjj) - uscal /= tjjs - } - if rec < 1 { - bi.Dscal(n, rec, x, 1) - scale *= rec - xmax *= rec - } - } - var sumj float64 - if uscal == 1 { - if upper { - sumj = bi.Ddot(j, a[j:], lda, x, 1) - } else if j < n-1 { - sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1) - } - } else { - if upper { - for i := 0; i < j; i++ { - sumj += (a[i*lda+j] * uscal) * x[i] - } - } else if j < n { - for i := j + 1; i < n; i++ { - sumj += (a[i*lda+j] * uscal) * x[i] - } - } - } - if uscal == tscal { - x[j] -= sumj - xj := math.Abs(x[j]) - var tjjs float64 - if nonUnit { - tjjs = a[j*lda+j] * tscal - } else { - tjjs = tscal - if tscal == 1 { - goto Skip2 - } - } - tjj := math.Abs(tjjs) - if tjj > smlnum { - if tjj < 1 { - if xj > tjj*bignum { - rec = 1 / xj - bi.Dscal(n, rec, x, 1) - scale *= rec - xmax *= rec - } - } - x[j] /= tjjs - } else if tjj > 0 { - if xj > tjj*bignum { - rec = (tjj * bignum) / xj - bi.Dscal(n, rec, x, 1) - scale *= rec - xmax *= rec - } - x[j] /= tjjs - } else { - for i := 0; i < n; i++ { - x[i] = 0 - } - x[j] = 1 - scale = 0 - xmax = 0 - } - } else { - x[j] = x[j]/tjjs - sumj - } - Skip2: - xmax = math.Max(xmax, math.Abs(x[j])) + } else { + x[j] = x[j]/tjjs - sumj } + Skip2: + xmax = math.Max(xmax, math.Abs(x[j])) } - scale /= tscal } + scale /= tscal if tscal != 1 { bi.Dscal(n, 1/tscal, cnorm, 1) }