From fce8621a32a08e94d48a16daf2a709122f7ed024 Mon Sep 17 00:00:00 2001 From: Vladimir Chalupecky Date: Sat, 15 Jun 2019 11:26:13 +0200 Subject: [PATCH] lapack/gonum: add Dpbtrs --- lapack/gonum/dpbtrs.go | 62 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 lapack/gonum/dpbtrs.go diff --git a/lapack/gonum/dpbtrs.go b/lapack/gonum/dpbtrs.go new file mode 100644 index 00000000..528f5f26 --- /dev/null +++ b/lapack/gonum/dpbtrs.go @@ -0,0 +1,62 @@ +// Copyright ©2019 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" +) + +// Dpbtrs solves a system of linear equations A*X = B with a symmetric positive +// definite band matrix A using the Cholesky factorization A = U^T * U or +// A = L * L^T computed by Dpbtrf. +func (Implementation) Dpbtrs(uplo blas.Uplo, n, kd, nrhs int, ab []float64, ldab int, b []float64, ldb int) { + switch { + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case n < 0: + panic(nLT0) + case kd < 0: + panic(kdLT0) + case nrhs < 0: + panic(nrhsLT0) + case ldab < kd+1: + panic(badLdA) + case ldb < max(1, nrhs): + panic(badLdB) + } + + // Quick return if possible. + if n == 0 || nrhs == 0 { + return + } + + if len(ab) < (n-1)*ldab+kd { + panic(shortAB) + } + if len(b) < (n-1)*ldb+nrhs { + panic(shortB) + } + + bi := blas64.Implementation() + if uplo == blas.Upper { + // Solve A*X = B where A = U^T*U. + for j := 0; j < nrhs; j++ { + // Solve U^T*Y = B, overwriting B with Y. + bi.Dtbsv(blas.Upper, blas.Trans, blas.NonUnit, n, kd, ab, ldab, b[j:], ldb) + // Solve U*X = Y, overwriting Y with X. + bi.Dtbsv(blas.Upper, blas.NoTrans, blas.NonUnit, n, kd, ab, ldab, b[j:], ldb) + } + } else { + // Solve A*X = B where A = L*L^T. + for j := 0; j < nrhs; j++ { + // Solve L*Y = B, overwriting B with Y. + bi.Dtbsv(blas.Lower, blas.NoTrans, blas.NonUnit, n, kd, ab, ldab, b[j:], ldb) + // Solve L^T*X = Y, overwriting Y with X. + bi.Dtbsv(blas.Lower, blas.Trans, blas.NonUnit, n, kd, ab, ldab, b[j:], ldb) + } + } + return +}