mirror of
https://github.com/gonum/gonum.git
synced 2025-10-20 13:55:20 +08:00
cgo/clapack: use LAPACKE_func_work functions
This reduces allocations and harmonises the cgo and native behaviours.
This commit is contained in:
@@ -24,6 +24,7 @@ open(my $clapack, "<", $clapackHeader) or die;
|
||||
open(my $golapack, ">", "clapack.go") or die;
|
||||
|
||||
my %done;
|
||||
my %hasWork;
|
||||
|
||||
printf $golapack <<"EOH";
|
||||
// Do not manually edit this file. It was created by the genLapack.pl script from ${clapackHeader}.
|
||||
@@ -158,6 +159,10 @@ our %allUplo = (
|
||||
"lacpy" => undef
|
||||
);
|
||||
|
||||
foreach my $line (@lines) {
|
||||
assess($line);
|
||||
}
|
||||
|
||||
foreach my $line (@lines) {
|
||||
process($line);
|
||||
}
|
||||
@@ -165,6 +170,31 @@ foreach my $line (@lines) {
|
||||
close($golapack);
|
||||
`go fmt .`;
|
||||
|
||||
sub assess {
|
||||
my $line = shift;
|
||||
chomp $line;
|
||||
assessWork($line);
|
||||
}
|
||||
|
||||
sub assessWork {
|
||||
my $proto = shift;
|
||||
if (not($proto =~ /LAPACKE/)) {
|
||||
return
|
||||
}
|
||||
my ($func, $paramList) = split /[()]/, $proto;
|
||||
if ($func =~ /rook/) {
|
||||
return
|
||||
}
|
||||
|
||||
(my $ret, $func) = split ' ', $func;
|
||||
(my $pack, $func, my $tail) = split '_', $func;
|
||||
if (!defined $tail or $tail ne "work") {
|
||||
return
|
||||
}
|
||||
|
||||
$hasWork{$func} = 1;
|
||||
}
|
||||
|
||||
sub process {
|
||||
my $line = shift;
|
||||
chomp $line;
|
||||
@@ -173,13 +203,22 @@ sub process {
|
||||
|
||||
sub processProto {
|
||||
my $proto = shift;
|
||||
if(not($proto =~ /LAPACKE/)) {
|
||||
if (not($proto =~ /LAPACKE/)) {
|
||||
return
|
||||
}
|
||||
my ($func, $paramList) = split /[()]/, $proto;
|
||||
|
||||
(my $ret, $func) = split ' ', $func;
|
||||
(my $pack, $func, my $tail) = split '_', $func;
|
||||
if ($hasWork{$func} && (!defined $tail || $tail ne "work")) {
|
||||
# This is ilaver only at this stage.
|
||||
return
|
||||
}
|
||||
if (defined $tail) {
|
||||
$tail = "_$tail";
|
||||
} else {
|
||||
$tail = "";
|
||||
}
|
||||
|
||||
if ($done{$func} or $xobjs{$func} or $deprecated{$func}){
|
||||
return
|
||||
@@ -199,13 +238,7 @@ sub processProto {
|
||||
}
|
||||
$done{$func} = 1;
|
||||
|
||||
|
||||
my $gofunc;
|
||||
if ($tail) {
|
||||
$gofunc = ucfirst $func . ucfirst $tail;
|
||||
}else{
|
||||
$gofunc = ucfirst $func;
|
||||
}
|
||||
my $gofunc = ucfirst $func;
|
||||
|
||||
my $GoRet = $typeConv{$ret."_return"};
|
||||
my $GoRetType = $typeConv{$ret."_return_type"};
|
||||
@@ -221,7 +254,7 @@ sub processProto {
|
||||
if ($ret ne 'void') {
|
||||
print $golapack $bp."return ".$GoRet."(";
|
||||
}
|
||||
print $golapack "C.LAPACKE_$func(".processParamToC($func, $paramList).")";
|
||||
print $golapack "C.LAPACKE_$func$tail(".processParamToC($func, $paramList).")";
|
||||
if ($ret ne 'void') {
|
||||
print $golapack ")";
|
||||
}
|
||||
|
@@ -112,7 +112,7 @@ func (impl Implementation) Dlange(norm lapack.MatrixNorm, m, n int, a []float64,
|
||||
if norm == lapack.MaxColumnSum && len(work) < n {
|
||||
panic(badWork)
|
||||
}
|
||||
return clapack.Dlange(byte(norm), m, n, a, lda)
|
||||
return clapack.Dlange(byte(norm), m, n, a, lda, work)
|
||||
}
|
||||
|
||||
// Dlansy computes the specified norm of an n×n symmetric matrix. If
|
||||
@@ -131,7 +131,7 @@ func (impl Implementation) Dlansy(norm lapack.MatrixNorm, uplo blas.Uplo, n int,
|
||||
if uplo != blas.Upper && uplo != blas.Lower {
|
||||
panic(badUplo)
|
||||
}
|
||||
return clapack.Dlansy(byte(norm), uplo, n, a, lda)
|
||||
return clapack.Dlansy(byte(norm), uplo, n, a, lda, work)
|
||||
}
|
||||
|
||||
// Dlantr computes the specified norm of an m×n trapezoidal matrix A. If
|
||||
@@ -153,7 +153,7 @@ func (impl Implementation) Dlantr(norm lapack.MatrixNorm, uplo blas.Uplo, diag b
|
||||
if norm == lapack.MaxColumnSum && len(work) < n {
|
||||
panic(badWork)
|
||||
}
|
||||
return clapack.Dlantr(byte(norm), uplo, diag, m, n, a, lda)
|
||||
return clapack.Dlantr(byte(norm), uplo, diag, m, n, a, lda, work)
|
||||
}
|
||||
|
||||
// Dpotrf computes the Cholesky decomposition of the symmetric positive definite
|
||||
@@ -242,7 +242,7 @@ func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, v
|
||||
if len(c) == 0 {
|
||||
c = make([]float64, 1)
|
||||
}
|
||||
return clapack.Dbdsqr(uplo, n, ncvt, nru, ncc, d, e, vt, ldvt, u, ldu, c, ldc)
|
||||
return clapack.Dbdsqr(uplo, n, ncvt, nru, ncc, d, e, vt, ldvt, u, ldu, c, ldc, work)
|
||||
}
|
||||
|
||||
// Dgebrd reduces a general m×n matrix A to upper or lower bidiagonal form B by
|
||||
@@ -312,7 +312,7 @@ func (impl Implementation) Dgebrd(m, n int, a []float64, lda int, d, e, tauQ, ta
|
||||
panic(badWork)
|
||||
}
|
||||
|
||||
clapack.Dgebrd(m, n, a, lda, d, e, tauQ, tauP)
|
||||
clapack.Dgebrd(m, n, a, lda, d, e, tauQ, tauP, work, lwork)
|
||||
}
|
||||
|
||||
// Dgecon estimates the reciprocal of the condition number of the n×n matrix A
|
||||
@@ -326,6 +326,7 @@ func (impl Implementation) Dgebrd(m, n int, a []float64, lda int, d, e, tauQ, ta
|
||||
// work is a temporary data slice of length at least 4*n and Dgecon will panic otherwise.
|
||||
//
|
||||
// iwork is a temporary data slice of length at least n and Dgecon will panic otherwise.
|
||||
// Elements of iwork must fit within the int32 type or Dgecon will panic.
|
||||
func (impl Implementation) Dgecon(norm lapack.MatrixNorm, n int, a []float64, lda int, anorm float64, work []float64, iwork []int) float64 {
|
||||
checkMatrix(n, n, a, lda)
|
||||
if norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum {
|
||||
@@ -338,7 +339,17 @@ func (impl Implementation) Dgecon(norm lapack.MatrixNorm, n int, a []float64, ld
|
||||
panic(badWork)
|
||||
}
|
||||
rcond := make([]float64, 1)
|
||||
clapack.Dgecon(byte(norm), n, a, lda, anorm, rcond)
|
||||
_iwork := make([]int32, len(iwork))
|
||||
for i, v := range iwork {
|
||||
if v != int(int32(v)) {
|
||||
panic("lapack: iwork element out of range")
|
||||
}
|
||||
_iwork[i] = int32(v)
|
||||
}
|
||||
clapack.Dgecon(byte(norm), n, a, lda, anorm, rcond, work, _iwork)
|
||||
for i, v := range _iwork {
|
||||
iwork[i] = int(v)
|
||||
}
|
||||
return rcond[0]
|
||||
}
|
||||
|
||||
@@ -367,7 +378,7 @@ func (impl Implementation) Dgelq2(m, n int, a []float64, lda int, tau, work []fl
|
||||
if len(work) < m {
|
||||
panic(badWork)
|
||||
}
|
||||
clapack.Dgelq2(m, n, a, lda, tau)
|
||||
clapack.Dgelq2(m, n, a, lda, tau, work)
|
||||
}
|
||||
|
||||
// Dgelqf computes the LQ factorization of the m×n matrix A using a blocked
|
||||
@@ -394,7 +405,7 @@ func (impl Implementation) Dgelqf(m, n int, a []float64, lda int, tau, work []fl
|
||||
if len(tau) < min(m, n) {
|
||||
panic(badTau)
|
||||
}
|
||||
clapack.Dgelqf(m, n, a, lda, tau)
|
||||
clapack.Dgelqf(m, n, a, lda, tau, work, lwork)
|
||||
}
|
||||
|
||||
// Dgeqr2 computes a QR factorization of the m×n matrix A.
|
||||
@@ -428,7 +439,7 @@ func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []fl
|
||||
if len(tau) < k {
|
||||
panic(badTau)
|
||||
}
|
||||
clapack.Dgeqr2(m, n, a, lda, tau)
|
||||
clapack.Dgeqr2(m, n, a, lda, tau, work)
|
||||
}
|
||||
|
||||
// Dgeqrf computes the QR factorization of the m×n matrix A using a blocked
|
||||
@@ -456,7 +467,7 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl
|
||||
if len(tau) < k {
|
||||
panic(badTau)
|
||||
}
|
||||
clapack.Dgeqrf(m, n, a, lda, tau)
|
||||
clapack.Dgeqrf(m, n, a, lda, tau, work, lwork)
|
||||
}
|
||||
|
||||
// Dgels finds a minimum-norm solution based on the matrices A and B using the
|
||||
@@ -500,7 +511,7 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float
|
||||
if lwork < mn+max(mn, nrhs) {
|
||||
panic(badWork)
|
||||
}
|
||||
return clapack.Dgels(trans, m, n, nrhs, a, lda, b, ldb)
|
||||
return clapack.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, lwork)
|
||||
}
|
||||
|
||||
const noSVDO = "dgesvd: not coded for overwrite"
|
||||
@@ -579,7 +590,7 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
|
||||
work[0] = float64(minWork)
|
||||
return true
|
||||
}
|
||||
return clapack.Dgesvd(lapack.Job(jobU), lapack.Job(jobVT), m, n, a, lda, s, u, ldu, vt, ldvt, work[1:])
|
||||
return clapack.Dgesvd(lapack.Job(jobU), lapack.Job(jobVT), m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork)
|
||||
}
|
||||
|
||||
// Dgetf2 computes the LU decomposition of the m×n matrix A.
|
||||
@@ -671,7 +682,7 @@ func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work
|
||||
for i, v := range ipiv {
|
||||
ipiv32[i] = int32(v) + 1 // Transform to one-indexed.
|
||||
}
|
||||
return clapack.Dgetri(n, a, lda, ipiv32)
|
||||
return clapack.Dgetri(n, a, lda, ipiv32, work, lwork)
|
||||
}
|
||||
|
||||
// Dgetrs solves a system of equations using an LU factorization.
|
||||
@@ -735,7 +746,7 @@ func (impl Implementation) Dorgbr(vect lapack.DecompUpdate, m, n, k int, a []flo
|
||||
if lwork < mn {
|
||||
panic(badWork)
|
||||
}
|
||||
clapack.Dorgbr(byte(vect), m, n, k, a, lda, tau)
|
||||
clapack.Dorgbr(byte(vect), m, n, k, a, lda, tau, work, lwork)
|
||||
}
|
||||
|
||||
// Dorglq generates an m×n matrix Q with orthonormal rows defined by the product
|
||||
@@ -777,7 +788,7 @@ func (impl Implementation) Dorglq(m, n, k int, a []float64, lda int, tau, work [
|
||||
if lwork < m {
|
||||
panic(badWork)
|
||||
}
|
||||
clapack.Dorglq(m, n, k, a, lda, tau)
|
||||
clapack.Dorglq(m, n, k, a, lda, tau, work, lwork)
|
||||
}
|
||||
|
||||
// Dorgqr generates an m×n matrix Q with orthonormal columns defined by the
|
||||
@@ -819,7 +830,7 @@ func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work [
|
||||
if lwork < n {
|
||||
panic(badWork)
|
||||
}
|
||||
clapack.Dorgqr(m, n, k, a, lda, tau)
|
||||
clapack.Dorgqr(m, n, k, a, lda, tau, work, lwork)
|
||||
}
|
||||
|
||||
// Dormbr applies a multiplicative update to the matrix C based on a
|
||||
@@ -866,7 +877,7 @@ func (impl Implementation) Dormbr(vect lapack.DecompUpdate, side blas.Side, tran
|
||||
} else {
|
||||
checkMatrix(min(nq, k), nq, a, lda)
|
||||
}
|
||||
clapack.Dormbr(byte(vect), side, trans, m, n, k, a, lda, tau, c, ldc)
|
||||
clapack.Dormbr(byte(vect), side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork)
|
||||
}
|
||||
|
||||
// Dormlq multiplies the matrix C by the orthogonal matrix Q defined by the
|
||||
@@ -921,7 +932,7 @@ func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k
|
||||
panic(badWork)
|
||||
}
|
||||
}
|
||||
clapack.Dormlq(side, trans, m, n, k, a, lda, tau, c, ldc)
|
||||
clapack.Dormlq(side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork)
|
||||
}
|
||||
|
||||
// Dormqr multiplies the matrix C by the orthogonal matrix Q defined by the
|
||||
@@ -973,7 +984,7 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k
|
||||
}
|
||||
}
|
||||
|
||||
clapack.Dormqr(side, trans, m, n, k, a, lda, tau, c, ldc)
|
||||
clapack.Dormqr(side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork)
|
||||
}
|
||||
|
||||
// Dpocon estimates the reciprocal of the condition number of a positive-definite
|
||||
@@ -985,6 +996,7 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k
|
||||
// work is a temporary data slice of length at least 3*n and Dpocon will panic otherwise.
|
||||
//
|
||||
// iwork is a temporary data slice of length at least n and Dpocon will panic otherwise.
|
||||
// Elements of iwork must fit within the int32 type or Dpocon will panic.
|
||||
func (impl Implementation) Dpocon(uplo blas.Uplo, n int, a []float64, lda int, anorm float64, work []float64, iwork []int) float64 {
|
||||
checkMatrix(n, n, a, lda)
|
||||
if uplo != blas.Upper && uplo != blas.Lower {
|
||||
@@ -997,7 +1009,17 @@ func (impl Implementation) Dpocon(uplo blas.Uplo, n int, a []float64, lda int, a
|
||||
panic(badWork)
|
||||
}
|
||||
rcond := make([]float64, 1)
|
||||
clapack.Dpocon(uplo, n, a, lda, anorm, rcond)
|
||||
_iwork := make([]int32, len(iwork))
|
||||
for i, v := range iwork {
|
||||
if v != int(int32(v)) {
|
||||
panic("lapack: iwork element out of range")
|
||||
}
|
||||
_iwork[i] = int32(v)
|
||||
}
|
||||
clapack.Dpocon(uplo, n, a, lda, anorm, rcond, work, _iwork)
|
||||
for i, v := range _iwork {
|
||||
iwork[i] = int(v)
|
||||
}
|
||||
return rcond[0]
|
||||
}
|
||||
|
||||
@@ -1027,7 +1049,7 @@ func (impl Implementation) Dsyev(jobz lapack.EigComp, uplo blas.Uplo, n int, a [
|
||||
if lwork < 3*n-1 {
|
||||
panic(badWork)
|
||||
}
|
||||
return clapack.Dsyev(lapack.Job(jobz), uplo, n, a, lda, w)
|
||||
return clapack.Dsyev(lapack.Job(jobz), uplo, n, a, lda, w, work, lwork)
|
||||
}
|
||||
|
||||
// Dtrcon estimates the reciprocal of the condition number of a triangular matrix A.
|
||||
@@ -1036,6 +1058,7 @@ func (impl Implementation) Dsyev(jobz lapack.EigComp, uplo blas.Uplo, n int, a [
|
||||
// work is a temporary data slice of length at least 3*n and Dtrcon will panic otherwise.
|
||||
//
|
||||
// iwork is a temporary data slice of length at least n and Dtrcon will panic otherwise.
|
||||
// Elements of iwork must fit within the int32 type or Dtrcon will panic.
|
||||
func (impl Implementation) Dtrcon(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int, work []float64, iwork []int) float64 {
|
||||
if norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum {
|
||||
panic(badNorm)
|
||||
@@ -1053,7 +1076,17 @@ func (impl Implementation) Dtrcon(norm lapack.MatrixNorm, uplo blas.Uplo, diag b
|
||||
panic(badWork)
|
||||
}
|
||||
rcond := []float64{0}
|
||||
clapack.Dtrcon(byte(norm), uplo, diag, n, a, lda, rcond)
|
||||
_iwork := make([]int32, len(iwork))
|
||||
for i, v := range iwork {
|
||||
if v != int(int32(v)) {
|
||||
panic("lapack: iwork element out of range")
|
||||
}
|
||||
_iwork[i] = int32(v)
|
||||
}
|
||||
clapack.Dtrcon(byte(norm), uplo, diag, n, a, lda, rcond, work, _iwork)
|
||||
for i, v := range _iwork {
|
||||
iwork[i] = int(v)
|
||||
}
|
||||
return rcond[0]
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user