cgo/clapack: use LAPACKE_func_work functions

This reduces allocations and harmonises the cgo and native behaviours.
This commit is contained in:
kortschak
2016-06-03 13:03:32 +09:30
parent c254fffffb
commit 0a88dcb8a8
3 changed files with 98 additions and 32 deletions

View File

@@ -24,6 +24,7 @@ open(my $clapack, "<", $clapackHeader) or die;
open(my $golapack, ">", "clapack.go") or die;
my %done;
my %hasWork;
printf $golapack <<"EOH";
// Do not manually edit this file. It was created by the genLapack.pl script from ${clapackHeader}.
@@ -158,6 +159,10 @@ our %allUplo = (
"lacpy" => undef
);
foreach my $line (@lines) {
assess($line);
}
foreach my $line (@lines) {
process($line);
}
@@ -165,6 +170,31 @@ foreach my $line (@lines) {
close($golapack);
`go fmt .`;
sub assess {
my $line = shift;
chomp $line;
assessWork($line);
}
sub assessWork {
my $proto = shift;
if (not($proto =~ /LAPACKE/)) {
return
}
my ($func, $paramList) = split /[()]/, $proto;
if ($func =~ /rook/) {
return
}
(my $ret, $func) = split ' ', $func;
(my $pack, $func, my $tail) = split '_', $func;
if (!defined $tail or $tail ne "work") {
return
}
$hasWork{$func} = 1;
}
sub process {
my $line = shift;
chomp $line;
@@ -173,13 +203,22 @@ sub process {
sub processProto {
my $proto = shift;
if(not($proto =~ /LAPACKE/)) {
if (not($proto =~ /LAPACKE/)) {
return
}
my ($func, $paramList) = split /[()]/, $proto;
(my $ret, $func) = split ' ', $func;
(my $pack, $func, my $tail) = split '_', $func;
if ($hasWork{$func} && (!defined $tail || $tail ne "work")) {
# This is ilaver only at this stage.
return
}
if (defined $tail) {
$tail = "_$tail";
} else {
$tail = "";
}
if ($done{$func} or $xobjs{$func} or $deprecated{$func}){
return
@@ -199,13 +238,7 @@ sub processProto {
}
$done{$func} = 1;
my $gofunc;
if ($tail) {
$gofunc = ucfirst $func . ucfirst $tail;
}else{
$gofunc = ucfirst $func;
}
my $gofunc = ucfirst $func;
my $GoRet = $typeConv{$ret."_return"};
my $GoRetType = $typeConv{$ret."_return_type"};
@@ -221,7 +254,7 @@ sub processProto {
if ($ret ne 'void') {
print $golapack $bp."return ".$GoRet."(";
}
print $golapack "C.LAPACKE_$func(".processParamToC($func, $paramList).")";
print $golapack "C.LAPACKE_$func$tail(".processParamToC($func, $paramList).")";
if ($ret ne 'void') {
print $golapack ")";
}

View File

@@ -112,7 +112,7 @@ func (impl Implementation) Dlange(norm lapack.MatrixNorm, m, n int, a []float64,
if norm == lapack.MaxColumnSum && len(work) < n {
panic(badWork)
}
return clapack.Dlange(byte(norm), m, n, a, lda)
return clapack.Dlange(byte(norm), m, n, a, lda, work)
}
// Dlansy computes the specified norm of an n×n symmetric matrix. If
@@ -131,7 +131,7 @@ func (impl Implementation) Dlansy(norm lapack.MatrixNorm, uplo blas.Uplo, n int,
if uplo != blas.Upper && uplo != blas.Lower {
panic(badUplo)
}
return clapack.Dlansy(byte(norm), uplo, n, a, lda)
return clapack.Dlansy(byte(norm), uplo, n, a, lda, work)
}
// Dlantr computes the specified norm of an m×n trapezoidal matrix A. If
@@ -153,7 +153,7 @@ func (impl Implementation) Dlantr(norm lapack.MatrixNorm, uplo blas.Uplo, diag b
if norm == lapack.MaxColumnSum && len(work) < n {
panic(badWork)
}
return clapack.Dlantr(byte(norm), uplo, diag, m, n, a, lda)
return clapack.Dlantr(byte(norm), uplo, diag, m, n, a, lda, work)
}
// Dpotrf computes the Cholesky decomposition of the symmetric positive definite
@@ -242,7 +242,7 @@ func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, v
if len(c) == 0 {
c = make([]float64, 1)
}
return clapack.Dbdsqr(uplo, n, ncvt, nru, ncc, d, e, vt, ldvt, u, ldu, c, ldc)
return clapack.Dbdsqr(uplo, n, ncvt, nru, ncc, d, e, vt, ldvt, u, ldu, c, ldc, work)
}
// Dgebrd reduces a general m×n matrix A to upper or lower bidiagonal form B by
@@ -312,7 +312,7 @@ func (impl Implementation) Dgebrd(m, n int, a []float64, lda int, d, e, tauQ, ta
panic(badWork)
}
clapack.Dgebrd(m, n, a, lda, d, e, tauQ, tauP)
clapack.Dgebrd(m, n, a, lda, d, e, tauQ, tauP, work, lwork)
}
// Dgecon estimates the reciprocal of the condition number of the n×n matrix A
@@ -326,6 +326,7 @@ func (impl Implementation) Dgebrd(m, n int, a []float64, lda int, d, e, tauQ, ta
// work is a temporary data slice of length at least 4*n and Dgecon will panic otherwise.
//
// iwork is a temporary data slice of length at least n and Dgecon will panic otherwise.
// Elements of iwork must fit within the int32 type or Dgecon will panic.
func (impl Implementation) Dgecon(norm lapack.MatrixNorm, n int, a []float64, lda int, anorm float64, work []float64, iwork []int) float64 {
checkMatrix(n, n, a, lda)
if norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum {
@@ -338,7 +339,17 @@ func (impl Implementation) Dgecon(norm lapack.MatrixNorm, n int, a []float64, ld
panic(badWork)
}
rcond := make([]float64, 1)
clapack.Dgecon(byte(norm), n, a, lda, anorm, rcond)
_iwork := make([]int32, len(iwork))
for i, v := range iwork {
if v != int(int32(v)) {
panic("lapack: iwork element out of range")
}
_iwork[i] = int32(v)
}
clapack.Dgecon(byte(norm), n, a, lda, anorm, rcond, work, _iwork)
for i, v := range _iwork {
iwork[i] = int(v)
}
return rcond[0]
}
@@ -367,7 +378,7 @@ func (impl Implementation) Dgelq2(m, n int, a []float64, lda int, tau, work []fl
if len(work) < m {
panic(badWork)
}
clapack.Dgelq2(m, n, a, lda, tau)
clapack.Dgelq2(m, n, a, lda, tau, work)
}
// Dgelqf computes the LQ factorization of the m×n matrix A using a blocked
@@ -394,7 +405,7 @@ func (impl Implementation) Dgelqf(m, n int, a []float64, lda int, tau, work []fl
if len(tau) < min(m, n) {
panic(badTau)
}
clapack.Dgelqf(m, n, a, lda, tau)
clapack.Dgelqf(m, n, a, lda, tau, work, lwork)
}
// Dgeqr2 computes a QR factorization of the m×n matrix A.
@@ -428,7 +439,7 @@ func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []fl
if len(tau) < k {
panic(badTau)
}
clapack.Dgeqr2(m, n, a, lda, tau)
clapack.Dgeqr2(m, n, a, lda, tau, work)
}
// Dgeqrf computes the QR factorization of the m×n matrix A using a blocked
@@ -456,7 +467,7 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl
if len(tau) < k {
panic(badTau)
}
clapack.Dgeqrf(m, n, a, lda, tau)
clapack.Dgeqrf(m, n, a, lda, tau, work, lwork)
}
// Dgels finds a minimum-norm solution based on the matrices A and B using the
@@ -500,7 +511,7 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float
if lwork < mn+max(mn, nrhs) {
panic(badWork)
}
return clapack.Dgels(trans, m, n, nrhs, a, lda, b, ldb)
return clapack.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, lwork)
}
const noSVDO = "dgesvd: not coded for overwrite"
@@ -579,7 +590,7 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float
work[0] = float64(minWork)
return true
}
return clapack.Dgesvd(lapack.Job(jobU), lapack.Job(jobVT), m, n, a, lda, s, u, ldu, vt, ldvt, work[1:])
return clapack.Dgesvd(lapack.Job(jobU), lapack.Job(jobVT), m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork)
}
// Dgetf2 computes the LU decomposition of the m×n matrix A.
@@ -671,7 +682,7 @@ func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work
for i, v := range ipiv {
ipiv32[i] = int32(v) + 1 // Transform to one-indexed.
}
return clapack.Dgetri(n, a, lda, ipiv32)
return clapack.Dgetri(n, a, lda, ipiv32, work, lwork)
}
// Dgetrs solves a system of equations using an LU factorization.
@@ -735,7 +746,7 @@ func (impl Implementation) Dorgbr(vect lapack.DecompUpdate, m, n, k int, a []flo
if lwork < mn {
panic(badWork)
}
clapack.Dorgbr(byte(vect), m, n, k, a, lda, tau)
clapack.Dorgbr(byte(vect), m, n, k, a, lda, tau, work, lwork)
}
// Dorglq generates an m×n matrix Q with orthonormal rows defined by the product
@@ -777,7 +788,7 @@ func (impl Implementation) Dorglq(m, n, k int, a []float64, lda int, tau, work [
if lwork < m {
panic(badWork)
}
clapack.Dorglq(m, n, k, a, lda, tau)
clapack.Dorglq(m, n, k, a, lda, tau, work, lwork)
}
// Dorgqr generates an m×n matrix Q with orthonormal columns defined by the
@@ -819,7 +830,7 @@ func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work [
if lwork < n {
panic(badWork)
}
clapack.Dorgqr(m, n, k, a, lda, tau)
clapack.Dorgqr(m, n, k, a, lda, tau, work, lwork)
}
// Dormbr applies a multiplicative update to the matrix C based on a
@@ -866,7 +877,7 @@ func (impl Implementation) Dormbr(vect lapack.DecompUpdate, side blas.Side, tran
} else {
checkMatrix(min(nq, k), nq, a, lda)
}
clapack.Dormbr(byte(vect), side, trans, m, n, k, a, lda, tau, c, ldc)
clapack.Dormbr(byte(vect), side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork)
}
// Dormlq multiplies the matrix C by the orthogonal matrix Q defined by the
@@ -921,7 +932,7 @@ func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k
panic(badWork)
}
}
clapack.Dormlq(side, trans, m, n, k, a, lda, tau, c, ldc)
clapack.Dormlq(side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork)
}
// Dormqr multiplies the matrix C by the orthogonal matrix Q defined by the
@@ -973,7 +984,7 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k
}
}
clapack.Dormqr(side, trans, m, n, k, a, lda, tau, c, ldc)
clapack.Dormqr(side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork)
}
// Dpocon estimates the reciprocal of the condition number of a positive-definite
@@ -985,6 +996,7 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k
// work is a temporary data slice of length at least 3*n and Dpocon will panic otherwise.
//
// iwork is a temporary data slice of length at least n and Dpocon will panic otherwise.
// Elements of iwork must fit within the int32 type or Dpocon will panic.
func (impl Implementation) Dpocon(uplo blas.Uplo, n int, a []float64, lda int, anorm float64, work []float64, iwork []int) float64 {
checkMatrix(n, n, a, lda)
if uplo != blas.Upper && uplo != blas.Lower {
@@ -997,7 +1009,17 @@ func (impl Implementation) Dpocon(uplo blas.Uplo, n int, a []float64, lda int, a
panic(badWork)
}
rcond := make([]float64, 1)
clapack.Dpocon(uplo, n, a, lda, anorm, rcond)
_iwork := make([]int32, len(iwork))
for i, v := range iwork {
if v != int(int32(v)) {
panic("lapack: iwork element out of range")
}
_iwork[i] = int32(v)
}
clapack.Dpocon(uplo, n, a, lda, anorm, rcond, work, _iwork)
for i, v := range _iwork {
iwork[i] = int(v)
}
return rcond[0]
}
@@ -1027,7 +1049,7 @@ func (impl Implementation) Dsyev(jobz lapack.EigComp, uplo blas.Uplo, n int, a [
if lwork < 3*n-1 {
panic(badWork)
}
return clapack.Dsyev(lapack.Job(jobz), uplo, n, a, lda, w)
return clapack.Dsyev(lapack.Job(jobz), uplo, n, a, lda, w, work, lwork)
}
// Dtrcon estimates the reciprocal of the condition number of a triangular matrix A.
@@ -1036,6 +1058,7 @@ func (impl Implementation) Dsyev(jobz lapack.EigComp, uplo blas.Uplo, n int, a [
// work is a temporary data slice of length at least 3*n and Dtrcon will panic otherwise.
//
// iwork is a temporary data slice of length at least n and Dtrcon will panic otherwise.
// Elements of iwork must fit within the int32 type or Dtrcon will panic.
func (impl Implementation) Dtrcon(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int, work []float64, iwork []int) float64 {
if norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum {
panic(badNorm)
@@ -1053,7 +1076,17 @@ func (impl Implementation) Dtrcon(norm lapack.MatrixNorm, uplo blas.Uplo, diag b
panic(badWork)
}
rcond := []float64{0}
clapack.Dtrcon(byte(norm), uplo, diag, n, a, lda, rcond)
_iwork := make([]int32, len(iwork))
for i, v := range iwork {
if v != int(int32(v)) {
panic("lapack: iwork element out of range")
}
_iwork[i] = int32(v)
}
clapack.Dtrcon(byte(norm), uplo, diag, n, a, lda, rcond, work, _iwork)
for i, v := range _iwork {
iwork[i] = int(v)
}
return rcond[0]
}